BUG: Use micro average for MultiClassAccuracy (#671)

the default for MultiClassAccuracy is macro [Accuracy — PyTorch-Metrics
0.10.3 documentation
(torchmetrics.readthedocs.io)](https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#multiclassaccuracy)

But base class (which we were using before) is micro [Accuracy —
PyTorch-Metrics 0.10.3 documentation
(torchmetrics.readthedocs.io)](https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html)

We need to set average=micro
This commit is contained in:
Kenza Bouzid 2022-11-18 14:19:27 +00:00 коммит произвёл GitHub
Родитель 37c7860f2f
Коммит 9e812e7ca8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -176,7 +176,7 @@ class BaseDeepMILModule(LightningModule):
def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes),
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes, average='micro'),
MetricsKey.AUROC: MulticlassAUROC(num_classes=self.n_classes),
MetricsKey.AVERAGE_PRECISION: MulticlassAveragePrecision(num_classes=self.n_classes),
# Quadratic Weighted Kappa (QWK) used in PANDA challenge