Add FP and TN to DeepMIL outputs and configure outputs for multi-class classification (#679)
This commit is contained in:
Родитель
a4c82df78e
Коммит
501f3920c6
|
@ -51,6 +51,7 @@ jobs that run in AzureML.
|
|||
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
||||
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
|
||||
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.
|
||||
- ([#679](https://github.com/microsoft/InnerEye-DeepLearning/pull/679)) Add FP and TN slides/tiles to DeepMIL outputs and extend outputs to multi-class problems.
|
||||
|
||||
### Changed
|
||||
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
|
||||
|
|
|
@ -27,7 +27,7 @@ from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict
|
|||
from health_ml.utils import log_on_epoch
|
||||
|
||||
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
|
||||
ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
|
||||
|
||||
def _format_cuda_memory_stats() -> str:
|
||||
|
@ -242,21 +242,28 @@ class DeepMILModule(LightningModule):
|
|||
predicted_probs = self.activation_fn(bag_logits)
|
||||
if self.n_classes > 1:
|
||||
predicted_labels = argmax(predicted_probs, dim=1)
|
||||
probs_perclass = predicted_probs
|
||||
else:
|
||||
predicted_labels = round(predicted_probs)
|
||||
probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()] for i in range(len(predicted_probs))])
|
||||
|
||||
loss = loss.view(-1, 1)
|
||||
predicted_labels = predicted_labels.view(-1, 1)
|
||||
predicted_probs = predicted_probs.view(-1, 1)
|
||||
if self.n_classes == 1:
|
||||
predicted_probs = predicted_probs.view(-1, 1)
|
||||
bag_labels = bag_labels.view(-1, 1)
|
||||
|
||||
results = dict()
|
||||
for metric_object in self.get_metrics_dict(stage).values():
|
||||
metric_object.update(predicted_probs, bag_labels)
|
||||
if self.n_classes > 1:
|
||||
metric_object.update(predicted_probs, bag_labels.squeeze())
|
||||
else:
|
||||
metric_object.update(predicted_probs, bag_labels)
|
||||
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
|
||||
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
|
||||
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
|
||||
ResultsKey.PROB: predicted_probs, ResultsKey.PRED_LABEL: predicted_labels,
|
||||
ResultsKey.PROB: predicted_probs, ResultsKey.CLASS_PROBS: probs_perclass,
|
||||
ResultsKey.PRED_LABEL: predicted_labels,
|
||||
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
|
||||
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})
|
||||
|
||||
|
@ -339,11 +346,21 @@ class DeepMILModule(LightningModule):
|
|||
torch.save(features_list, encoded_features_filename)
|
||||
|
||||
print("Selecting tiles ...")
|
||||
fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att'))
|
||||
fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att'))
|
||||
tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att'))
|
||||
tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att'))
|
||||
report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]}
|
||||
# Class 0
|
||||
tn_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'highest_att'))
|
||||
tn_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'lowest_att'))
|
||||
fp_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'highest_att'))
|
||||
fp_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'lowest_att'))
|
||||
report_cases = {'TN': [tn_top_tiles, tn_bottom_tiles], 'FP': [fp_top_tiles, fp_bottom_tiles]}
|
||||
|
||||
# Class 1 to n_classes-1
|
||||
n_classes_to_select = self.n_classes if self.n_classes > 1 else 2
|
||||
for i in range(1, n_classes_to_select):
|
||||
fn_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'highest_att'))
|
||||
fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'lowest_att'))
|
||||
tp_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'highest_att'))
|
||||
tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'lowest_att'))
|
||||
report_cases.update({'TP_'+str(i): [tp_top_tiles, tp_bottom_tiles], 'FN_'+str(i): [fn_top_tiles, fn_bottom_tiles]})
|
||||
|
||||
for key in report_cases.keys():
|
||||
print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...")
|
||||
|
@ -397,13 +414,19 @@ class DeepMILModule(LightningModule):
|
|||
# these steps are required to convert the dictionary to pandas dataframe.
|
||||
device = 'cuda' if use_gpu else 'cpu'
|
||||
dict_new = dict()
|
||||
bag_size = len(dict_old[ResultsKey.SLIDE_ID])
|
||||
for key, value in dict_old.items():
|
||||
if isinstance(value, Tensor):
|
||||
value = value.squeeze(0).to(device).numpy()
|
||||
if value.ndim == 0:
|
||||
bag_size = len(dict_old[ResultsKey.SLIDE_ID])
|
||||
value = np.full(bag_size, fill_value=value)
|
||||
dict_new[key] = value
|
||||
if key not in [ResultsKey.CLASS_PROBS, ResultsKey.PROB]:
|
||||
if isinstance(value, Tensor):
|
||||
value = value.squeeze(0).to(device).numpy()
|
||||
if value.ndim == 0:
|
||||
value = np.full(bag_size, fill_value=value)
|
||||
dict_new[key] = value
|
||||
elif key == ResultsKey.CLASS_PROBS:
|
||||
if isinstance(value, Tensor):
|
||||
value = value.squeeze(0).to(device).numpy()
|
||||
for i in range(len(value)):
|
||||
dict_new[key+str(i)] = np.repeat(value[i], bag_size)
|
||||
return dict_new
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -20,7 +20,7 @@ from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_til
|
|||
def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1,
|
||||
select: Tuple = ('lowest_pred', 'highest_att'),
|
||||
slide_col: str = ResultsKey.SLIDE_ID, gt_col: str = ResultsKey.TRUE_LABEL,
|
||||
attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.PROB,
|
||||
attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.CLASS_PROBS,
|
||||
return_col: str = ResultsKey.IMAGE_PATH) -> List[Tuple[Any, Any, List[Any], List[Any]]]:
|
||||
"""
|
||||
:param results: List that contains slide_level dicts
|
||||
|
@ -35,7 +35,7 @@ def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: in
|
|||
:param return_col: column name of the values we want to return for each tile
|
||||
:return: tuple containing the slides id, the slide score, the tile ids, the tiles scores
|
||||
"""
|
||||
tmp_s = [(results[prob_col][i], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore
|
||||
tmp_s = [(results[prob_col][i][label], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore
|
||||
if select[0] == 'lowest_pred':
|
||||
tmp_s.sort(reverse=False)
|
||||
elif select[0] == 'highest_pred':
|
||||
|
@ -58,12 +58,12 @@ def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: in
|
|||
scores.append(results[attn_col][slide_idx][0][t_idx])
|
||||
# slide_ids are duplicated
|
||||
k_idx.append((results[slide_col][slide_idx][0],
|
||||
results[prob_col][slide_idx].item(),
|
||||
results[prob_col][slide_idx],
|
||||
k_tiles, scores))
|
||||
return k_idx
|
||||
|
||||
|
||||
def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB,
|
||||
def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.CLASS_PROBS,
|
||||
gt_col: str = ResultsKey.TRUE_LABEL) -> plt.figure:
|
||||
"""
|
||||
:param results: List that contains slide_level dicts
|
||||
|
@ -71,20 +71,23 @@ def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB,
|
|||
:param gt_col: column name that contains the true label
|
||||
:return: matplotlib figure of the scores histogram by class
|
||||
"""
|
||||
pos_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 1]
|
||||
neg_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 0]
|
||||
fig, ax = plt.subplots()
|
||||
ax.hist([pos_scores, neg_scores], label=['1', '0'], alpha=0.5)
|
||||
n_classes = len(results[prob_col][0])
|
||||
scores_class = []
|
||||
for j in range(n_classes):
|
||||
scores = [results[prob_col][i][j].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == j]
|
||||
scores_class.append(scores)
|
||||
fig, ax = plt.subplots()
|
||||
ax.hist(scores_class, label=[str(i) for i in range(n_classes)], alpha=0.5)
|
||||
ax.set_xlabel("Predicted Score")
|
||||
ax.legend()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5,
|
||||
def plot_attention_tiles(slide: str, scores: List[float], paths: List, attn: List, case: str, ncols: int = 5,
|
||||
size: Tuple = (10, 10)) -> plt.figure:
|
||||
"""
|
||||
:param slide: slide identifier
|
||||
:param score: predicted score for the slide
|
||||
:param scores: predicted scores of each class for the slide
|
||||
:param paths: list of paths to tiles belonging to the slide
|
||||
:param attn: list of scores belonging to the tiles in paths. paths and attn are expected to have the same shape
|
||||
:param case: string used to define the title of the plot e.g. TP
|
||||
|
@ -94,7 +97,7 @@ def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case
|
|||
"""
|
||||
nrows = int(ceil(len(paths) / ncols))
|
||||
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=size)
|
||||
fig.suptitle(f"{case}: {slide} P=%.2f" % score)
|
||||
fig.suptitle(f"{case}: {slide} P=%.2f" % max(scores))
|
||||
for i in range(len(paths)):
|
||||
img = load_pil_image(paths[i])
|
||||
axs.ravel()[i].imshow(img, clim=(0, 255), cmap='gray')
|
||||
|
|
|
@ -48,6 +48,7 @@ class ResultsKey(str, Enum):
|
|||
IMAGE_PATH = 'image_path'
|
||||
LOSS = 'loss'
|
||||
PROB = 'prob'
|
||||
CLASS_PROBS = 'prob_class'
|
||||
PRED_LABEL = 'pred_label'
|
||||
TRUE_LABEL = 'true_label'
|
||||
BAG_ATTN = 'bag_attn'
|
||||
|
|
|
@ -31,6 +31,9 @@ def assert_equal_lists(pred: List, expected: List) -> None:
|
|||
for j, value in enumerate(slide):
|
||||
if type(value) in [int, float]:
|
||||
assert math.isclose(value, expected[i][j], rel_tol=1e-06)
|
||||
elif (type(value) == Tensor) and (value.ndim >= 1):
|
||||
for k, idx in enumerate(value):
|
||||
assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06)
|
||||
elif isinstance(value, List):
|
||||
for k, idx in enumerate(value):
|
||||
if type(idx) in [int, float]:
|
||||
|
@ -41,15 +44,20 @@ def assert_equal_lists(pred: List, expected: List) -> None:
|
|||
raise TypeError("Unexpected list composition")
|
||||
|
||||
|
||||
test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
|
||||
ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],
|
||||
ResultsKey.PROB: [Tensor([0.5]), Tensor([0.7]), Tensor([0.4]), Tensor([1.0])],
|
||||
ResultsKey.TRUE_LABEL: [0, 1, 1, 1],
|
||||
test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],
|
||||
ResultsKey.CLASS_PROBS: [Tensor([0.6, 0.4]), Tensor([0.3, 0.7]), Tensor([0.6, 0.4]), Tensor([0.0, 1.0]),
|
||||
Tensor([0.7, 0.3]), Tensor([0.8, 0.2]), Tensor([0.1, 0.9]), Tensor([0.01, 0.99])],
|
||||
ResultsKey.TRUE_LABEL: [0, 1, 1, 1, 1, 0, 0, 0],
|
||||
ResultsKey.BAG_ATTN:
|
||||
[Tensor([[0.1, 0.0, 0.2, 0.15]]),
|
||||
[Tensor([[0.10, 0.00, 0.20, 0.15]]),
|
||||
Tensor([[0.10, 0.18, 0.15, 0.13]]),
|
||||
Tensor([[0.25, 0.23, 0.20, 0.21]]),
|
||||
Tensor([[0.33, 0.31, 0.37, 0.35]])],
|
||||
Tensor([[0.33, 0.31, 0.37, 0.35]]),
|
||||
Tensor([[0.43, 0.01, 0.07, 0.25]]),
|
||||
Tensor([[0.53, 0.11, 0.17, 0.55]]),
|
||||
Tensor([[0.63, 0.21, 0.27, 0.05]]),
|
||||
Tensor([[0.73, 0.31, 0.37, 0.15]])],
|
||||
ResultsKey.TILE_X:
|
||||
[Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424]),
|
||||
|
@ -64,27 +72,40 @@ test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4,
|
|||
|
||||
|
||||
def test_select_k_tiles() -> None:
|
||||
top_tn = select_k_tiles(test_dict, n_slides=1, label=0, n_tiles=2, select=('lowest_pred', 'highest_att'))
|
||||
assert_equal_lists(top_tn, [(1, 0.5, [3, 4], [Tensor([0.2]), Tensor([0.15])])])
|
||||
|
||||
nslides = 2
|
||||
ntiles = 2
|
||||
top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
|
||||
bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles,
|
||||
select=('lowest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_fn, [(3, 0.4, [1, 2], [Tensor([0.25]), Tensor([0.23])]),
|
||||
(2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_fn, [(3, 0.4, [3, 4], [Tensor([0.20]), Tensor([0.21])]),
|
||||
(2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])])
|
||||
# TP
|
||||
top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
|
||||
bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
|
||||
print(top_tp)
|
||||
assert_equal_lists(top_tp, [(4, Tensor([0.0, 1.0]), [3, 4], [Tensor([0.37]), Tensor([0.35])]),
|
||||
(2, Tensor([0.3, 0.7]), [2, 3], [Tensor([0.18]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_tp, [(4, Tensor([0.0, 1.0]), [2, 1], [Tensor([0.31]), Tensor([0.33])]),
|
||||
(2, Tensor([0.3, 0.7]), [1, 4], [Tensor([0.10]), Tensor([0.13])])])
|
||||
|
||||
top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles,
|
||||
select=('highest_pred', 'highest_att'))
|
||||
bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles,
|
||||
select=('highest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_tp, [(4, 1.0, [3, 4], [Tensor([0.37]), Tensor([0.35])]),
|
||||
(2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_tp, [(4, 1.0, [2, 1], [Tensor([0.31]), Tensor([0.33])]),
|
||||
(2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])])
|
||||
# FN
|
||||
top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
|
||||
bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_fn, [(5, Tensor([0.7, 0.3]), [1, 4], [Tensor([0.43]), Tensor([0.25])]),
|
||||
(3, Tensor([0.6, 0.4]), [1, 2], [Tensor([0.25]), Tensor([0.23])])])
|
||||
assert_equal_lists(bottom_fn, [(5, Tensor([0.7, 0.3]), [2, 3], [Tensor([0.01]), Tensor([0.07])]),
|
||||
(3, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.20]), Tensor([0.21])])])
|
||||
|
||||
# TN
|
||||
top_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
|
||||
bottom_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_tn, [(6, Tensor([0.8, 0.2]), [4, 1], [Tensor([0.55]), Tensor([0.53])]),
|
||||
(1, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.2]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_tn, [(6, Tensor([0.8, 0.2]), [2, 3], [Tensor([0.11]), Tensor([0.17])]),
|
||||
(1, Tensor([0.6, 0.4]), [2, 1], [Tensor([0.00]), Tensor([0.10])])])
|
||||
|
||||
# FP
|
||||
top_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
|
||||
bottom_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_fp, [(8, Tensor([0.01, 0.99]), [1, 3], [Tensor([0.73]), Tensor([0.37])]),
|
||||
(7, Tensor([0.1, 0.9]), [1, 3], [Tensor([0.63]), Tensor([0.27])])])
|
||||
assert_equal_lists(bottom_fp, [(8, Tensor([0.01, 0.99]), [4, 2], [Tensor([0.15]), Tensor([0.31])]),
|
||||
(7, Tensor([0.1, 0.9]), [4, 2], [Tensor([0.05]), Tensor([0.21])])])
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ca95c0017d0a51d75d118e54f21c1e907b3d90dcca822b23622e369267907198
|
||||
size 17057
|
||||
oid sha256:6ddc430ffcade51a072e9452833143840b1e5726148fd850ad3f370f1315bb32
|
||||
size 20452
|
||||
|
|
2
hi-ml
2
hi-ml
|
@ -1 +1 @@
|
|||
Subproject commit 30854eae4fd27776be9f0105099ddba663ef3eb5
|
||||
Subproject commit 0250715c5ac1ef09227b51388df44b568a496f65
|
Загрузка…
Ссылка в новой задаче