зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add/reorganise metrics and correctly normalise confusion matrix (#578)
In this PR: Average precision metric is added to multi-class (n_classes > 1) case. AUROC is modified such that num_classes=None for binary case, as prescribed in Pytorch documentation https://torchmetrics.readthedocs.io/en/stable/classification/auroc.html. Parameter num_classes is not explicitly given in Specificity, similar to the other binary metrics. Hardcoded threshold=0.5 is removed from binary metrics (since this is default value). Confusion matrix is normalized on true values. Metrics are re-organized for readability.
This commit is contained in:
Родитель
952e93766b
Коммит
c9651adec4
|
@ -190,29 +190,31 @@ class BaseDeepMILModule(LightningModule):
|
|||
|
||||
def get_metrics(self) -> nn.ModuleDict:
|
||||
if self.n_classes > 1:
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
# is calculated using Cohen's Kappa with quadratic weights
|
||||
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(num_classes=self.n_classes),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
# is calculated using Cohen's Kappa with quadratic weights
|
||||
# https://www.kaggle.com/code/reighns/understanding-the-quadratic-weighted-kappa/
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=self.n_classes, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes),
|
||||
# Metrics below are computed for multi-class case only
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted')})
|
||||
else:
|
||||
threshold = 0.5
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.PRECISION: Precision(threshold=threshold),
|
||||
MetricsKey.RECALL: Recall(threshold=threshold),
|
||||
MetricsKey.F1: F1(threshold=threshold),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold),
|
||||
# Average precision is a measure of area under the PR curve
|
||||
# https://sanchom.wordpress.com/tag/average-precision/
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic',
|
||||
threshold=threshold),
|
||||
MetricsKey.SPECIFICITY: Specificity(num_classes=self.n_classes, threshold=threshold)})
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: Accuracy(),
|
||||
MetricsKey.AUROC: AUROC(num_classes=None),
|
||||
# Average precision is a measure of area under the PR curve
|
||||
MetricsKey.AVERAGE_PRECISION: AveragePrecision(),
|
||||
MetricsKey.COHENKAPPA: CohenKappa(num_classes=2, weights='quadratic'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2),
|
||||
# Metrics below are computed for binary case only
|
||||
MetricsKey.F1: F1(),
|
||||
MetricsKey.PRECISION: Precision(),
|
||||
MetricsKey.RECALL: Recall(),
|
||||
MetricsKey.SPECIFICITY: Specificity()})
|
||||
|
||||
def log_metrics(self, stage: str) -> None:
|
||||
valid_stages = [stage for stage in ModelKey]
|
||||
|
|
|
@ -68,7 +68,7 @@ def save_confusion_matrix(results: ResultsType, class_names: Sequence[str], figu
|
|||
true_labels,
|
||||
pred_labels,
|
||||
labels=all_potential_labels,
|
||||
normalize="pred"
|
||||
normalize="true"
|
||||
)
|
||||
|
||||
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=(class_names))
|
||||
|
|
Загрузка…
Ссылка в новой задаче