ENH: Separate attention heatmap and slide thumbnail plot options (#640)

Dissociate thumbnails and heatmaps plots to flexibly remove/add each of
these plot options.
Motivations: thumbnail plots can be costly, so we might want to omit it
sometime while including heatmap plots and vice versa.
This commit is contained in:
Kenza Bouzid 2022-10-24 09:15:38 +01:00 коммит произвёл GitHub
Родитель 84e37efc59
Коммит b2e873daa5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 82 добавлений и 68 удалений

Просмотреть файл

@ -94,7 +94,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
def get_test_plot_options(self) -> Set[PlotOption]:
plot_options = super().get_test_plot_options()
plot_options.add(PlotOption.SLIDE_THUMBNAIL_HEATMAP)
plot_options.update([PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP])
return plot_options

Просмотреть файл

@ -93,7 +93,8 @@ class AMLMetricsJsonKey(str, Enum):
class PlotOption(Enum):
TOP_BOTTOM_TILES = "top_bottom_tiles"
SLIDE_THUMBNAIL_HEATMAP = "slide_thumbnail_heatmap"
SLIDE_THUMBNAIL = "slide_thumbnail"
ATTENTION_HEATMAP = "attention_heatmap"
CONFUSION_MATRIX = "confusion_matrix"
HISTOGRAM = "histogram"
PR_CURVE = "pr_curve"

Просмотреть файл

@ -56,10 +56,13 @@ def validate_class_names(class_names: Optional[Sequence[str]], n_classes: int) -
def validate_slide_datasets_for_plot_options(
plot_options: Collection[PlotOption], slides_dataset: Optional[SlidesDataset]
) -> None:
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in plot_options and not slides_dataset:
raise ValueError("You can not plot slide thumbnails and heatmaps without setting a slides_dataset. "
"Please remove `PlotOption.SLIDE_THUMBNAIL_HEATMAP` from your plot options or provide "
"a slide dataset.")
def _validate_slide_plot_option(plot_option: PlotOption) -> None:
if plot_option in plot_options and not slides_dataset:
raise ValueError(f"Plot option {plot_option} requires a slides dataset")
_validate_slide_plot_option(PlotOption.SLIDE_THUMBNAIL)
_validate_slide_plot_option(PlotOption.ATTENTION_HEATMAP)
def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:

Просмотреть файл

@ -25,6 +25,7 @@ from health_cpath.utils.viz_utils import load_image_dict, save_figure
ResultsType = Dict[ResultsKey, List[Any]]
SlideDictType = Dict[SlideKey, Any]
def validate_class_names_for_plot_options(
@ -120,44 +121,44 @@ def save_top_and_bottom_tiles(
save_figure(fig=bottom_tiles_fig, figpath=figures_dir / f"{slide_node.slide_id}_bottom.png")
def save_slide_thumbnail_and_heatmap(
def save_slide_thumbnail(case: str, slide_node: SlideNode, slide_dict: SlideDictType, figures_dir: Path) -> None:
"""Plots and saves a slide thumbnail
:param case: The report case (e.g., TP, FN, ...)
:param slide_node: The slide node that encapsulates the slide metadata.
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
:param figures_dir: The path to the directory where to save the plots.
"""
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_dict[SlideKey.IMAGE], scale=1.0)
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
def save_attention_heatmap(
case: str,
slide_node: SlideNode,
slide_dict: SlideDictType,
figures_dir: Path,
results: ResultsType,
slides_dataset: SlidesDataset,
tile_size: int = 224,
level: int = 1,
wsi_has_mask: bool = True,
) -> None:
"""Plots and saves a slide thumbnail and attention heatmap
:param case: The report case (e.g., TP, FN, ...)
:param slide_node: The slide node that encapsulates the slide metadata.
:param slide_dict: The slide dictionary that contains the slide image and other metadata.
:param figures_dir: The path to the directory where to save the plots.
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
:param slides_dataset: The slides dataset from where to pick the slide.
:param tile_size: Size of each tile. Default 224.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
:param wsi_has_mask: Whether the slide has a mask or not. Default True.
"""
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=level, margin=0, wsi_has_mask=wsi_has_mask)
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.ORIGIN]
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
fig = plot_heatmap_overlay(
case=case,
slide_node=slide_node,
slide_image=slide_image,
slide_image=slide_dict[SlideKey.IMAGE],
results=results,
location_bbox=location_bbox,
location_bbox=slide_dict[SlideKey.ORIGIN],
tile_size=tile_size,
level=level,
)
@ -205,33 +206,34 @@ class DeepMILPlotsHandler:
self.wsi_has_mask = wsi_has_mask
self.slides_dataset: Optional[SlidesDataset] = None
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
"""Returns the slide dictionary for a given slide node"""
assert self.slides_dataset is not None, "Cannot plot attention heatmap or wsi without slides dataset"
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = self.slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
return slide_dict
def save_slide_node_figures(
self, case: str, slide_node: SlideNode, outputs_dir: Path, results: ResultsType
) -> None:
"""Plots and saves all slide related figures, e.g., `TOP_BOTTOM_TILES` and `SLIDE_THUMBNAIL_HEATMAP`"""
"""Plots and saves all slide related figures: `TOP_BOTTOM_TILES`, `SLIDE_THUMBNAIL` and `ATTENTION_HEATMAP`."""
case_dir = make_figure_dirs(subfolder=case, parent_dir=outputs_dir)
if PlotOption.TOP_BOTTOM_TILES in self.plot_options:
save_top_and_bottom_tiles(
case=case,
slide_node=slide_node,
figures_dir=case_dir,
num_columns=self.num_columns,
figsize=self.figsize,
)
if PlotOption.SLIDE_THUMBNAIL_HEATMAP in self.plot_options:
assert self.slides_dataset
save_slide_thumbnail_and_heatmap(
case=case,
slide_node=slide_node,
figures_dir=case_dir,
results=results,
slides_dataset=self.slides_dataset,
tile_size=self.tile_size,
level=self.level,
wsi_has_mask=self.wsi_has_mask,
)
save_top_and_bottom_tiles(case, slide_node, case_dir, self.num_columns, self.figsize)
if PlotOption.ATTENTION_HEATMAP in self.plot_options or PlotOption.SLIDE_THUMBNAIL in self.plot_options:
slide_dict = self.get_slide_dict(slide_node=slide_node)
if PlotOption.SLIDE_THUMBNAIL in self.plot_options:
save_slide_thumbnail(case=case, slide_node=slide_node, slide_dict=slide_dict, figures_dir=case_dir)
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
save_attention_heatmap(
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
)
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
"""Plots and saves all selected plot options during inference (validation or test) time.

Просмотреть файл

@ -5,7 +5,7 @@
import logging
import os
from pathlib import Path
from typing import Collection
from typing import Any, Collection, Dict, List
from unittest.mock import MagicMock, patch
import pytest
@ -17,21 +17,28 @@ from testhisto.mocks.container import MockDeepSMILETilesPanda
def test_plots_handler_wrong_class_names() -> None:
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
_ = DeepMILPlotsHandler(plot_options, class_names=[])
assert "No class_names were provided while activating confusion matrix plotting." in str(ex)
def test_plots_handler_slide_thumbnails_without_slide_dataset() -> None:
with pytest.raises(ValueError) as ex:
@pytest.mark.parametrize(
"slide_plot_options",
[
[PlotOption.SLIDE_THUMBNAIL],
[PlotOption.ATTENTION_HEATMAP],
[PlotOption.SLIDE_THUMBNAIL, PlotOption.ATTENTION_HEATMAP]
],
)
def test_plots_handler_slide_plot_options_without_slide_dataset(slide_plot_options: List[PlotOption]) -> None:
exception_prompt = f"Plot option {slide_plot_options[0]} requires a slides dataset"
with pytest.raises(ValueError, match=rf"{exception_prompt}"):
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"))
container.setup()
container.data_module = MagicMock()
container.data_module.train_dataset.n_classes = 6
outputs_handler = container.get_outputs_handler()
outputs_handler.test_plots_handler.plot_options = {PlotOption.SLIDE_THUMBNAIL_HEATMAP}
outputs_handler.test_plots_handler.plot_options = slide_plot_options
outputs_handler.set_slides_dataset_for_plots_handlers(container.get_slides_dataset())
assert "You can not plot slide thumbnails and heatmaps without setting a slides_dataset." in str(ex)
def assert_plot_func_called_if_among_plot_options(
@ -52,13 +59,14 @@ def assert_plot_func_called_if_among_plot_options(
{},
{PlotOption.HISTOGRAM, PlotOption.PR_CURVE},
{PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX},
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.SLIDE_THUMBNAIL_HEATMAP},
{PlotOption.HISTOGRAM, PlotOption.TOP_BOTTOM_TILES, PlotOption.ATTENTION_HEATMAP},
{
PlotOption.HISTOGRAM,
PlotOption.PR_CURVE,
PlotOption.CONFUSION_MATRIX,
PlotOption.TOP_BOTTOM_TILES,
PlotOption.SLIDE_THUMBNAIL_HEATMAP,
PlotOption.SLIDE_THUMBNAIL,
PlotOption.ATTENTION_HEATMAP,
},
],
)
@ -72,23 +80,23 @@ def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[
tiles_selector.top_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
tiles_selector.bottom_slides_heaps = {0: [slide_node] * n_tiles, 1: [slide_node] * n_tiles}
with patch("health_cpath.utils.plots_utils.save_slide_thumbnail_and_heatmap") as mock_slide:
with patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles") as mock_tile:
with patch("health_cpath.utils.plots_utils.save_scores_histogram") as mock_histogram:
with patch("health_cpath.utils.plots_utils.save_confusion_matrix") as mock_conf:
with patch("health_cpath.utils.plots_utils.save_pr_curve") as mock_pr:
plots_handler.save_plots(
outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock()
)
patchers: Dict[PlotOption, Any] = {
PlotOption.SLIDE_THUMBNAIL: patch("health_cpath.utils.plots_utils.save_slide_thumbnail"),
PlotOption.ATTENTION_HEATMAP: patch("health_cpath.utils.plots_utils.save_attention_heatmap"),
PlotOption.TOP_BOTTOM_TILES: patch("health_cpath.utils.plots_utils.save_top_and_bottom_tiles"),
PlotOption.CONFUSION_MATRIX: patch("health_cpath.utils.plots_utils.save_confusion_matrix"),
PlotOption.HISTOGRAM: patch("health_cpath.utils.plots_utils.save_scores_histogram"),
PlotOption.PR_CURVE: patch("health_cpath.utils.plots_utils.save_pr_curve"),
}
mock_funcs = {option: patcher.start() for option, patcher in patchers.items()} # type: ignore
with patch.object(plots_handler, "get_slide_dict"):
plots_handler.save_plots(outputs_dir=MagicMock(), tiles_selector=tiles_selector, results=MagicMock())
patch.stopall()
calls_count = 0
calls_count += assert_plot_func_called_if_among_plot_options(
mock_slide, PlotOption.SLIDE_THUMBNAIL_HEATMAP, plot_options
)
calls_count += assert_plot_func_called_if_among_plot_options(mock_tile, PlotOption.TOP_BOTTOM_TILES, plot_options)
calls_count += assert_plot_func_called_if_among_plot_options(mock_histogram, PlotOption.HISTOGRAM, plot_options)
calls_count += assert_plot_func_called_if_among_plot_options(mock_conf, PlotOption.CONFUSION_MATRIX, plot_options)
calls_count += assert_plot_func_called_if_among_plot_options(mock_pr, PlotOption.PR_CURVE, plot_options)
for option, mock_func in mock_funcs.items():
calls_count += assert_plot_func_called_if_among_plot_options(mock_func, option, plot_options)
assert calls_count == len(plot_options)