* Fix DeepMIL metrics input bug

* Add first version of metrics tests

* Update submodule

* Add test for DeepMIL metrics inputs

* Clean-up and update submodule

* Update changelog

* Upgrade mlflow due to Component Governance warning
This commit is contained in:
Daniel Coelho de Castro 2022-03-01 09:10:13 +00:00 коммит произвёл GitHub
Родитель d7e5d8b5e5
Коммит e984554c9e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 122 добавлений и 17 удалений

Просмотреть файл

@ -128,6 +128,7 @@ in inference-only runs when using lightning containers.
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
- ([#652](https://github.com/microsoft/InnerEye-DeepLearning/pull/652)) Run pytest build on Windows after Linux agent version upgrade
- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04
- ([#674](https://github.com/microsoft/InnerEye-DeepLearning/pull/674)) Fix DeepMIL metrics bug whereby hard labels were used instead of probabilities.
### Removed

Просмотреть файл

@ -179,12 +179,13 @@ class DeepMILModule(LightningModule):
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
else:
return nn.ModuleDict({MetricsKey.ACC: Accuracy(),
threshold = 0.5
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.F1: F1(),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes+1)})
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
def log_metrics(self,
stage: str) -> None:
@ -238,24 +239,24 @@ class DeepMILModule(LightningModule):
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
probs = self.activation_fn(bag_logits)
predicted_probs = self.activation_fn(bag_logits)
if self.n_classes > 1:
preds = argmax(probs, dim=1)
predicted_labels = argmax(predicted_probs, dim=1)
else:
preds = round(probs)
predicted_labels = round(predicted_probs)
loss = loss.view(-1, 1)
preds = preds.view(-1, 1)
probs = probs.view(-1, 1)
predicted_labels = predicted_labels.view(-1, 1)
predicted_probs = predicted_probs.view(-1, 1)
bag_labels = bag_labels.view(-1, 1)
results = dict()
for metric_object in self.get_metrics_dict(stage).values():
metric_object.update(preds, bag_labels)
metric_object.update(predicted_probs, bag_labels)
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds,
ResultsKey.PROB: predicted_probs, ResultsKey.PRED_LABEL: predicted_labels,
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})

Просмотреть файл

@ -4,13 +4,16 @@
# ------------------------------------------------------------------------------------------
import os
from typing import Callable, Dict, List, Optional, Type # noqa
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from unittest.mock import MagicMock
import pytest
import torch
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from torchmetrics import Accuracy, Metric # noqa
from torchvision.models import resnet18
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
@ -30,7 +33,7 @@ from InnerEye.ML.Histopathology.datasets.default_paths import (
PANDA_TILES_DATASET_DIR,
)
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey
@ -157,6 +160,106 @@ def test_lightningmodule_mean_pooling(
dropout_rate=dropout_rate)
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
def is_integral(x: torch.Tensor) -> bool:
return (x == x.long()).all() # type: ignore
assert scores.shape == labels.shape
assert torch.is_floating_point(scores), "Received scores with integer dtype"
assert not is_integral(scores), "Received scores with integral values"
assert is_integral(labels), "Received labels with floating-point values"
def add_callback(fn: Callable, callback: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
callback(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
def test_metrics() -> None:
input_dim = (128,)
module = DeepMILModule(
encoder=IdentityEncoder(input_dim=input_dim),
label_column=TilesDataset.LABEL_COLUMN,
n_classes=1,
pooling_layer=AttentionLayer,
)
# Patching to enable running the module without a Trainer object
module.trainer = MagicMock(world_size=1) # type: ignore
module.log = MagicMock() # type: ignore
batch_size = 20
bag_size = 5
class_weights = torch.tensor([.8, .2])
bags: List[Dict] = []
for slide_idx in range(batch_size):
bag_label = torch.multinomial(class_weights, 1)
sample: Dict[str, Iterable] = {
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
for tile_idx in range(bag_size)],
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
}
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
bags.append(sample)
batch = default_collate(bags)
# ================
# Test that the module metrics match manually computed metrics with the correct inputs
module_metrics_dict = module.test_metrics
independent_metrics_dict = module.get_metrics()
# Patch the metrics to check that the inputs are valid. In particular, test that the scores
# do not have integral values, which would suggest that hard labels were passed instead.
for metric_obj in module_metrics_dict.values():
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)
results = module.test_step(batch, 0)
predicted_probs = results[ResultsKey.PROB]
true_labels = results[ResultsKey.TRUE_LABEL]
for key, metric_obj in module_metrics_dict.items():
value = metric_obj.compute()
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"
# ================
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
if hasattr(metric, 'threshold')]
def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
for key in thresholded_metrics_keys:
metrics_dict[key].threshold = threshold
def reset_metrics(metrics_dict: Any) -> None:
for metric_obj in metrics_dict.values():
metric_obj.reset()
low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
_ = module.test_step(batch, 0)
results_low_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
_ = module.test_step(batch, 0)
results_high_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
for key in thresholded_metrics_keys:
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
f"Got same value for '{key}' metric with low and high thresholds"
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {

Просмотреть файл

@ -34,7 +34,7 @@ dependencies:
- jupyter-client==6.1.5
- lightning-bolts==0.4.0
- matplotlib==3.3.0
- mlflow==1.17.0
- mlflow==1.23.1
- monai==0.6.0
- mypy==0.910
- mypy-extensions==0.4.3

2
hi-ml

@ -1 +1 @@
Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b
Subproject commit 30854eae4fd27776be9f0105099ddba663ef3eb5