Add accuracy at threshold 0.5 to classification report (#450)

Adds the metric "Accuracy at threshold 0.5" to the classification report (`classification_crossval_report.ipynb`). Also deletes the unused `classification_report.ipynb`.
This commit is contained in:
Shruthi42 2021-05-04 11:09:35 +01:00 коммит произвёл GitHub
Родитель 9a7ac87da4
Коммит 35423b3674
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 33 добавлений и 310 удалений

Просмотреть файл

@ -60,6 +60,7 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
model configs with custom behaviour while leveraging the existing InnerEye workflows.
- ([#445](https://github.com/microsoft/InnerEye-DeepLearning/pull/445)) Adding test coverage for the `HelloContainer`
model with multiple GPUs
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Adds the metric "Accuracy at threshold 0.5" to the classification report (`classification_crossval_report.ipynb`).
### Changed
@ -93,6 +94,7 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
### Removed
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Deprecated `start_epoch` config argument.
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Delete unused `classification_report.ipynb`.
### Deprecated

Просмотреть файл

@ -1,287 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"%%javascript\n",
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
" return false;\n",
"}\n",
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
"# This cell must carry the tag \"parameters\" in its metadata.\n",
"from pathlib import Path\n",
"import pickle\n",
"import codecs\n",
"\n",
"innereye_path = Path.cwd().parent.parent.parent\n",
"train_metrics_csv = \"\"\n",
"val_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'val_metrics_classification.csv'\n",
"test_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'test_metrics_classification.csv'\n",
"number_best_and_worst_performing = 20\n",
"config= \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import sys\n",
"print(f\"Adding to path: {innereye_path}\")\n",
"if str(innereye_path) not in sys.path:\n",
" sys.path.append(str(innereye_path))\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
"\n",
"from InnerEye.ML.reports.notebook_report import print_header\n",
"from InnerEye.ML.reports.classification_report import plot_pr_and_roc_curves_from_csv, \\\n",
"print_k_best_and_worst_performing, print_metrics_for_all_prediction_targets, \\\n",
"plot_k_best_and_worst_performing, get_labels_and_predictions\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"plt.rcParams['figure.figsize'] = (20, 10)\n",
"\n",
"#convert params to Path\n",
"train_metrics_csv = Path(train_metrics_csv)\n",
"val_metrics_csv = Path(val_metrics_csv)\n",
"test_metrics_csv = Path(test_metrics_csv)"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"# Train Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"if train_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"# Validation Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"# Test Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# AUC and PR curves\n",
"## Train Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"if train_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=train_metrics_csv, config=config)"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"## Validation set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=val_metrics_csv, config=config)"
]
},
{
"cell_type": "markdown",
"id": "14",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Test set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"if test_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=test_metrics_csv, config=config)"
]
},
{
"cell_type": "markdown",
"id": "16",
"metadata": {},
"source": [
"# Best and worst samples by ID"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing,\n",
" prediction_target=prediction_target)"
]
},
{
"cell_type": "markdown",
"id": "18",
"metadata": {},
"source": [
"# Plot best and worst sample images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Просмотреть файл

@ -49,7 +49,8 @@ class ReportedScalarMetrics(Enum):
AUC_PR = "Area under PR Curve", False
AUC_ROC = "Area under ROC Curve", False
OptimalThreshold = "Optimal threshold", False
Accuracy = "Accuracy at optimal threshold", True
AccuracyAtOptimalThreshold = "Accuracy at optimal threshold", True
AccuracyAtThreshold05 = "Accuracy at threshold 0.5", True
Sensitivity = "Sensitivity at optimal threshold", True
Specificity = "Specificity at optimal threshold", True
@ -326,10 +327,14 @@ def get_metric(predictions_to_set_optimal_threshold: LabelsAndPredictions,
precision, recall, _ = precision_recall_curve(predictions_to_compute_metrics.labels,
predictions_to_compute_metrics.model_outputs)
return auc(recall, precision)
elif metric is ReportedScalarMetrics.Accuracy:
elif metric is ReportedScalarMetrics.AccuracyAtOptimalThreshold:
return binary_classification_accuracy(model_output=predictions_to_compute_metrics.model_outputs,
label=predictions_to_compute_metrics.labels,
threshold=optimal_threshold)
elif metric is ReportedScalarMetrics.AccuracyAtThreshold05:
return binary_classification_accuracy(model_output=predictions_to_compute_metrics.model_outputs,
label=predictions_to_compute_metrics.labels,
threshold=0.5)
elif metric is ReportedScalarMetrics.Specificity:
return recall_score(predictions_to_compute_metrics.labels,
predictions_to_compute_metrics.model_outputs >= optimal_threshold, pos_label=0)

Просмотреть файл

@ -894,5 +894,5 @@ class MLRunner:
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
except Exception as ex:
print_exception(ex, "Failed to generated reporting notebook.")
print_exception(ex, "Failed to generate reporting notebook.")
raise

Просмотреть файл

@ -196,17 +196,23 @@ def test_get_metric() -> None:
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
predictions_to_set_optimal_threshold=val_metrics,
metric=ReportedScalarMetrics.Accuracy)
metric=ReportedScalarMetrics.AccuracyAtOptimalThreshold)
assert accuracy == 0.5
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
predictions_to_set_optimal_threshold=val_metrics,
metric=ReportedScalarMetrics.Accuracy,
metric=ReportedScalarMetrics.AccuracyAtOptimalThreshold,
optimal_threshold=0.1)
assert accuracy == 0.5
accuracy = get_metric(predictions_to_compute_metrics=test_metrics,
predictions_to_set_optimal_threshold=val_metrics,
metric=ReportedScalarMetrics.AccuracyAtThreshold05)
assert accuracy == 0.5
specificity = get_metric(predictions_to_compute_metrics=test_metrics,
predictions_to_set_optimal_threshold=val_metrics,
metric=ReportedScalarMetrics.Specificity)
@ -257,12 +263,13 @@ def test_get_metrics_table_single_run() -> None:
is_thresholded=False, is_crossval_report=False)
expected_header = "Metric Value".split('\t')
expected_rows = [
"Area under PR Curve 0.5417".split('\t'),
"Area under ROC Curve 0.5000".split('\t'),
"Optimal threshold 0.6000".split('\t'),
"Accuracy at optimal threshold 0.5000".split('\t'),
"Sensitivity at optimal threshold 0.5000".split('\t'),
"Specificity at optimal threshold 0.5000".split('\t'),
f"{ReportedScalarMetrics.AUC_PR.value[0]} 0.5417".split('\t'),
f"{ReportedScalarMetrics.AUC_ROC.value[0]} 0.5000".split('\t'),
f"{ReportedScalarMetrics.OptimalThreshold.value[0]} 0.6000".split('\t'),
f"{ReportedScalarMetrics.AccuracyAtOptimalThreshold.value[0]} 0.5000".split('\t'),
f"{ReportedScalarMetrics.AccuracyAtThreshold05.value[0]} 0.5000".split('\t'),
f"{ReportedScalarMetrics.Sensitivity.value[0]} 0.5000".split('\t'),
f"{ReportedScalarMetrics.Specificity.value[0]} 0.5000".split('\t'),
]
check_table_equality(header, rows, expected_header, expected_rows)
@ -283,12 +290,13 @@ def test_get_metrics_table_crossval() -> None:
is_thresholded=False, is_crossval_report=True)
expected_header = "Metric Split 0 Split 1 Split 2 Mean (std)".split('\t')
expected_rows = [
"Area under PR Curve 0.5417 0.4481 0.6889 0.5595 (0.0991)".split('\t'),
"Area under ROC Curve 0.5000 0.2778 0.7222 0.5000 (0.1814)".split('\t'),
"Optimal threshold 0.6000 0.6000 0.6000 0.6000 (0.0000)".split('\t'),
"Accuracy at optimal threshold 0.5000 0.2500 0.7500 0.5000 (0.2041)".split('\t'),
"Sensitivity at optimal threshold 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
"Specificity at optimal threshold 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t')
f"{ReportedScalarMetrics.AUC_PR.value[0]} 0.5417 0.4481 0.6889 0.5595 (0.0991)".split('\t'),
f"{ReportedScalarMetrics.AUC_ROC.value[0]} 0.5000 0.2778 0.7222 0.5000 (0.1814)".split('\t'),
f"{ReportedScalarMetrics.OptimalThreshold.value[0]} 0.6000 0.6000 0.6000 0.6000 (0.0000)".split('\t'),
f"{ReportedScalarMetrics.AccuracyAtOptimalThreshold.value[0]} 0.5000 0.2500 0.7500 0.5000 (0.2041)".split('\t'),
f"{ReportedScalarMetrics.AccuracyAtThreshold05.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
f"{ReportedScalarMetrics.Sensitivity.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t'),
f"{ReportedScalarMetrics.Specificity.value[0]} 0.5000 0.1667 0.8333 0.5000 (0.2722)".split('\t')
]
check_table_equality(header, rows, expected_header, expected_rows)

Просмотреть файл

@ -13,7 +13,7 @@ from ruamel.yaml.comments import CommentedMap
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import namespace_to_path
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, INNEREYE_PACKAGE_ROOT
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME
ML_NAMESPACE = "InnerEye.ML"
@ -118,11 +118,6 @@ if is_dev_package:
print("\n ***** NOTE: This package is built for development purpose only. DO NOT RELEASE THIS! *****")
print(f"\n ***** Will install dev package data: {package_data} *****\n")
package_data[INNEREYE_PACKAGE_NAME] += [
str(INNEREYE_PACKAGE_ROOT / r"ML/reports/segmentation_report.ipynb"),
str(INNEREYE_PACKAGE_ROOT / r"ML/reports/classification_report.ipynb")
]
pre_processed_packages = _pre_process_packages()
try:
setuptools.setup(