зеркало из https://github.com/microsoft/hi-ml.git
BUG: Use micro average for MultiClassAccuracy (#671)
the default for MultiClassAccuracy is macro [Accuracy — PyTorch-Metrics 0.10.3 documentation (torchmetrics.readthedocs.io)](https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#multiclassaccuracy) But base class (which we were using before) is micro [Accuracy — PyTorch-Metrics 0.10.3 documentation (torchmetrics.readthedocs.io)](https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html) We need to set average=micro
This commit is contained in:
Родитель
37c7860f2f
Коммит
9e812e7ca8
|
@ -176,7 +176,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
def get_metrics(self) -> nn.ModuleDict:
|
||||
if self.n_classes > 1:
|
||||
return nn.ModuleDict({
|
||||
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes),
|
||||
MetricsKey.ACC: MulticlassAccuracy(num_classes=self.n_classes, average='micro'),
|
||||
MetricsKey.AUROC: MulticlassAUROC(num_classes=self.n_classes),
|
||||
MetricsKey.AVERAGE_PRECISION: MulticlassAveragePrecision(num_classes=self.n_classes),
|
||||
# Quadratic Weighted Kappa (QWK) used in PANDA challenge
|
||||
|
|
Загрузка…
Ссылка в новой задаче