зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add non normalized confusion matrix to plot options (#791)
Add non normalized confusion matrix to plot options
This commit is contained in:
Родитель
fc5776875a
Коммит
273e84e6bb
|
@ -17,7 +17,7 @@ from health_cpath.utils.viz_utils import (
|
|||
plot_attention_tiles,
|
||||
plot_heatmap_overlay,
|
||||
plot_attention_histogram,
|
||||
plot_normalized_confusion_matrix,
|
||||
plot_normalized_and_non_normalized_confusion_matrices,
|
||||
plot_scores_hist,
|
||||
plot_slide,
|
||||
)
|
||||
|
@ -130,15 +130,11 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
|
|||
if pred_labels_diff_expected != set():
|
||||
raise ValueError("More entries were found in predicted labels than are available in class names")
|
||||
|
||||
cf_matrix_n = confusion_matrix(
|
||||
true_labels,
|
||||
pred_labels,
|
||||
labels=all_potential_labels,
|
||||
normalize="true"
|
||||
)
|
||||
cf_matrix = confusion_matrix(true_labels, pred_labels, labels=all_potential_labels)
|
||||
cf_matrix_n = confusion_matrix(true_labels, pred_labels, labels=all_potential_labels, normalize="true")
|
||||
|
||||
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))
|
||||
save_figure(fig=fig, figpath=figures_dir / f"normalized_confusion_matrix_{stage}.png")
|
||||
fig = plot_normalized_and_non_normalized_confusion_matrices(cm=cf_matrix, cm_n=cf_matrix_n, class_names=class_names)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"confusion_matrices_{stage}.png")
|
||||
|
||||
|
||||
def save_top_and_bottom_tiles(
|
||||
|
|
|
@ -30,6 +30,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
|
|||
def load_image_dict(sample: dict, loading_params: LoadingParams) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
|
||||
:param sample: dict describing image metadata. Example:
|
||||
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
|
||||
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
|
||||
|
@ -258,8 +259,9 @@ def plot_attention_histogram(case: str, slide_node: SlideNode, results: Dict[Res
|
|||
|
||||
def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: Sequence[str]) -> plt.Figure:
|
||||
"""Plots a normalized confusion matrix and returns the figure.
|
||||
param cm: Normalized confusion matrix to be plotted.
|
||||
param class_names: List of class names.
|
||||
|
||||
:param cm: Normalized confusion matrix to be plotted.
|
||||
:param class_names: List of class names.
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
ax = sns.heatmap(cm, annot=True, cmap="Blues", fmt=".2%")
|
||||
|
@ -270,9 +272,30 @@ def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: Sequence[str])
|
|||
return fig
|
||||
|
||||
|
||||
def plot_normalized_and_non_normalized_confusion_matrices(
|
||||
cm: np.ndarray, cm_n: np.ndarray, class_names: Sequence[str],
|
||||
) -> plt.Figure:
|
||||
"""Plots a normalized and non-normalized confusion matrix and returns the figure.
|
||||
|
||||
:param cm: Non normalized confusion matrix to be plotted.
|
||||
:param cm_n: Normalized confusion matrix to be plotted.
|
||||
:param class_names: List of class names.
|
||||
"""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
|
||||
axes[0] = sns.heatmap(cm, annot=True, cmap="Blues", fmt="d", ax=axes[0])
|
||||
axes[1] = sns.heatmap(cm_n, annot=True, cmap="Blues", fmt=".2%", ax=axes[1])
|
||||
for ax in axes:
|
||||
ax.set_xlabel("Predicted")
|
||||
ax.set_ylabel("True")
|
||||
ax.xaxis.set_ticklabels(class_names)
|
||||
ax.yaxis.set_ticklabels(class_names)
|
||||
return fig
|
||||
|
||||
|
||||
def resize_and_save(width_inch: int, height_inch: int, filename: Union[Path, str], dpi: int = 150) -> None:
|
||||
"""
|
||||
Resizes the present figure to the given (width, height) in inches, and saves it to the given filename.
|
||||
|
||||
:param width_inch: The width of the figure in inches.
|
||||
:param height_inch: The height of the figure in inches.
|
||||
:param filename: The filename to save to.
|
||||
|
|
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/confusion_matrices_foo.png
Normal file
Двоичные данные
hi-ml-cpath/testhisto/test_data/histo_heatmaps/confusion_matrices_foo.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 22 KiB |
|
@ -131,8 +131,12 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
|
|||
class_names = ["foo", "bar"]
|
||||
|
||||
save_confusion_matrix(results, class_names, tmp_path, stage='foo')
|
||||
file = Path(tmp_path) / "normalized_confusion_matrix_foo.png"
|
||||
file = Path(tmp_path) / "confusion_matrices_foo.png"
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / "confusion_matrices_foo.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
# check that an error is raised if true labels include indices greater than the expected number of classes
|
||||
invalid_results_1 = {
|
||||
|
@ -156,11 +160,15 @@ def test_save_conf_matrix_integration(tmp_path: Path) -> None:
|
|||
class_names_extended = ["foo", "bar", "baz"]
|
||||
num_classes = len(class_names_extended)
|
||||
expected_conf_matrix_shape = (num_classes, num_classes)
|
||||
with patch("health_cpath.utils.plots_utils.plot_normalized_confusion_matrix") as mock_plot_conf_matrix:
|
||||
with patch(
|
||||
"health_cpath.utils.plots_utils.plot_normalized_and_non_normalized_confusion_matrices"
|
||||
) as mock_plot_conf_matrix:
|
||||
with patch("health_cpath.utils.plots_utils.save_figure"):
|
||||
save_confusion_matrix(results, class_names_extended, tmp_path)
|
||||
mock_plot_conf_matrix.assert_called_once()
|
||||
actual_n_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm_n')
|
||||
actual_conf_matrix = mock_plot_conf_matrix.call_args[1].get('cm')
|
||||
assert actual_n_conf_matrix.shape == expected_conf_matrix_shape
|
||||
assert actual_conf_matrix.shape == expected_conf_matrix_shape
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче