ENH: Add/reorganise metrics and correctly normalise confusion matrix (#578)

In this PR:

Average precision metric is added to multi-class (n_classes > 1) case.
AUROC is modified such that num_classes=None for binary case, as prescribed in Pytorch documentation https://torchmetrics.readthedocs.io/en/stable/classification/auroc.html.
Parameter num_classes is not explicitly given in Specificity, similar to the other binary metrics.
Hardcoded threshold=0.5 is removed from binary metrics (since this is default value).
Confusion matrix is normalized on true values.
Metrics are re-organized for readability.
This commit is contained in:
Harshita Sharma 2022-08-19 15:45:18 +01:00 коммит произвёл GitHub
Родитель 952e93766b
Коммит c9651adec4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 25 добавлений и 23 удалений

Просмотреть файл

@ -190,29 +190,31 @@ class BaseDeepMILModule(LightningModule):
def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
# is calculated using Cohen's Kappa with quadratic weights
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
# is calculated using Cohen's Kappa with quadratic weights
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
# Metrics below are computed for multi-class case only
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
else:
threshold = 0.5
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold),
# Average precision is a measure of area under the PR curve
# https://sanchom.wordpress.com/tag/average-precision/
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic',
threshold=threshold),
MetricsKey.SPECIFICITY: Specificity(num_classes=self.n_classes, threshold=threshold)})
return nn.ModuleDict({
MetricsKey.ACC: Accuracy(),
MetricsKey.AUROC: AUROC(num_classes=None),
# Average precision is a measure of area under the PR curve
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
# Metrics below are computed for binary case only
MetricsKey.F1: F1(),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.SPECIFICITY: Specificity()})
def log_metrics(self, stage: str) -> None:
valid_stages = [stage for stage in ModelKey]

Просмотреть файл

@ -68,7 +68,7 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
true_labels,
pred_labels,
labels=all_potential_labels,
normalize="pred"
normalize="true"
)
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))