ENH: Add non normalized confusion matrix to plot options (#791)

Add non normalized confusion matrix to plot options
This commit is contained in:
Kenza Bouzid 2023-03-14 13:31:45 +00:00 коммит произвёл GitHub
Родитель fc5776875a
Коммит 273e84e6bb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 40 добавлений и 13 удалений

Просмотреть файл

@ -17,7 +17,7 @@ from health_cpath.utils.viz_utils import (
plot_attention_tiles,
plot_heatmap_overlay,
plot_attention_histogram,
plot_normalized_confusion_matrix,
plot_normalized_and_non_normalized_confusion_matrices,
plot_scores_hist,
plot_slide,
)
@ -130,15 +130,11 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
if pred_labels_diff_expected != set():
raise ValueError("More entries were found in predicted labels than are available in class names")
cf_matrix_n = confusion_matrix(
true_labels,
pred_labels,
labels=all_potential_labels,
normalize="true"
)
cf_matrix = confusion_matrix(true_labels, pred_labels, labels=all_potential_labels)
cf_matrix_n = confusion_matrix(true_labels, pred_labels, labels=all_potential_labels, normalize="true")
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))
save_figure(fig=fig, figpath=figures_dir / f"normalized_confusion_matrix_{stage}.png")
fig = plot_normalized_and_non_normalized_confusion_matrices(cm=cf_matrix, cm_n=cf_matrix_n, class_names=class_names)
save_figure(fig=fig, figpath=figures_dir / f"confusion_matrices_{stage}.png")
def save_top_and_bottom_tiles(

Просмотреть файл

@ -30,6 +30,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
def load_image_dict(sample: dict, loading_params: LoadingParams) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
@ -258,8 +259,9 @@ def plot_attention_histogram(case: str, slide_node: SlideNode, results: Dict[Res
def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: Sequence[str]) -> plt.Figure:
"""Plots a normalized confusion matrix and returns the figure.
param cm: Normalized confusion matrix to be plotted.
param class_names: List of class names.
:param cm: Normalized confusion matrix to be plotted.
:param class_names: List of class names.
"""
fig, ax = plt.subplots()
ax = sns.heatmap(cm, annot=True, cmap="Blues", fmt=".2%")
@ -270,9 +272,30 @@ def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: Sequence[str])
return fig
def plot_normalized_and_non_normalized_confusion_matrices(
cm: np.ndarray, cm_n: np.ndarray, class_names: Sequence[str],
) -> plt.Figure:
"""Plots a normalized and non-normalized confusion matrix and returns the figure.
:param cm: Non normalized confusion matrix to be plotted.
:param cm_n: Normalized confusion matrix to be plotted.
:param class_names: List of class names.
"""
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0] = sns.heatmap(cm, annot=True, cmap="Blues", fmt="d", ax=axes[0])
axes[1] = sns.heatmap(cm_n, annot=True, cmap="Blues", fmt=".2%", ax=axes[1])
for ax in axes:
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.xaxis.set_ticklabels(class_names)
ax.yaxis.set_ticklabels(class_names)
return fig
def resize_and_save(width_inch: int, height_inch: int, filename: Union[Path, str], dpi: int = 150) -> None:
"""
Resizes the present figure to the given (width, height) in inches, and saves it to the given filename.
:param width_inch: The width of the figure in inches.
:param height_inch: The height of the figure in inches.
:param filename: The filename to save to.

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 22 KiB

Просмотреть файл

@ -131,8 +131,12 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
class_names = ["foo", "bar"]
save_confusion_matrix(results, class_names, tmp_path, stage='foo')
file = Path(tmp_path) / "normalized_confusion_matrix_foo.png"
file = Path(tmp_path) / "confusion_matrices_foo.png"
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / "confusion_matrices_foo.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
# check that an error is raised if true labels include indices greater than the expected number of classes
invalid_results_1 = {
@ -156,11 +160,15 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
class_names_extended = ["foo", "bar", "baz"]
num_classes = len(class_names_extended)
expected_conf_matrix_shape = (num_classes, num_classes)
with patch("health_cpath.utils.plots_utils.plot_normalized_confusion_matrix") as mock_plot_conf_matrix:
with patch(
"health_cpath.utils.plots_utils.plot_normalized_and_non_normalized_confusion_matrices"
) as mock_plot_conf_matrix:
with patch("health_cpath.utils.plots_utils.save_figure"):
save_confusion_matrix(results, class_names_extended, tmp_path)
mock_plot_conf_matrix.assert_called_once()
actual_n_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm_n')
actual_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm')
assert actual_n_conf_matrix.shape == expected_conf_matrix_shape
assert actual_conf_matrix.shape == expected_conf_matrix_shape