зеркало из https://github.com/microsoft/hi-ml.git
ENH: Separate attention heatmap and slide thumbnail plot options (#640)
Dissociate thumbnails and heatmaps plots to flexibly remove/add each of these plot options. Motivations: thumbnail plots can be costly, so we might want to omit it sometime while including heatmap plots and vice versa.
This commit is contained in:
Родитель
84e37efc59
Коммит
b2e873daa5
|
@ -94,7 +94,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
|
||||
def get_test_plot_options(self) -> Set[PlotOption]:
|
||||
plot_options = super().get_test_plot_options()
|
||||
plot_options.add(PlotOption.SLIDE_THUMBNAIL_HEATMAP)
|
||||
plot_options.update([PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP])
|
||||
return plot_options
|
||||
|
||||
|
||||
|
|
|
@ -93,7 +93,8 @@ class AMLMetricsJsonKey(str, Enum):
|
|||
|
||||
class PlotOption(Enum):
|
||||
TOP_BOTTOM_TILES = "top_bottom_tiles"
|
||||
SLIDE_THUMBNAIL_HEATMAP = "slide_thumbnail_heatmap"
|
||||
SLIDE_THUMBNAIL = "slide_thumbnail"
|
||||
ATTENTION_HEATMAP = "attention_heatmap"
|
||||
CONFUSION_MATRIX = "confusion_matrix"
|
||||
HISTOGRAM = "histogram"
|
||||
PR_CURVE = "pr_curve"
|
||||
|
|
|
@ -56,10 +56,13 @@ def validate_class_names(class_names: Optional[Sequence[str]], n_classes: int) -
|
|||
def validate_slide_datasets_for_plot_options(
|
||||
plot_options: Collection[PlotOption], slides_dataset: Optional[SlidesDataset]
|
||||
) -> None:
|
||||
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in plot_options and not slides_dataset:
|
||||
raise ValueError("You can not plot slide thumbnails and heatmaps without setting a slides_dataset. "
|
||||
"Please remove `PlotOption.SLIDE_THUMBNAIL_HEATMAP` from your plot options or provide "
|
||||
"a slide dataset.")
|
||||
|
||||
def _validate_slide_plot_option(plot_option: PlotOption) -> None:
|
||||
if plot_option in plot_options and not slides_dataset:
|
||||
raise ValueError(f"Plot option {plot_option} requires a slides dataset")
|
||||
|
||||
_validate_slide_plot_option(PlotOption.SLIDE_THUMBNAIL)
|
||||
_validate_slide_plot_option(PlotOption.ATTENTION_HEATMAP)
|
||||
|
||||
|
||||
def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
|
||||
|
|
|
@ -25,6 +25,7 @@ from health_cpath.utils.viz_utils import load_image_dict, save_figure
|
|||
|
||||
|
||||
ResultsType = Dict[ResultsKey, List[Any]]
|
||||
SlideDictType = Dict[SlideKey, Any]
|
||||
|
||||
|
||||
def validate_class_names_for_plot_options(
|
||||
|
@ -120,44 +121,44 @@ def save_top_and_bottom_tiles(
|
|||
save_figure(fig=bottom_tiles_fig, figpath=figures_dir / f"{slide_node.slide_id}_bottom.png")
|
||||
|
||||
|
||||
def save_slide_thumbnail_and_heatmap(
|
||||
def save_slide_thumbnail(case: str, slide_node: SlideNode, slide_dict: SlideDictType, figures_dir: Path) -> None:
|
||||
"""Plots and saves a slide thumbnail
|
||||
|
||||
:param case: The report case (e.g., TP, FN, ...)
|
||||
:param slide_node: The slide node that encapsulates the slide metadata.
|
||||
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
|
||||
:param figures_dir: The path to the directory where to save the plots.
|
||||
"""
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_dict[SlideKey.IMAGE], scale=1.0)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
|
||||
|
||||
|
||||
def save_attention_heatmap(
|
||||
case: str,
|
||||
slide_node: SlideNode,
|
||||
slide_dict: SlideDictType,
|
||||
figures_dir: Path,
|
||||
results: ResultsType,
|
||||
slides_dataset: SlidesDataset,
|
||||
tile_size: int = 224,
|
||||
level: int = 1,
|
||||
wsi_has_mask: bool = True,
|
||||
) -> None:
|
||||
"""Plots and saves a slide thumbnail and attention heatmap
|
||||
|
||||
:param case: The report case (e.g., TP, FN, ...)
|
||||
:param slide_node: The slide node that encapsulates the slide metadata.
|
||||
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
|
||||
:param figures_dir: The path to the directory where to save the plots.
|
||||
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
|
||||
:param slides_dataset: The slides dataset from where to pick the slide.
|
||||
:param tile_size: Size of each tile. Default 224.
|
||||
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
|
||||
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
|
||||
:param wsi_has_mask: Whether the slide has a mask or not. Default True.
|
||||
"""
|
||||
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=level, margin=0, wsi_has_mask=wsi_has_mask)
|
||||
slide_image = slide_dict[SlideKey.IMAGE]
|
||||
location_bbox = slide_dict[SlideKey.ORIGIN]
|
||||
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
|
||||
|
||||
fig = plot_heatmap_overlay(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
slide_image=slide_image,
|
||||
slide_image=slide_dict[SlideKey.IMAGE],
|
||||
results=results,
|
||||
location_bbox=location_bbox,
|
||||
location_bbox=slide_dict[SlideKey.ORIGIN],
|
||||
tile_size=tile_size,
|
||||
level=level,
|
||||
)
|
||||
|
@ -205,33 +206,34 @@ class DeepMILPlotsHandler:
|
|||
self.wsi_has_mask = wsi_has_mask
|
||||
self.slides_dataset: Optional[SlidesDataset] = None
|
||||
|
||||
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
|
||||
"""Returns the slide dictionary for a given slide node"""
|
||||
assert self.slides_dataset is not None, "Cannot plot attention heatmap or wsi without slides dataset"
|
||||
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = self.slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
|
||||
return slide_dict
|
||||
|
||||
def save_slide_node_figures(
|
||||
self, case: str, slide_node: SlideNode, outputs_dir: Path, results: ResultsType
|
||||
) -> None:
|
||||
"""Plots and saves all slide related figures, e.g., `TOP_BOTTOM_TILES` and `SLIDE_THUMBNAIL_HEATMAP`"""
|
||||
|
||||
"""Plots and saves all slide related figures: `TOP_BOTTOM_TILES`, `SLIDE_THUMBNAIL` and `ATTENTION_HEATMAP`."""
|
||||
case_dir = make_figure_dirs(subfolder=case, parent_dir=outputs_dir)
|
||||
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.plot_options:
|
||||
save_top_and_bottom_tiles(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
figures_dir=case_dir,
|
||||
num_columns=self.num_columns,
|
||||
figsize=self.figsize,
|
||||
)
|
||||
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in self.plot_options:
|
||||
assert self.slides_dataset
|
||||
save_slide_thumbnail_and_heatmap(
|
||||
case=case,
|
||||
slide_node=slide_node,
|
||||
figures_dir=case_dir,
|
||||
results=results,
|
||||
slides_dataset=self.slides_dataset,
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
)
|
||||
save_top_and_bottom_tiles(case, slide_node, case_dir, self.num_columns, self.figsize)
|
||||
|
||||
if PlotOption.ATTENTION_HEATMAP in self.plot_options or PlotOption.SLIDE_THUMBNAIL in self.plot_options:
|
||||
slide_dict = self.get_slide_dict(slide_node=slide_node)
|
||||
|
||||
if PlotOption.SLIDE_THUMBNAIL in self.plot_options:
|
||||
save_slide_thumbnail(case=case, slide_node=slide_node, slide_dict=slide_dict, figures_dir=case_dir)
|
||||
|
||||
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
|
||||
save_attention_heatmap(
|
||||
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
|
||||
)
|
||||
|
||||
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
|
||||
"""Plots and saves all selected plot options during inference (validation or test) time.
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Collection
|
||||
from typing import Any, Collection, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
|
@ -17,21 +17,28 @@ from testhisto.mocks.container import MockDeepSMILETilesPanda
|
|||
|
||||
def test_plots_handler_wrong_class_names() -> None:
|
||||
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
|
||||
_ = DeepMILPlotsHandler(plot_options, class_names=[])
|
||||
assert "No class_names were provided while activating confusion matrix plotting." in str(ex)
|
||||
|
||||
|
||||
def test_plots_handler_slide_thumbnails_without_slide_dataset() -> None:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
@pytest.mark.parametrize(
|
||||
"slide_plot_options",
|
||||
[
|
||||
[PlotOption.SLIDE_THUMBNAIL],
|
||||
[PlotOption.ATTENTION_HEATMAP],
|
||||
[PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP]
|
||||
],
|
||||
)
|
||||
def test_plots_handler_slide_plot_options_without_slide_dataset(slide_plot_options: List[PlotOption]) -> None:
|
||||
exception_prompt = f"Plot option {slide_plot_options[0]} requires a slides dataset"
|
||||
with pytest.raises(ValueError, match=rf"{exception_prompt}"):
|
||||
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"))
|
||||
container.setup()
|
||||
container.data_module = MagicMock()
|
||||
container.data_module.train_dataset.n_classes = 6
|
||||
outputs_handler = container.get_outputs_handler()
|
||||
outputs_handler.test_plots_handler.plot_options = {PlotOption.SLIDE_THUMBNAIL_HEATMAP}
|
||||
outputs_handler.test_plots_handler.plot_options = slide_plot_options
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(container.get_slides_dataset())
|
||||
assert "You can not plot slide thumbnails and heatmaps without setting a slides_dataset." in str(ex)
|
||||
|
||||
|
||||
def assert_plot_func_called_if_among_plot_options(
|
||||
|
@ -52,13 +59,14 @@ def assert_plot_func_called_if_among_plot_options(
|
|||
{},
|
||||
{PlotOption.HISTOGRAM, PlotOption.PR_CURVE},
|
||||
{PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX},
|
||||
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.SLIDE_THUMBNAIL_HEATMAP},
|
||||
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.ATTENTION_HEATMAP},
|
||||
{
|
||||
PlotOption.HISTOGRAM,
|
||||
PlotOption.PR_CURVE,
|
||||
PlotOption.CONFUSION_MATRIX,
|
||||
PlotOption.TOP_BOTTOM_TILES,
|
||||
PlotOption.SLIDE_THUMBNAIL_HEATMAP,
|
||||
PlotOption.SLIDE_THUMBNAIL,
|
||||
PlotOption.ATTENTION_HEATMAP,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
@ -72,23 +80,23 @@ def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[
|
|||
tiles_selector.top_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
|
||||
tiles_selector.bottom_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
|
||||
|
||||
with patch("health_cpath.utils.plots_utils.save_slide_thumbnail_and_heatmap") as mock_slide:
|
||||
with patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles") as mock_tile:
|
||||
with patch("health_cpath.utils.plots_utils.save_scores_histogram") as mock_histogram:
|
||||
with patch("health_cpath.utils.plots_utils.save_confusion_matrix") as mock_conf:
|
||||
with patch("health_cpath.utils.plots_utils.save_pr_curve") as mock_pr:
|
||||
plots_handler.save_plots(
|
||||
outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock()
|
||||
)
|
||||
patchers: Dict[PlotOption, Any] = {
|
||||
PlotOption.SLIDE_THUMBNAIL: patch("health_cpath.utils.plots_utils.save_slide_thumbnail"),
|
||||
PlotOption.ATTENTION_HEATMAP: patch("health_cpath.utils.plots_utils.save_attention_heatmap"),
|
||||
PlotOption.TOP_BOTTOM_TILES: patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles"),
|
||||
PlotOption.CONFUSION_MATRIX: patch("health_cpath.utils.plots_utils.save_confusion_matrix"),
|
||||
PlotOption.HISTOGRAM: patch("health_cpath.utils.plots_utils.save_scores_histogram"),
|
||||
PlotOption.PR_CURVE: patch("health_cpath.utils.plots_utils.save_pr_curve"),
|
||||
}
|
||||
|
||||
mock_funcs = {option: patcher.start() for option, patcher in patchers.items()} # type: ignore
|
||||
with patch.object(plots_handler, "get_slide_dict"):
|
||||
plots_handler.save_plots(outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock())
|
||||
patch.stopall()
|
||||
|
||||
calls_count = 0
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(
|
||||
mock_slide, PlotOption.SLIDE_THUMBNAIL_HEATMAP, plot_options
|
||||
)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_tile, PlotOption.TOP_BOTTOM_TILES, plot_options)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_histogram, PlotOption.HISTOGRAM, plot_options)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_conf, PlotOption.CONFUSION_MATRIX, plot_options)
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_pr, PlotOption.PR_CURVE, plot_options)
|
||||
for option, mock_func in mock_funcs.items():
|
||||
calls_count += assert_plot_func_called_if_among_plot_options(mock_func, option, plot_options)
|
||||
|
||||
assert calls_count == len(plot_options)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче