Fix DeepMIL metrics bug (#674)
* Fix DeepMIL metrics input bug * Add first version of metrics tests * Update submodule * Add test for DeepMIL metrics inputs * Clean-up and update submodule * Update changelog * Upgrade mlflow due to Component Governance warning
This commit is contained in:
Родитель
d7e5d8b5e5
Коммит
e984554c9e
|
@ -128,6 +128,7 @@ in inference-only runs when using lightning containers.
|
|||
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
|
||||
- ([#652](https://github.com/microsoft/InnerEye-DeepLearning/pull/652)) Run pytest build on Windows after Linux agent version upgrade
|
||||
- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04
|
||||
- ([#674](https://github.com/microsoft/InnerEye-DeepLearning/pull/674)) Fix DeepMIL metrics bug whereby hard labels were used instead of probabilities.
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -179,12 +179,13 @@ class DeepMILModule(LightningModule):
|
|||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
|
||||
else:
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(),
|
||||
threshold = 0.5
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.PRECISION: Precision(),
|
||||
MetricsKey.RECALL: Recall(),
|
||||
MetricsKey.F1: F1(),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes+1)})
|
||||
MetricsKey.PRECISION: Precision(threshold=threshold),
|
||||
MetricsKey.RECALL: Recall(threshold=threshold),
|
||||
MetricsKey.F1: F1(threshold=threshold),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
|
||||
|
||||
def log_metrics(self,
|
||||
stage: str) -> None:
|
||||
|
@ -238,24 +239,24 @@ class DeepMILModule(LightningModule):
|
|||
else:
|
||||
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
|
||||
probs = self.activation_fn(bag_logits)
|
||||
predicted_probs = self.activation_fn(bag_logits)
|
||||
if self.n_classes > 1:
|
||||
preds = argmax(probs, dim=1)
|
||||
predicted_labels = argmax(predicted_probs, dim=1)
|
||||
else:
|
||||
preds = round(probs)
|
||||
predicted_labels = round(predicted_probs)
|
||||
|
||||
loss = loss.view(-1, 1)
|
||||
preds = preds.view(-1, 1)
|
||||
probs = probs.view(-1, 1)
|
||||
predicted_labels = predicted_labels.view(-1, 1)
|
||||
predicted_probs = predicted_probs.view(-1, 1)
|
||||
bag_labels = bag_labels.view(-1, 1)
|
||||
|
||||
results = dict()
|
||||
for metric_object in self.get_metrics_dict(stage).values():
|
||||
metric_object.update(preds, bag_labels)
|
||||
metric_object.update(predicted_probs, bag_labels)
|
||||
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
|
||||
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
|
||||
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
|
||||
ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds,
|
||||
ResultsKey.PROB: predicted_probs, ResultsKey.PRED_LABEL: predicted_labels,
|
||||
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
|
||||
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})
|
||||
|
||||
|
|
|
@ -4,13 +4,16 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Type # noqa
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from torchmetrics import Accuracy, Metric # noqa
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
AttentionLayer,
|
||||
GatedAttentionLayer,
|
||||
|
@ -30,7 +33,7 @@ from InnerEye.ML.Histopathology.datasets.default_paths import (
|
|||
PANDA_TILES_DATASET_DIR,
|
||||
)
|
||||
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
|
||||
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder
|
||||
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
|
||||
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey
|
||||
|
||||
|
||||
|
@ -157,6 +160,106 @@ def test_lightningmodule_mean_pooling(
|
|||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
|
||||
def is_integral(x: torch.Tensor) -> bool:
|
||||
return (x == x.long()).all() # type: ignore
|
||||
|
||||
assert scores.shape == labels.shape
|
||||
assert torch.is_floating_point(scores), "Received scores with integer dtype"
|
||||
assert not is_integral(scores), "Received scores with integral values"
|
||||
assert is_integral(labels), "Received labels with floating-point values"
|
||||
|
||||
|
||||
def add_callback(fn: Callable, callback: Callable) -> Callable:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
callback(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_metrics() -> None:
|
||||
input_dim = (128,)
|
||||
module = DeepMILModule(
|
||||
encoder=IdentityEncoder(input_dim=input_dim),
|
||||
label_column=TilesDataset.LABEL_COLUMN,
|
||||
n_classes=1,
|
||||
pooling_layer=AttentionLayer,
|
||||
)
|
||||
|
||||
# Patching to enable running the module without a Trainer object
|
||||
module.trainer = MagicMock(world_size=1) # type: ignore
|
||||
module.log = MagicMock() # type: ignore
|
||||
|
||||
batch_size = 20
|
||||
bag_size = 5
|
||||
class_weights = torch.tensor([.8, .2])
|
||||
bags: List[Dict] = []
|
||||
for slide_idx in range(batch_size):
|
||||
bag_label = torch.multinomial(class_weights, 1)
|
||||
sample: Dict[str, Iterable] = {
|
||||
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
|
||||
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
|
||||
for tile_idx in range(bag_size)],
|
||||
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
|
||||
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
|
||||
}
|
||||
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
|
||||
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
|
||||
bags.append(sample)
|
||||
batch = default_collate(bags)
|
||||
|
||||
# ================
|
||||
# Test that the module metrics match manually computed metrics with the correct inputs
|
||||
module_metrics_dict = module.test_metrics
|
||||
independent_metrics_dict = module.get_metrics()
|
||||
|
||||
# Patch the metrics to check that the inputs are valid. In particular, test that the scores
|
||||
# do not have integral values, which would suggest that hard labels were passed instead.
|
||||
for metric_obj in module_metrics_dict.values():
|
||||
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)
|
||||
|
||||
results = module.test_step(batch, 0)
|
||||
predicted_probs = results[ResultsKey.PROB]
|
||||
true_labels = results[ResultsKey.TRUE_LABEL]
|
||||
|
||||
for key, metric_obj in module_metrics_dict.items():
|
||||
value = metric_obj.compute()
|
||||
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
|
||||
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"
|
||||
|
||||
# ================
|
||||
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
|
||||
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
|
||||
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
|
||||
if hasattr(metric, 'threshold')]
|
||||
|
||||
def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
|
||||
for key in thresholded_metrics_keys:
|
||||
metrics_dict[key].threshold = threshold
|
||||
|
||||
def reset_metrics(metrics_dict: Any) -> None:
|
||||
for metric_obj in metrics_dict.values():
|
||||
metric_obj.reset()
|
||||
|
||||
low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))
|
||||
|
||||
reset_metrics(module_metrics_dict)
|
||||
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
|
||||
_ = module.test_step(batch, 0)
|
||||
results_low_threshold = {key: module_metrics_dict[key].compute()
|
||||
for key in thresholded_metrics_keys}
|
||||
|
||||
reset_metrics(module_metrics_dict)
|
||||
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
|
||||
_ = module.test_step(batch, 0)
|
||||
results_high_threshold = {key: module_metrics_dict[key].compute()
|
||||
for key in thresholded_metrics_keys}
|
||||
|
||||
for key in thresholded_metrics_keys:
|
||||
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
|
||||
f"Got same value for '{key}' metric with low and high thresholds"
|
||||
|
||||
|
||||
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
return {
|
||||
|
|
|
@ -34,7 +34,7 @@ dependencies:
|
|||
- jupyter-client==6.1.5
|
||||
- lightning-bolts==0.4.0
|
||||
- matplotlib==3.3.0
|
||||
- mlflow==1.17.0
|
||||
- mlflow==1.23.1
|
||||
- monai==0.6.0
|
||||
- mypy==0.910
|
||||
- mypy-extensions==0.4.3
|
||||
|
|
2
hi-ml
2
hi-ml
|
@ -1 +1 @@
|
|||
Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b
|
||||
Subproject commit 30854eae4fd27776be9f0105099ddba663ef3eb5
|
Загрузка…
Ссылка в новой задаче