зеркало из https://github.com/microsoft/hi-ml.git
FIX: Fix bug in saving validation outputs (#649)
Fix bug causing flat validation accuracy plots and upgrade torchmetrics
This commit is contained in:
Родитель
b711b2b545
Коммит
e18d37ee84
|
@ -305,7 +305,7 @@ dependencies:
|
|||
- toml==0.10.2
|
||||
- tomli==2.0.1
|
||||
- tomlkit==0.11.6
|
||||
- torchmetrics==0.6.0
|
||||
- torchmetrics==0.10.2
|
||||
- tqdm==4.64.1
|
||||
- twine==3.3.0
|
||||
- typing-inspect==0.8.0
|
||||
|
|
|
@ -14,6 +14,6 @@ seaborn==0.10.1
|
|||
simpleitk==2.1.1.2
|
||||
tifffile==2021.11.2
|
||||
torch==1.10
|
||||
torchmetrics==0.6.0
|
||||
torchmetrics==0.10.2
|
||||
umap-learn==0.5.2
|
||||
yacs==0.1.8
|
||||
|
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
|
||||
from torchmetrics import AUROC, F1, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
|
||||
from torchmetrics import AUROC, F1Score, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa # type: ignore
|
||||
|
||||
from health_ml.utils import log_on_epoch
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
|
@ -148,7 +148,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.PRECISION: Precision(threshold=threshold),
|
||||
MetricsKey.RECALL: Recall(threshold=threshold),
|
||||
MetricsKey.F1: F1(threshold=threshold),
|
||||
MetricsKey.F1: F1Score(threshold=threshold),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
|
||||
|
||||
def log_metrics(self, stage: str) -> None:
|
||||
|
|
|
@ -14,7 +14,6 @@ import torch
|
|||
import logging
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
from torchmetrics import Accuracy
|
||||
from torchmetrics.metric import Metric
|
||||
|
||||
from health_azure.utils import replace_directory
|
||||
|
@ -213,10 +212,6 @@ class OutputsPolicy:
|
|||
# The metric needs to be computed on all ranks to allow synchronisation
|
||||
metric_value = float(metric.compute())
|
||||
|
||||
# It seems to be necessary to reset the Accuracy metric after computing, else some processes get stuck here
|
||||
if isinstance(metric, Accuracy):
|
||||
metric.reset()
|
||||
|
||||
# Validation outputs and best metric should be saved only by the global rank-0 process
|
||||
if not is_global_rank_zero:
|
||||
return False
|
||||
|
|
|
@ -236,7 +236,7 @@ dependencies:
|
|||
- tomli==2.0.1
|
||||
- tomlkit==0.11.5
|
||||
- torch==1.12.1
|
||||
- torchmetrics==0.10.0
|
||||
- torchmetrics==0.10.2
|
||||
- torchvision==0.13.1
|
||||
- tqdm==4.63.0
|
||||
- twine==3.3.0
|
||||
|
|
Загрузка…
Ссылка в новой задаче