Add accuracy at threshold 0.5 to classification report (#450)
Adds the metric "Accuracy at threshold 0.5" to the classification report (`classification_crossval_report.ipynb`). Also deletes the unused `classification_report.ipynb`.
This commit is contained in:
Родитель
9a7ac87da4
Коммит
35423b3674
|
@ -60,6 +60,7 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
|
|||
model configs with custom behaviour while leveraging the existing InnerEye workflows.
|
||||
- ([#445](https://github.com/microsoft/InnerEye-DeepLearning/pull/445)) Adding test coverage for the `HelloContainer`
|
||||
model with multiple GPUs
|
||||
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Adds the metric "Accuracy at threshold 0.5" to the classification report (`classification_crossval_report.ipynb`).
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -93,6 +94,7 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
|
|||
|
||||
### Removed
|
||||
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Deprecated `start_epoch` config argument.
|
||||
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Delete unused `classification_report.ipynb`.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -1,287 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%javascript\n",
|
||||
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
|
||||
" return false;\n",
|
||||
"}\n",
|
||||
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"tags": [
|
||||
"parameters"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
|
||||
"# This cell must carry the tag \"parameters\" in its metadata.\n",
|
||||
"from pathlib import Path\n",
|
||||
"import pickle\n",
|
||||
"import codecs\n",
|
||||
"\n",
|
||||
"innereye_path = Path.cwd().parent.parent.parent\n",
|
||||
"train_metrics_csv = \"\"\n",
|
||||
"val_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'val_metrics_classification.csv'\n",
|
||||
"test_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'test_metrics_classification.csv'\n",
|
||||
"number_best_and_worst_performing = 20\n",
|
||||
"config= \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"print(f\"Adding to path: {innereye_path}\")\n",
|
||||
"if str(innereye_path) not in sys.path:\n",
|
||||
" sys.path.append(str(innereye_path))\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
|
||||
"\n",
|
||||
"from InnerEye.ML.reports.notebook_report import print_header\n",
|
||||
"from InnerEye.ML.reports.classification_report import plot_pr_and_roc_curves_from_csv, \\\n",
|
||||
"print_k_best_and_worst_performing, print_metrics_for_all_prediction_targets, \\\n",
|
||||
"plot_k_best_and_worst_performing, get_labels_and_predictions\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(\"ignore\")\n",
|
||||
"plt.rcParams['figure.figsize'] = (20, 10)\n",
|
||||
"\n",
|
||||
"#convert params to Path\n",
|
||||
"train_metrics_csv = Path(train_metrics_csv)\n",
|
||||
"val_metrics_csv = Path(val_metrics_csv)\n",
|
||||
"test_metrics_csv = Path(test_metrics_csv)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Validation Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Test Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# AUC and PR curves\n",
|
||||
"## Train Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=train_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Validation set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "13",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=val_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Test set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if test_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=test_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Best and worst samples by ID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" for prediction_target in config.target_names:\n",
|
||||
" print_header(f\"Class {prediction_target}\", level=3)\n",
|
||||
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" k=number_best_and_worst_performing,\n",
|
||||
" prediction_target=prediction_target)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Plot best and worst sample images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" for prediction_target in config.target_names:\n",
|
||||
" print_header(f\"Class {prediction_target}\", level=3)\n",
|
||||
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"celltoolbar": "Tags",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -49,7 +49,8 @@ class ReportedScalarMetrics(Enum):
|
|||
AUC_PR = "Area under PR Curve", False
|
||||
AUC_ROC = "Area under ROC Curve", False
|
||||
OptimalThreshold = "Optimal threshold", False
|
||||
Accuracy = "Accuracy at optimal threshold", True
|
||||
AccuracyAtOptimalThreshold = "Accuracy at optimal threshold", True
|
||||
AccuracyAtThreshold05 = "Accuracy at threshold 0.5", True
|
||||
Sensitivity = "Sensitivity at optimal threshold", True
|
||||
Specificity = "Specificity at optimal threshold", True
|
||||
|
||||
|
@ -326,10 +327,14 @@ def get_metric(predictions_to_set_optimal_threshold: LabelsAndPredictions,
|
|||
precision, recall, _ = precision_recall_curve(predictions_to_compute_metrics.labels,
|
||||
predictions_to_compute_metrics.model_outputs)
|
||||
return auc(recall, precision)
|
||||
elif metric is ReportedScalarMetrics.Accuracy:
|
||||
elif metric is ReportedScalarMetrics.AccuracyAtOptimalThreshold:
|
||||
return binary_classification_accuracy(model_output=predictions_to_compute_metrics.model_outputs,
|
||||
label=predictions_to_compute_metrics.labels,
|
||||
threshold=optimal_threshold)
|
||||
elif metric is ReportedScalarMetrics.AccuracyAtThreshold05:
|
||||
return binary_classification_accuracy(model_output=predictions_to_compute_metrics.model_outputs,
|
||||
label=predictions_to_compute_metrics.labels,
|
||||
threshold=0.5)
|
||||
elif metric is ReportedScalarMetrics.Specificity:
|
||||
return recall_score(predictions_to_compute_metrics.labels,
|
||||
predictions_to_compute_metrics.model_outputs >= optimal_threshold, pos_label=0)
|
||||
|
|
|
@ -894,5 +894,5 @@ class MLRunner:
|
|||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
except Exception as ex:
|
||||
print_exception(ex, "Failed to generated reporting notebook.")
|
||||
print_exception(ex, "Failed to generate reporting notebook.")
|
||||
raise
|
||||
|
|
|
@ -196,17 +196,23 @@ def test_get_metric() -> None:
|
|||
|
||||
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
|
||||
predictions_to_set_optimal_threshold=val_metrics,
|
||||
metric=ReportedScalarMetrics.Accuracy)
|
||||
metric=ReportedScalarMetrics.AccuracyAtOptimalThreshold)
|
||||
|
||||
assert accuracy == 0.5
|
||||
|
||||
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
|
||||
predictions_to_set_optimal_threshold=val_metrics,
|
||||
metric=ReportedScalarMetrics.Accuracy,
|
||||
metric=ReportedScalarMetrics.AccuracyAtOptimalThreshold,
|
||||
optimal_threshold=0.1)
|
||||
|
||||
assert accuracy == 0.5
|
||||
|
||||
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
|
||||
predictions_to_set_optimal_threshold=val_metrics,
|
||||
metric=ReportedScalarMetrics.AccuracyAtThreshold05)
|
||||
|
||||
assert accuracy == 0.5
|
||||
|
||||
specificity = get_metric(predictions_to_compute_metrics=test_metrics,
|
||||
predictions_to_set_optimal_threshold=val_metrics,
|
||||
metric=ReportedScalarMetrics.Specificity)
|
||||
|
@ -257,12 +263,13 @@ def test_get_metrics_table_single_run() -> None:
|
|||
is_thresholded=False, is_crossval_report=False)
|
||||
expected_header = "Metric Value".split('\t')
|
||||
expected_rows = [
|
||||
"Area under PR Curve 0.5417".split('\t'),
|
||||
"Area under ROC Curve 0.5000".split('\t'),
|
||||
"Optimal threshold 0.6000".split('\t'),
|
||||
"Accuracy at optimal threshold 0.5000".split('\t'),
|
||||
"Sensitivity at optimal threshold 0.5000".split('\t'),
|
||||
"Specificity at optimal threshold 0.5000".split('\t'),
|
||||
f"{ReportedScalarMetrics.AUC_PR.value[0]} 0.5417".split('\t'),
|
||||
f"{ReportedScalarMetrics.AUC_ROC.value[0]} 0.5000".split('\t'),
|
||||
f"{ReportedScalarMetrics.OptimalThreshold.value[0]} 0.6000".split('\t'),
|
||||
f"{ReportedScalarMetrics.AccuracyAtOptimalThreshold.value[0]} 0.5000".split('\t'),
|
||||
f"{ReportedScalarMetrics.AccuracyAtThreshold05.value[0]} 0.5000".split('\t'),
|
||||
f"{ReportedScalarMetrics.Sensitivity.value[0]} 0.5000".split('\t'),
|
||||
f"{ReportedScalarMetrics.Specificity.value[0]} 0.5000".split('\t'),
|
||||
]
|
||||
check_table_equality(header, rows, expected_header, expected_rows)
|
||||
|
||||
|
@ -283,12 +290,13 @@ def test_get_metrics_table_crossval() -> None:
|
|||
is_thresholded=False, is_crossval_report=True)
|
||||
expected_header = "Metric Split 0 Split 1 Split 2 Mean (std)".split('\t')
|
||||
expected_rows = [
|
||||
"Area under PR Curve 0.5417 0.4481 0.6889 0.5595 (0.0991)".split('\t'),
|
||||
"Area under ROC Curve 0.5000 0.2778 0.7222 0.5000 (0.1814)".split('\t'),
|
||||
"Optimal threshold 0.6000 0.6000 0.6000 0.6000 (0.0000)".split('\t'),
|
||||
"Accuracy at optimal threshold 0.5000 0.2500 0.7500 0.5000 (0.2041)".split('\t'),
|
||||
"Sensitivity at optimal threshold 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
|
||||
"Specificity at optimal threshold 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t')
|
||||
f"{ReportedScalarMetrics.AUC_PR.value[0]} 0.5417 0.4481 0.6889 0.5595 (0.0991)".split('\t'),
|
||||
f"{ReportedScalarMetrics.AUC_ROC.value[0]} 0.5000 0.2778 0.7222 0.5000 (0.1814)".split('\t'),
|
||||
f"{ReportedScalarMetrics.OptimalThreshold.value[0]} 0.6000 0.6000 0.6000 0.6000 (0.0000)".split('\t'),
|
||||
f"{ReportedScalarMetrics.AccuracyAtOptimalThreshold.value[0]} 0.5000 0.2500 0.7500 0.5000 (0.2041)".split('\t'),
|
||||
f"{ReportedScalarMetrics.AccuracyAtThreshold05.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
|
||||
f"{ReportedScalarMetrics.Sensitivity.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
|
||||
f"{ReportedScalarMetrics.Specificity.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t')
|
||||
]
|
||||
check_table_equality(header, rows, expected_header, expected_rows)
|
||||
|
||||
|
|
7
setup.py
7
setup.py
|
@ -13,7 +13,7 @@ from ruamel.yaml.comments import CommentedMap
|
|||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.common_util import namespace_to_path
|
||||
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, INNEREYE_PACKAGE_ROOT
|
||||
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME
|
||||
|
||||
ML_NAMESPACE = "InnerEye.ML"
|
||||
|
||||
|
@ -118,11 +118,6 @@ if is_dev_package:
|
|||
print("\n ***** NOTE: This package is built for development purpose only. DO NOT RELEASE THIS! *****")
|
||||
print(f"\n ***** Will install dev package data: {package_data} *****\n")
|
||||
|
||||
package_data[INNEREYE_PACKAGE_NAME] += [
|
||||
str(INNEREYE_PACKAGE_ROOT / r"ML/reports/segmentation_report.ipynb"),
|
||||
str(INNEREYE_PACKAGE_ROOT / r"ML/reports/classification_report.ipynb")
|
||||
]
|
||||
|
||||
pre_processed_packages = _pre_process_packages()
|
||||
try:
|
||||
setuptools.setup(
|
||||
|
|
Загрузка…
Ссылка в новой задаче