FIX: Fix bug in saving validation outputs (#649)

Fix bug causing flat validation accuracy plots and upgrade torchmetrics
This commit is contained in:
Melissa Bristow 2022-11-10 14:57:39 +00:00 коммит произвёл GitHub
Родитель b711b2b545
Коммит e18d37ee84
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 5 добавлений и 10 удалений

Просмотреть файл

@ -305,7 +305,7 @@ dependencies:
- toml==0.10.2
- tomli==2.0.1
- tomlkit==0.11.6
- torchmetrics==0.6.0
- torchmetrics==0.10.2
- tqdm==4.64.1
- twine==3.3.0
- typing-inspect==0.8.0

Просмотреть файл

@ -14,6 +14,6 @@ seaborn==0.10.1
simpleitk==2.1.1.2
tifffile==2021.11.2
torch==1.10
torchmetrics==0.6.0
torchmetrics==0.10.2
umap-learn==0.5.2
yacs==0.1.8

Просмотреть файл

@ -9,7 +9,7 @@ from pathlib import Path
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
from torchmetrics import AUROC, F1, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
from torchmetrics import AUROC, F1Score, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
from health_ml.utils import log_on_epoch
from health_ml.deep_learning_config import OptimizerParams
@ -148,7 +148,7 @@ class BaseDeepMILModule(LightningModule):
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1(threshold=threshold),
MetricsKey.F1: F1Score(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
def log_metrics(self, stage: str) -> None:

Просмотреть файл

@ -14,7 +14,6 @@ import torch
import logging
from ruamel.yaml import YAML
from torchmetrics import Accuracy
from torchmetrics.metric import Metric
from health_azure.utils import replace_directory
@ -213,10 +212,6 @@ class OutputsPolicy:
# The metric needs to be computed on all ranks to allow synchronisation
metric_value = float(metric.compute())
# It seems to be necessary to reset the Accuracy metric after computing, else some processes get stuck here
if isinstance(metric, Accuracy):
metric.reset()
# Validation outputs and best metric should be saved only by the global rank-0 process
if not is_global_rank_zero:
return False

Просмотреть файл

@ -236,7 +236,7 @@ dependencies:
- tomli==2.0.1
- tomlkit==0.11.5
- torch==1.12.1
- torchmetrics==0.10.0
- torchmetrics==0.10.2
- torchvision==0.13.1
- tqdm==4.63.0
- twine==3.3.0