зеркало из https://github.com/microsoft/hi-ml.git
ENH: Track validation loss and add entropy value (#629)
Add entropy values to detect ambiguity and track validation loss during training
This commit is contained in:
Родитель
85a317d0f7
Коммит
e2c1ca1cb4
|
@ -291,7 +291,8 @@ class BaseMILTiles(BaseMIL):
|
|||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler)
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
@ -334,7 +335,8 @@ class BaseMILSlides(BaseMIL):
|
|||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler)
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
|
|
@ -51,7 +51,8 @@ class BaseDeepMILModule(LightningModule):
|
|||
encoder_params: EncoderParams = EncoderParams(),
|
||||
pooling_params: PoolingParams = PoolingParams(),
|
||||
optimizer_params: OptimizerParams = OptimizerParams(),
|
||||
outputs_handler: Optional[DeepMILOutputsHandler] = None) -> None:
|
||||
outputs_handler: Optional[DeepMILOutputsHandler] = None,
|
||||
analyse_loss: Optional[bool] = False) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
|
||||
|
@ -71,6 +72,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
|
||||
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
|
||||
metrics logging).
|
||||
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -100,7 +102,8 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.activation_fn = self.get_activation()
|
||||
self.loss_fn = self.get_loss()
|
||||
self.analyse_loss = analyse_loss
|
||||
self.loss_fn = self.get_loss(reduction="mean")
|
||||
self.loss_fn_no_reduction = self.get_loss(reduction="none")
|
||||
|
||||
# Metrics Objects
|
||||
|
@ -295,14 +298,17 @@ class BaseDeepMILModule(LightningModule):
|
|||
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
|
||||
if self.n_classes > 1:
|
||||
return loss_fn(bag_logits, bag_labels.long())
|
||||
else:
|
||||
return loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
|
||||
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> BatchResultsType:
|
||||
|
||||
bag_logits, bag_labels, bag_attn_list = self.compute_bag_labels_logits_and_attn_maps(batch)
|
||||
|
||||
if self.n_classes > 1:
|
||||
loss = self.loss_fn(bag_logits, bag_labels.long())
|
||||
else:
|
||||
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
loss = self._compute_loss(self.loss_fn, bag_logits, bag_labels)
|
||||
|
||||
predicted_probs = self.activation_fn(bag_logits)
|
||||
if self.n_classes > 1:
|
||||
|
@ -320,9 +326,13 @@ class BaseDeepMILModule(LightningModule):
|
|||
if self.n_classes == 1:
|
||||
predicted_probs = predicted_probs.squeeze(dim=1)
|
||||
|
||||
results = dict()
|
||||
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
|
||||
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
|
||||
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
|
||||
|
||||
bag_labels = bag_labels.view(-1, 1)
|
||||
|
||||
results = dict()
|
||||
for metric_object in self.get_metrics_dict(stage).values():
|
||||
metric_object.update(predicted_probs, bag_labels.view(batch_size,).int())
|
||||
results.update({ResultsKey.LOSS: loss,
|
||||
|
@ -341,13 +351,17 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
|
||||
return results
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
if self.verbose:
|
||||
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
|
||||
return train_result[ResultsKey.LOSS]
|
||||
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
|
||||
if self.analyse_loss:
|
||||
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
|
||||
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS]})
|
||||
return results
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)
|
||||
|
|
|
@ -11,17 +11,20 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cm
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from health_cpath.models.deepmil import BaseDeepMILModule
|
||||
from health_cpath.utils.naming import ResultsKey
|
||||
from health_cpath.utils.naming import ModelKey, ResultsKey
|
||||
from health_cpath.utils.output_utils import BatchResultsType
|
||||
|
||||
LossCacheDictType = Dict[ResultsKey, List]
|
||||
LossCacheDictType = Dict[Union[ResultsKey, str], List]
|
||||
LossDictType = Dict[str, List]
|
||||
AnomalyDictType = Dict[ModelKey, List[str]]
|
||||
|
||||
|
||||
class LossCallbackParams(param.Parameterized):
|
||||
|
@ -89,6 +92,7 @@ class LossAnalysisCallback(Callback):
|
|||
num_slides_heatmap: int = 20,
|
||||
save_tile_ids: bool = False,
|
||||
log_exceptions: bool = True,
|
||||
create_outputs_folders: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
|
@ -102,6 +106,7 @@ class LossAnalysisCallback(Callback):
|
|||
False. This is useful to analyse the tiles that are contributing to the loss value of a slide.
|
||||
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
|
||||
False will raise the intercepted exceptions.
|
||||
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
|
||||
"""
|
||||
|
||||
self.patience = patience
|
||||
|
@ -112,89 +117,89 @@ class LossAnalysisCallback(Callback):
|
|||
self.save_tile_ids = save_tile_ids
|
||||
self.log_exceptions = log_exceptions
|
||||
|
||||
self.outputs_folder = outputs_folder / "loss_values_callback"
|
||||
self.create_outputs_folders()
|
||||
self.outputs_folder = outputs_folder / "loss_analysis_callback"
|
||||
if create_outputs_folders:
|
||||
self.create_outputs_folders()
|
||||
|
||||
self.loss_cache = self.reset_loss_cache()
|
||||
self.loss_cache: Dict[ModelKey, LossCacheDictType] = {
|
||||
stage: self.get_empty_loss_cache() for stage in [ModelKey.TRAIN, ModelKey.VAL]
|
||||
}
|
||||
self.epochs_range = list(range(self.patience, self.max_epochs, self.epochs_interval))
|
||||
|
||||
self.nan_slides: List[str] = []
|
||||
self.anomaly_slides: List[str] = []
|
||||
self.nan_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
|
||||
self.anomaly_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
|
||||
|
||||
@property
|
||||
def cache_folder(self) -> Path:
|
||||
return self.outputs_folder / "loss_cache"
|
||||
def get_cache_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_cache"
|
||||
|
||||
@property
|
||||
def scatter_folder(self) -> Path:
|
||||
return self.outputs_folder / "loss_scatter"
|
||||
def get_scatter_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_scatter"
|
||||
|
||||
@property
|
||||
def heatmap_folder(self) -> Path:
|
||||
return self.outputs_folder / "loss_heatmap"
|
||||
def get_heatmap_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_heatmap"
|
||||
|
||||
@property
|
||||
def stats_folder(self) -> Path:
|
||||
return self.outputs_folder / "loss_stats"
|
||||
def get_stats_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_stats"
|
||||
|
||||
@property
|
||||
def anomalies_folder(self) -> Path:
|
||||
return self.outputs_folder / "loss_anomalies"
|
||||
def get_anomalies_folder(self, stage: ModelKey) -> Path:
|
||||
return self.outputs_folder / f"{stage}/loss_anomalies"
|
||||
|
||||
def create_outputs_folders(self) -> None:
|
||||
folders = [
|
||||
self.cache_folder,
|
||||
self.scatter_folder,
|
||||
self.heatmap_folder,
|
||||
self.stats_folder,
|
||||
self.anomalies_folder,
|
||||
self.get_cache_folder,
|
||||
self.get_scatter_folder,
|
||||
self.get_heatmap_folder,
|
||||
self.get_stats_folder,
|
||||
self.get_anomalies_folder,
|
||||
]
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
for folder in folders:
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
for stage in stages:
|
||||
os.makedirs(folder(stage), exist_ok=True)
|
||||
|
||||
def reset_loss_cache(self) -> LossCacheDictType:
|
||||
keys = [ResultsKey.LOSS, ResultsKey.SLIDE_ID]
|
||||
def get_empty_loss_cache(self) -> LossCacheDictType:
|
||||
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
|
||||
if self.save_tile_ids:
|
||||
keys.append(ResultsKey.TILE_ID)
|
||||
return {key: [] for key in keys}
|
||||
|
||||
def _get_filename(self, filename: str, epoch: int, order: Optional[str] = None) -> str:
|
||||
zero_filled_epoch = str(epoch).zfill(len(str(self.max_epochs)))
|
||||
filename = filename.format(zero_filled_epoch, order) if order else filename.format(zero_filled_epoch)
|
||||
return filename
|
||||
def _format_epoch(self, epoch: int) -> str:
|
||||
return str(epoch).zfill(len(str(self.max_epochs)))
|
||||
|
||||
def get_loss_cache_file(self, epoch: int) -> Path:
|
||||
return self.cache_folder / self._get_filename(filename="epoch_{}.csv", epoch=epoch)
|
||||
def get_loss_cache_file(self, epoch: int, stage: ModelKey) -> Path:
|
||||
return self.get_cache_folder(stage) / f"epoch_{self._format_epoch(epoch)}.csv"
|
||||
|
||||
def get_all_epochs_loss_cache_file(self) -> Path:
|
||||
return self.cache_folder / "all_epochs.csv"
|
||||
def get_all_epochs_loss_cache_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_cache_folder(stage) / f"all_epochs_{stage}.csv"
|
||||
|
||||
def get_loss_stats_file(self) -> Path:
|
||||
return self.stats_folder / "loss_stats.csv"
|
||||
def get_loss_stats_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_stats_{stage}.csv"
|
||||
|
||||
def get_loss_ranks_file(self) -> Path:
|
||||
return self.stats_folder / "loss_ranks.csv"
|
||||
def get_loss_ranks_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_ranks_{stage}.csv"
|
||||
|
||||
def get_loss_ranks_stats_file(self) -> Path:
|
||||
return self.stats_folder / "loss_ranks_stats.csv"
|
||||
def get_loss_ranks_stats_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_stats_folder(stage) / f"loss_ranks_stats_{stage}.csv"
|
||||
|
||||
def get_nan_slides_file(self) -> Path:
|
||||
return self.anomalies_folder / "nan_slides.txt"
|
||||
def get_nan_slides_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_anomalies_folder(stage) / f"nan_slides_{stage}.txt"
|
||||
|
||||
def get_anomaly_slides_file(self) -> Path:
|
||||
return self.anomalies_folder / "anomaly_slides.txt"
|
||||
def get_anomaly_slides_file(self, stage: ModelKey) -> Path:
|
||||
return self.get_anomalies_folder(stage) / f"anomaly_slides_{stage}.txt"
|
||||
|
||||
def get_scatter_plot_file(self, order: str) -> Path:
|
||||
return self.scatter_folder / "slides_with_{}_loss_values.png".format(order)
|
||||
def get_scatter_plot_file(self, order: str, stage: ModelKey) -> Path:
|
||||
return self.get_scatter_folder(stage) / f"slides_with_{order}_loss_values_{stage}.png"
|
||||
|
||||
def get_heatmap_plot_file(self, epoch: int, order: str) -> Path:
|
||||
return self.heatmap_folder / self._get_filename(filename="epoch_{}_{}_slides.png", epoch=epoch, order=order)
|
||||
def get_heatmap_plot_file(self, epoch: int, order: str, stage: ModelKey) -> Path:
|
||||
return self.get_heatmap_folder(stage) / f"epoch_{self._format_epoch(epoch)}_{order}_slides_{stage}.png"
|
||||
|
||||
def read_loss_cache(self, epoch: int, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
|
||||
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS]
|
||||
return pd.read_csv(self.get_loss_cache_file(epoch), index_col=idx_col, usecols=columns)
|
||||
def read_loss_cache(self, epoch: int, stage: ModelKey, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
|
||||
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
|
||||
return pd.read_csv(self.get_loss_cache_file(epoch, stage), index_col=idx_col, usecols=columns)
|
||||
|
||||
def should_cache_loss_values(self, current_epoch: int) -> bool:
|
||||
if current_epoch >= self.max_epochs:
|
||||
return False # Don't cache loss values for the extra validation epoch
|
||||
current_epoch = current_epoch + 1
|
||||
first_time = (current_epoch - self.patience) == 1
|
||||
return first_time or (
|
||||
|
@ -203,35 +208,41 @@ class LossAnalysisCallback(Callback):
|
|||
|
||||
def merge_loss_caches(self, loss_caches: List[LossCacheDictType]) -> LossCacheDictType:
|
||||
"""Merges the loss caches from all the workers into a single loss cache"""
|
||||
loss_cache = self.reset_loss_cache()
|
||||
loss_cache = self.get_empty_loss_cache()
|
||||
for loss_cache_per_device in loss_caches:
|
||||
for key in loss_cache.keys():
|
||||
loss_cache[key].extend(loss_cache_per_device[key])
|
||||
return loss_cache
|
||||
|
||||
def gather_loss_cache(self, rank: int) -> None:
|
||||
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
|
||||
"""Gathers the loss cache from all the workers"""
|
||||
if torch.distributed.is_initialized():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
loss_caches = [None] * world_size
|
||||
torch.distributed.all_gather_object(loss_caches, self.loss_cache)
|
||||
torch.distributed.all_gather_object(loss_caches, self.loss_cache[stage])
|
||||
if rank == 0:
|
||||
self.loss_cache = self.merge_loss_caches(loss_caches) # type: ignore
|
||||
self.loss_cache[stage] = self.merge_loss_caches(loss_caches) # type: ignore
|
||||
|
||||
def save_loss_cache(self, current_epoch: int) -> None:
|
||||
def save_loss_cache(self, current_epoch: int, stage: ModelKey) -> None:
|
||||
"""Saves the loss cache to a csv file"""
|
||||
loss_cache_df = pd.DataFrame(self.loss_cache)
|
||||
loss_cache_df = pd.DataFrame(self.loss_cache[stage])
|
||||
# Some slides may be appear multiple times in the loss cache in DDP mode. The Distributed Sampler duplicate
|
||||
# some slides to even out the number of samples per device, so we only keep the first occurrence.
|
||||
loss_cache_df.drop_duplicates(subset=ResultsKey.SLIDE_ID, inplace=True, keep="first")
|
||||
loss_cache_df = loss_cache_df.sort_values(by=ResultsKey.LOSS, ascending=False)
|
||||
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch), index=False)
|
||||
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch, stage), index=False)
|
||||
|
||||
def _select_values_for_epoch(
|
||||
self, keys: List[ResultsKey], epoch: int, high: Optional[bool] = None, num_values: Optional[int] = None
|
||||
self,
|
||||
keys: List[ResultsKey],
|
||||
epoch: int,
|
||||
stage: ModelKey,
|
||||
high: Optional[bool] = None,
|
||||
num_values: Optional[int] = None
|
||||
) -> List[np.ndarray]:
|
||||
loss_cache = self.read_loss_cache(epoch)
|
||||
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage"""
|
||||
loss_cache = self.read_loss_cache(epoch, stage)
|
||||
return_values = []
|
||||
for key in keys:
|
||||
values = loss_cache[key].values
|
||||
|
@ -246,52 +257,63 @@ class LossAnalysisCallback(Callback):
|
|||
return return_values
|
||||
|
||||
def select_slides_for_epoch(
|
||||
self, epoch: int, high: Optional[bool] = None, num_slides: Optional[int] = None
|
||||
self, epoch: int, stage: ModelKey, high: Optional[bool] = None, num_slides: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""Selects slides in ascending/descending order of loss values at a given epoch
|
||||
|
||||
:param epoch: The epoch to select the slides from.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values. If None, selects all slides.
|
||||
:param num_slides: The number of slides to select. If None, selects all slides.
|
||||
"""
|
||||
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, high, num_slides)[0]
|
||||
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, stage, high, num_slides)[0]
|
||||
|
||||
def select_slides_losses_for_epoch(self, epoch: int, high: Optional[bool] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def select_slides_entropy_for_epoch(
|
||||
self, epoch: int, stage: ModelKey, high: Optional[bool] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Selects slides and loss values of slides in ascending/descending order of loss at a given epoch
|
||||
|
||||
:param epoch: The epoch to select the slides from.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values. If None, selects all slides.
|
||||
"""
|
||||
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS]
|
||||
return_values = self._select_values_for_epoch(keys, epoch, high, self.num_slides_scatter)
|
||||
keys = [ResultsKey.SLIDE_ID, ResultsKey.ENTROPY]
|
||||
return_values = self._select_values_for_epoch(keys, epoch, stage, high, self.num_slides_scatter)
|
||||
return return_values[0], return_values[1]
|
||||
|
||||
def select_all_losses_for_selected_slides(self, slides: np.ndarray) -> LossDictType:
|
||||
"""Selects the loss values for a given set of slides"""
|
||||
def select_all_losses_for_selected_slides(self, slides: np.ndarray, stage: ModelKey) -> LossDictType:
|
||||
"""Selects the loss values for a given set of slides
|
||||
|
||||
:param slides: The slides to select the loss values for.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
|
||||
slides_loss_values: LossDictType = {slide_id: [] for slide_id in slides}
|
||||
for epoch in self.epochs_range:
|
||||
loss_cache = self.read_loss_cache(epoch, idx_col=ResultsKey.SLIDE_ID)
|
||||
loss_cache = self.read_loss_cache(epoch, stage, idx_col=ResultsKey.SLIDE_ID)
|
||||
for slide_id in slides:
|
||||
slides_loss_values[slide_id].append(loss_cache.loc[slide_id, ResultsKey.LOSS])
|
||||
return slides_loss_values
|
||||
|
||||
def select_slides_and_losses_across_epochs(self, high: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Selects the slides with the highest/lowest loss values across epochs
|
||||
def select_slides_and_entropy_across_epochs(
|
||||
self, stage: ModelKey, high: bool = True
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Selects the slides with the highest/lowest loss values across epochs and their entropy values
|
||||
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
|
||||
loss values.
|
||||
"""
|
||||
slides = []
|
||||
slides_loss = []
|
||||
slides_entropy = []
|
||||
for epoch in self.epochs_range:
|
||||
epoch_slides, epoch_slides_loss = self.select_slides_losses_for_epoch(epoch, high)
|
||||
epoch_slides, epoch_slides_entropy = self.select_slides_entropy_for_epoch(epoch, stage, high)
|
||||
slides.append(epoch_slides)
|
||||
slides_loss.append(epoch_slides_loss)
|
||||
slides_entropy.append(epoch_slides_entropy)
|
||||
|
||||
return np.array(slides).T, np.array(slides_loss).T
|
||||
return np.array(slides).T, np.array(slides_entropy).T
|
||||
|
||||
def save_slide_ids(self, slide_ids: List[str], path: Path) -> None:
|
||||
"""Dumps the slides ids in a txt file."""
|
||||
|
@ -300,10 +322,11 @@ class LossAnalysisCallback(Callback):
|
|||
for slide_id in slide_ids:
|
||||
f.write(f"{slide_id}\n")
|
||||
|
||||
def sanity_check_loss_values(self, loss_values: LossDictType) -> None:
|
||||
def sanity_check_loss_values(self, loss_values: LossDictType, stage: ModelKey) -> None:
|
||||
"""Checks if there are any NaNs or any other potential annomalies in the loss values.
|
||||
|
||||
:param loss_values: The loss values for all slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
# We don't want any of these exceptions to interrupt validation. So we catch them and log them.
|
||||
loss_values_copy = loss_values.copy()
|
||||
|
@ -311,68 +334,83 @@ class LossAnalysisCallback(Callback):
|
|||
try:
|
||||
if np.isnan(loss).any():
|
||||
logging.warning(f"NaNs found in loss values for slide {slide_id}.")
|
||||
self.nan_slides.append(slide_id)
|
||||
self.nan_slides[stage].append(slide_id)
|
||||
loss_values.pop(slide_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while checking for NaNs in loss values for slide {slide_id} with error {e}.")
|
||||
logging.warning(f"Loss values that caused the issue: {loss}")
|
||||
self.anomaly_slides.append(slide_id)
|
||||
self.anomaly_slides[stage].append(slide_id)
|
||||
loss_values.pop(slide_id)
|
||||
self.save_slide_ids(self.nan_slides, self.get_nan_slides_file())
|
||||
self.save_slide_ids(self.anomaly_slides, self.get_anomaly_slides_file())
|
||||
self.save_slide_ids(self.nan_slides[stage], self.get_nan_slides_file(stage))
|
||||
self.save_slide_ids(self.anomaly_slides[stage], self.get_anomaly_slides_file(stage))
|
||||
|
||||
def save_loss_ranks(self, slides_loss_values: LossDictType) -> None:
|
||||
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files."""
|
||||
def save_loss_ranks(self, slides_loss_values: LossDictType, stage: ModelKey) -> None:
|
||||
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files.
|
||||
|
||||
:param slides_loss_values: The loss values for all slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
"""
|
||||
|
||||
loss_df = pd.DataFrame(slides_loss_values).T
|
||||
loss_df.index.names = [ResultsKey.SLIDE_ID.value]
|
||||
loss_df.to_csv(self.get_all_epochs_loss_cache_file())
|
||||
loss_df.to_csv(self.get_all_epochs_loss_cache_file(stage))
|
||||
|
||||
loss_stats = loss_df.T.describe().T.sort_values(by="mean", ascending=False)
|
||||
loss_stats.to_csv(self.get_loss_stats_file())
|
||||
loss_stats.to_csv(self.get_loss_stats_file(stage))
|
||||
|
||||
loss_ranks = loss_df.rank(ascending=False)
|
||||
loss_ranks.to_csv(self.get_loss_ranks_file())
|
||||
loss_ranks.to_csv(self.get_loss_ranks_file(stage))
|
||||
|
||||
loss_ranks_stats = loss_ranks.T.describe().T.sort_values("mean", ascending=True)
|
||||
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file())
|
||||
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file(stage))
|
||||
|
||||
def plot_slides_loss_scatter(
|
||||
self,
|
||||
slides: np.ndarray,
|
||||
slides_loss: np.ndarray,
|
||||
slides_entropy: np.ndarray,
|
||||
stage: ModelKey,
|
||||
high: bool = True,
|
||||
figsize: Tuple[float, float] = (30, 30),
|
||||
) -> None:
|
||||
"""Plots the slides with the highest/lowest loss values across epochs in a scatter plot
|
||||
|
||||
:param slides: The slides ids.
|
||||
:param slides_loss: The loss values for each slide.
|
||||
:param figsize: The figure size, defaults to (20, 20)
|
||||
:param slides_entropy: The entropy values for each slide.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param figsize: The figure size, defaults to (30, 30)
|
||||
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
|
||||
"""
|
||||
label = self.TOP if high else self.BOTTOM
|
||||
plt.figure(figsize=figsize)
|
||||
markers_size = [15 * i for i in range(1, self.num_slides_scatter + 1)]
|
||||
markers_size = markers_size[::-1] if high else markers_size
|
||||
colors = cm.rainbow(np.linspace(0, 1, self.num_slides_scatter))
|
||||
for i in range(self.num_slides_scatter - 1, -1, -1):
|
||||
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}")
|
||||
for loss, epoch, slide in zip(slides_loss[i], self.epochs_range, slides[i]):
|
||||
plt.annotate(f"{loss:.3f}", (epoch, slide))
|
||||
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}", s=markers_size[i], color=colors[i])
|
||||
for entropy, epoch, slide in zip(slides_entropy[i], self.epochs_range, slides[i]):
|
||||
plt.annotate(f"{entropy:.3f}", (epoch, slide))
|
||||
plt.xlabel(self.X_LABEL)
|
||||
plt.ylabel(self.Y_LABEL)
|
||||
order = self.HIGHEST if high else self.LOWEST
|
||||
plt.title(f"Slides with {order} loss values per epoch.")
|
||||
plt.title(f"Slides with {order} loss values per epoch and their entropy values")
|
||||
plt.xticks(self.epochs_range)
|
||||
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
plt.grid(True, linestyle="--")
|
||||
plt.savefig(self.get_scatter_plot_file(order), bbox_inches="tight")
|
||||
plt.savefig(self.get_scatter_plot_file(order, stage), bbox_inches="tight")
|
||||
|
||||
def plot_loss_heatmap_for_slides_of_epoch(
|
||||
self, slides_loss_values: LossDictType, epoch: int, high: bool, figsize: Tuple[float, float] = (15, 15)
|
||||
self,
|
||||
slides_loss_values: LossDictType,
|
||||
epoch: int,
|
||||
stage: ModelKey,
|
||||
high: bool,
|
||||
figsize: Tuple[float, float] = (15, 15)
|
||||
) -> None:
|
||||
"""Plots the loss values for each slide across all epochs in a heatmap.
|
||||
|
||||
:param slides_loss_values: The loss values for each slide across all epochs.
|
||||
:param epoch: The epoch used to select the slides.
|
||||
:param stage: The model's stage (train/val).
|
||||
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
|
||||
:param figsize: The figure size, defaults to (15, 15)
|
||||
"""
|
||||
|
@ -384,69 +422,119 @@ class LossAnalysisCallback(Callback):
|
|||
plt.xlabel(self.X_LABEL)
|
||||
plt.ylabel(self.Y_LABEL)
|
||||
plt.title(f"Loss values evolution for {order} slides of epoch {epoch}")
|
||||
plt.savefig(self.get_heatmap_plot_file(epoch, order), bbox_inches="tight")
|
||||
plt.savefig(self.get_heatmap_plot_file(epoch, order, stage), bbox_inches="tight")
|
||||
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start( # type: ignore
|
||||
self, trainer: Trainer, pl_module: BaseDeepMILModule, batch: Dict, batch_idx: int, unused: int = 0,
|
||||
) -> None:
|
||||
"""Caches loss values per slide at each training step in a local variable self.loss_cache."""
|
||||
@staticmethod
|
||||
def compute_entropy(class_probs: torch.Tensor) -> List[float]:
|
||||
"""Computes the entropy of the class probabilities.
|
||||
|
||||
:param class_probs: The class probabilities.
|
||||
"""
|
||||
return (-torch.sum(class_probs * torch.log(class_probs), dim=1)).tolist()
|
||||
|
||||
def update_loss_cache(self, trainer: Trainer, outputs: BatchResultsType, batch: Dict, stage: ModelKey) -> None:
|
||||
"""Updates the loss cache with the loss values for the current batch."""
|
||||
if self.should_cache_loss_values(trainer.current_epoch):
|
||||
bag_logits, bag_labels, _ = pl_module.compute_bag_labels_logits_and_attn_maps(batch)
|
||||
if pl_module.n_classes > 1:
|
||||
loss = pl_module.loss_fn_no_reduction(bag_logits, bag_labels.long())
|
||||
else:
|
||||
loss = pl_module.loss_fn_no_reduction(bag_logits.squeeze(1), bag_labels.float())
|
||||
self.loss_cache[ResultsKey.LOSS].extend(loss.tolist())
|
||||
self.loss_cache[ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
|
||||
self.loss_cache[stage][ResultsKey.LOSS].extend(outputs[ResultsKey.LOSS_PER_SAMPLE])
|
||||
self.loss_cache[stage][ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
|
||||
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
|
||||
if self.save_tile_ids:
|
||||
self.loss_cache[ResultsKey.TILE_ID].extend(
|
||||
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
|
||||
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in batch[ResultsKey.TILE_ID]]
|
||||
)
|
||||
|
||||
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None:
|
||||
"""Synchronises the processes in DDP mode and resets the loss cache for the next epoch."""
|
||||
if self.should_cache_loss_values(trainer.current_epoch):
|
||||
self.gather_loss_cache(rank=pl_module.global_rank, stage=stage)
|
||||
if pl_module.global_rank == 0:
|
||||
self.save_loss_cache(trainer.current_epoch, stage)
|
||||
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache for all processes
|
||||
|
||||
def save_loss_outliers_analaysis_results(self, stage: ModelKey) -> None:
|
||||
"""Saves the loss outliers analysis results."""
|
||||
all_slides = self.select_slides_for_epoch(epoch=0, stage=stage)
|
||||
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides, stage=stage)
|
||||
|
||||
self.sanity_check_loss_values(all_loss_values_per_slides, stage=stage)
|
||||
self.save_loss_ranks(all_loss_values_per_slides, stage=stage)
|
||||
|
||||
top_slides, top_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=True)
|
||||
self.plot_slides_loss_scatter(top_slides, top_slides_entropy, stage, high=True)
|
||||
|
||||
bottom_slides, bottom_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=False)
|
||||
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_entropy, stage, high=False)
|
||||
|
||||
for epoch in self.epochs_range:
|
||||
epoch_slides = self.select_slides_for_epoch(epoch, stage=stage)
|
||||
|
||||
top_slides = epoch_slides[:self.num_slides_heatmap]
|
||||
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides, stage=stage)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, stage, high=True)
|
||||
|
||||
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
|
||||
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides, stage=stage)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, stage, high=False)
|
||||
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache
|
||||
|
||||
def handle_loss_exceptions(self, stage: ModelKey, exception: Exception) -> None:
|
||||
"""Handles the loss exceptions."""
|
||||
if self.log_exceptions:
|
||||
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
|
||||
# validation.
|
||||
logging.warning(f"Error while detecting {stage} loss values outliers: {exception}")
|
||||
else:
|
||||
# If we want to debug the error, we raise it. This will crash the training. This is useful when
|
||||
# running smoke tests.
|
||||
raise Exception(f"Error while detecting {stage} loss values outliers: {exception}")
|
||||
|
||||
def on_train_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: BaseDeepMILModule,
|
||||
outputs: BatchResultsType,
|
||||
batch: Dict,
|
||||
batch_idx: int,
|
||||
unused: int = 0
|
||||
) -> None:
|
||||
"""Caches train loss values per slide at each training step in a local variable self.loss_cache."""
|
||||
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.TRAIN)
|
||||
|
||||
def on_validation_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: BaseDeepMILModule,
|
||||
outputs: BatchResultsType,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int
|
||||
) -> None:
|
||||
"""Caches validation loss values per slide at each training step in a local variable self.loss_cache."""
|
||||
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.VAL)
|
||||
|
||||
def on_train_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
|
||||
if self.should_cache_loss_values(trainer.current_epoch):
|
||||
self.gather_loss_cache(rank=pl_module.global_rank)
|
||||
if pl_module.global_rank == 0:
|
||||
self.save_loss_cache(trainer.current_epoch)
|
||||
self.loss_cache = self.reset_loss_cache() # reset loss cache for all processes
|
||||
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.TRAIN)
|
||||
|
||||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
|
||||
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.VAL)
|
||||
|
||||
def on_train_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Hook called at the end of training. Plot the loss heatmap and scratter plots after ranking the slides by loss
|
||||
values."""
|
||||
|
||||
if pl_module.global_rank == 0:
|
||||
try:
|
||||
all_slides = self.select_slides_for_epoch(epoch=0)
|
||||
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides)
|
||||
|
||||
self.sanity_check_loss_values(all_loss_values_per_slides)
|
||||
self.save_loss_ranks(all_loss_values_per_slides)
|
||||
|
||||
top_slides, top_slides_loss = self.select_slides_and_losses_across_epochs(high=True)
|
||||
self.plot_slides_loss_scatter(top_slides, top_slides_loss, high=True)
|
||||
|
||||
bottom_slides, bottom_slides_loss = self.select_slides_and_losses_across_epochs(high=False)
|
||||
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_loss, high=False)
|
||||
|
||||
for epoch in self.epochs_range:
|
||||
epoch_slides = self.select_slides_for_epoch(epoch)
|
||||
|
||||
top_slides = epoch_slides[:self.num_slides_heatmap]
|
||||
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, high=True)
|
||||
|
||||
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
|
||||
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides)
|
||||
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, high=False)
|
||||
|
||||
self.save_loss_outliers_analaysis_results(stage=ModelKey.TRAIN)
|
||||
except Exception as e:
|
||||
if self.log_exceptions:
|
||||
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
|
||||
# validation.
|
||||
logging.warning(f"Error while detecting loss values outliers: {e}")
|
||||
else:
|
||||
# If we want to debug the error, we raise it. This will crash the training. This is useful when
|
||||
# running smoke tests.
|
||||
raise Exception(f"Error while detecting loss values outliers: {e}")
|
||||
self.handle_loss_exceptions(stage=ModelKey.TRAIN, exception=e)
|
||||
|
||||
def on_validation_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
|
||||
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
|
||||
loss values."""
|
||||
epoch = trainer.current_epoch
|
||||
if pl_module.global_rank == 0 and not pl_module._run_extra_val_epoch and epoch == (self.max_epochs - 1):
|
||||
try:
|
||||
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
|
||||
except Exception as e:
|
||||
self.handle_loss_exceptions(stage=ModelKey.VAL, exception=e)
|
||||
|
|
|
@ -50,6 +50,7 @@ class ResultsKey(str, Enum):
|
|||
FEATURES = 'features'
|
||||
IMAGE_PATH = 'image_path'
|
||||
LOSS = 'loss'
|
||||
LOSS_PER_SAMPLE = 'loss_per_sample'
|
||||
PROB = 'prob'
|
||||
CLASS_PROBS = 'prob_class'
|
||||
PRED_LABEL = 'pred_label'
|
||||
|
@ -59,6 +60,7 @@ class ResultsKey(str, Enum):
|
|||
TILE_TOP = 'top'
|
||||
TILE_RIGHT = 'right'
|
||||
TILE_BOTTOM = 'bottom'
|
||||
ENTROPY = 'entropy'
|
||||
|
||||
|
||||
class MetricsKey(str, Enum):
|
||||
|
|
|
@ -87,7 +87,7 @@ def mock_panda_tiles_root_dir(
|
|||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_slides=15,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
|
|
|
@ -269,7 +269,7 @@ def assert_train_step(module: BaseDeepMILModule, data_module: HistoDataModule, u
|
|||
train_data_loader = data_module.train_dataloader()
|
||||
for batch_idx, batch in enumerate(train_data_loader):
|
||||
batch = move_batch_to_expected_device(batch, use_gpu)
|
||||
loss = module.training_step(batch, batch_idx)
|
||||
loss = module.training_step(batch, batch_idx)[ResultsKey.LOSS]
|
||||
loss.retain_grad()
|
||||
loss.backward()
|
||||
assert loss.grad is not None
|
||||
|
|
|
@ -5,9 +5,12 @@ import numpy as np
|
|||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from health_cpath.utils.naming import ResultsKey
|
||||
from typing import Callable, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL
|
||||
|
||||
from health_cpath.utils.naming import ModelKey, ResultsKey
|
||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCacheDictType
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
@ -22,29 +25,40 @@ def _assert_loss_cache_contains_n_elements(loss_cache: LossCacheDictType, n: int
|
|||
assert len(loss_cache[key]) == n
|
||||
|
||||
|
||||
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int) -> None:
|
||||
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
|
||||
return {
|
||||
ResultsKey.LOSS: list(range(1, n_slides + 1)),
|
||||
ResultsKey.ENTROPY: list(range(1, n_slides + 1)),
|
||||
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
}
|
||||
|
||||
|
||||
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int, stage: ModelKey) -> None:
|
||||
for epoch in range(epochs):
|
||||
loss_callback.loss_cache = get_loss_cache()
|
||||
loss_callback.save_loss_cache(epoch)
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(n_slides=4, rank=0)
|
||||
loss_callback.save_loss_cache(epoch, stage)
|
||||
|
||||
|
||||
def test_loss_callback_outputs_folder_exist(tmp_path: Path) -> None:
|
||||
@pytest.mark.parametrize("create_outputs_folders", [True, False])
|
||||
def test_loss_callback_outputs_folder_exist(create_outputs_folders: bool, tmp_path: Path) -> None:
|
||||
outputs_folder = tmp_path / "outputs"
|
||||
outputs_folder.mkdir()
|
||||
callback = LossAnalysisCallback(outputs_folder=outputs_folder)
|
||||
for folder in [
|
||||
callback.outputs_folder,
|
||||
callback.cache_folder,
|
||||
callback.scatter_folder,
|
||||
callback.heatmap_folder,
|
||||
callback.anomalies_folder,
|
||||
]:
|
||||
assert folder.exists()
|
||||
callback = LossAnalysisCallback(outputs_folder=outputs_folder, create_outputs_folders=create_outputs_folders)
|
||||
for stage in [ModelKey.TRAIN, ModelKey.VAL]:
|
||||
for folder in [
|
||||
callback.outputs_folder,
|
||||
callback.get_cache_folder(stage),
|
||||
callback.get_scatter_folder(stage),
|
||||
callback.get_heatmap_folder(stage),
|
||||
callback.get_anomalies_folder(stage),
|
||||
]:
|
||||
assert folder.exists() == create_outputs_folders
|
||||
|
||||
|
||||
@pytest.mark.parametrize("analyse_loss", [True, False])
|
||||
def test_analyse_loss_param(analyse_loss: bool) -> None:
|
||||
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"), analyse_loss=analyse_loss)
|
||||
container = BaseMIL(analyse_loss=analyse_loss)
|
||||
container.data_module = MagicMock()
|
||||
callbacks = container.get_callbacks()
|
||||
assert isinstance(callbacks[-1], LossAnalysisCallback) == analyse_loss
|
||||
|
||||
|
@ -53,7 +67,8 @@ def test_analyse_loss_param(analyse_loss: bool) -> None:
|
|||
def test_save_tile_ids_param(save_tile_ids: bool) -> None:
|
||||
callback = LossAnalysisCallback(outputs_folder=Path("foo"), save_tile_ids=save_tile_ids)
|
||||
assert callback.save_tile_ids == save_tile_ids
|
||||
assert (ResultsKey.TILE_ID in callback.loss_cache) == save_tile_ids
|
||||
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.TRAIN]) == save_tile_ids
|
||||
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.VAL]) == save_tile_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("patience", [0, 1, 2])
|
||||
|
@ -93,68 +108,76 @@ def test_loss_analysis_epochs_interval(epochs_interval: int) -> None:
|
|||
current_epoch = 5
|
||||
assert callback.should_cache_loss_values(current_epoch)
|
||||
|
||||
current_epoch = max_epochs # no loss caching for extra validation epoch
|
||||
assert callback.should_cache_loss_values(current_epoch) is False
|
||||
|
||||
def test_on_train_batch_start(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
|
||||
|
||||
def test_on_train_and_val_batch_end(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
|
||||
batch_size = 2
|
||||
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
|
||||
container.batch_size = batch_size
|
||||
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir, analyse_loss=True, batch_size=batch_size)
|
||||
container.setup()
|
||||
container.create_lightning_module_and_store()
|
||||
|
||||
current_epoch = 5
|
||||
trainer = MagicMock(current_epoch=current_epoch)
|
||||
batch = next(iter(container.data_module.train_dataloader()))
|
||||
|
||||
callback = LossAnalysisCallback(outputs_folder=tmp_path)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache, 0)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.TRAIN], 0)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.VAL], 0)
|
||||
dataloader = iter(container.data_module.train_dataloader())
|
||||
|
||||
callback.on_train_batch_start(trainer, container.model, batch, 0, None) # type: ignore
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache, batch_size)
|
||||
def _call_on_batch_end_hook(on_batch_end_hook: Callable, batch_idx: int) -> None:
|
||||
batch = next(dataloader)
|
||||
outputs = container.model.training_step(batch, batch_idx)
|
||||
on_batch_end_hook(trainer, container.model, outputs, batch, batch_idx, 0) # type: ignore
|
||||
|
||||
batch = next(iter(container.data_module.train_dataloader()))
|
||||
callback.on_train_batch_start(trainer, container.model, batch, 1, None) # type: ignore
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache, 2 * batch_size)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks: List[Callable] = [callback.on_train_batch_end, callback.on_validation_batch_end]
|
||||
for stage, on_batch_end_hook in zip(stages, hooks):
|
||||
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=0)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], batch_size)
|
||||
|
||||
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=1)
|
||||
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], 2 * batch_size)
|
||||
|
||||
|
||||
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
|
||||
return {
|
||||
ResultsKey.LOSS: list(range(1, n_slides + 1)),
|
||||
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
|
||||
}
|
||||
|
||||
|
||||
def test_on_train_epoch_end(
|
||||
def test_on_train_and_val_epoch_end(
|
||||
tmp_path: Path, duplicate: bool = False, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
current_epoch = 5
|
||||
current_epoch = 2
|
||||
n_slides_per_process = 4
|
||||
trainer = MagicMock(current_epoch=current_epoch)
|
||||
pl_module = MagicMock(global_rank=rank)
|
||||
|
||||
loss_callback = LossAnalysisCallback(outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2)
|
||||
loss_callback.loss_cache = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
|
||||
if duplicate:
|
||||
# Duplicate slide "id_0" to test that the duplicates are removed
|
||||
loss_callback.loss_cache[ResultsKey.SLIDE_ID][0] = "id_0"
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2, max_epochs=10
|
||||
)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_epoch_end, loss_callback.on_validation_epoch_end]
|
||||
for stage, on_epoch_hook in zip(stages, hooks):
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
|
||||
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache, n_slides_per_process)
|
||||
loss_callback.on_train_epoch_end(trainer, pl_module)
|
||||
# Loss cache is flushed after each epoch
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache, 0)
|
||||
if duplicate:
|
||||
# Duplicate slide "id_0" to test that the duplicates are removed
|
||||
loss_callback.loss_cache[stage][ResultsKey.SLIDE_ID][0] = "id_0"
|
||||
|
||||
if rank > 0:
|
||||
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], n_slides_per_process)
|
||||
on_epoch_hook(trainer, pl_module)
|
||||
# Loss cache is flushed after each epoch
|
||||
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], 0)
|
||||
|
||||
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch)
|
||||
assert loss_callback.cache_folder.exists()
|
||||
assert loss_cache_path.exists()
|
||||
assert loss_cache_path.parent == loss_callback.cache_folder
|
||||
if rank > 0:
|
||||
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
|
||||
|
||||
loss_cache = pd.read_csv(loss_cache_path)
|
||||
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
|
||||
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
|
||||
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
|
||||
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch, stage)
|
||||
assert loss_callback.get_cache_folder(stage).exists()
|
||||
assert loss_cache_path.exists()
|
||||
assert loss_cache_path.parent == loss_callback.get_cache_folder(stage)
|
||||
|
||||
loss_cache = pd.read_csv(loss_cache_path)
|
||||
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
|
||||
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
|
||||
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
|
@ -162,39 +185,54 @@ def test_on_train_epoch_end(
|
|||
def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
|
||||
# Test that the loss cache is saved correctly when using multiple GPUs
|
||||
# First scenario: no duplicates
|
||||
run_distributed(test_on_train_epoch_end, [tmp_path, False], world_size=2)
|
||||
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, False], world_size=2)
|
||||
# Second scenario: introduce duplicates
|
||||
run_distributed(test_on_train_epoch_end, [tmp_path, True], world_size=2)
|
||||
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, True], world_size=2)
|
||||
|
||||
|
||||
def test_on_train_end(tmp_path: Path) -> None:
|
||||
trainer = MagicMock()
|
||||
pl_module = MagicMock(global_rank=0)
|
||||
def test_on_train_and_val_end(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
dump_loss_cache_for_epochs(loss_callback, max_epochs)
|
||||
loss_callback.on_train_end(trainer, pl_module)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
|
||||
for stage, on_end_hook in zip(stages, hooks):
|
||||
dump_loss_cache_for_epochs(loss_callback, max_epochs, stage)
|
||||
on_end_hook(trainer, pl_module)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_loss_cache_file(epoch).exists()
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_loss_cache_file(epoch, stage).exists()
|
||||
|
||||
# check save_loss_ranks outputs
|
||||
assert loss_callback.get_all_epochs_loss_cache_file().exists()
|
||||
assert loss_callback.get_loss_stats_file().exists()
|
||||
assert loss_callback.get_loss_ranks_file().exists()
|
||||
assert loss_callback.get_loss_ranks_stats_file().exists()
|
||||
# check save_loss_ranks outputs
|
||||
assert loss_callback.get_all_epochs_loss_cache_file(stage).exists()
|
||||
assert loss_callback.get_loss_stats_file(stage).exists()
|
||||
assert loss_callback.get_loss_ranks_file(stage).exists()
|
||||
assert loss_callback.get_loss_ranks_stats_file(stage).exists()
|
||||
|
||||
# check plot_slides_loss_scatter outputs
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST).exists()
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST).exists()
|
||||
# check plot_slides_loss_scatter outputs
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST, stage).exists()
|
||||
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST, stage).exists()
|
||||
|
||||
# check plot_loss_heatmap_for_slides_of_epoch outputs
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST).exists()
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST).exists()
|
||||
# check plot_loss_heatmap_for_slides_of_epoch outputs
|
||||
for epoch in range(max_epochs):
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST, stage).exists()
|
||||
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST, stage).exists()
|
||||
|
||||
|
||||
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=True)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=0)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
with patch.object(loss_callback, "save_loss_outliers_analaysis_results") as mock_func:
|
||||
loss_callback.on_validation_end(trainer, pl_module)
|
||||
mock_func.assert_not_called()
|
||||
|
||||
|
||||
def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
|
@ -203,36 +241,41 @@ def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> Non
|
|||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
|
||||
)
|
||||
for epoch in range(max_epochs):
|
||||
loss_callback.loss_cache = get_loss_cache(n_slides)
|
||||
loss_callback.loss_cache[ResultsKey.LOSS][epoch] = np.nan
|
||||
loss_callback.save_loss_cache(epoch)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
for stage in stages:
|
||||
for epoch in range(max_epochs):
|
||||
loss_callback.loss_cache[stage] = get_loss_cache(n_slides)
|
||||
loss_callback.loss_cache[stage][ResultsKey.LOSS][epoch] = np.nan # Introduce NaNs
|
||||
loss_callback.save_loss_cache(epoch, stage)
|
||||
|
||||
all_slides = loss_callback.select_slides_for_epoch(epoch=0)
|
||||
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides)
|
||||
loss_callback.sanity_check_loss_values(all_loss_values_per_slides)
|
||||
all_slides = loss_callback.select_slides_for_epoch(epoch=0, stage=stage)
|
||||
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides, stage)
|
||||
loss_callback.sanity_check_loss_values(all_loss_values_per_slides, stage)
|
||||
|
||||
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
|
||||
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
|
||||
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
|
||||
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
|
||||
|
||||
assert loss_callback.nan_slides == ["id_1", "id_0"]
|
||||
assert loss_callback.get_nan_slides_file().exists()
|
||||
assert loss_callback.get_nan_slides_file().parent == loss_callback.anomalies_folder
|
||||
assert loss_callback.nan_slides[stage] == ["id_1", "id_0"]
|
||||
assert loss_callback.get_nan_slides_file(stage).exists()
|
||||
assert loss_callback.get_nan_slides_file(stage).parent == loss_callback.get_anomalies_folder(stage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_exceptions", [True, False])
|
||||
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
max_epochs = 3
|
||||
trainer = MagicMock()
|
||||
pl_module = MagicMock(global_rank=0)
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs,
|
||||
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions
|
||||
)
|
||||
message = "Error while detecting loss values outliers:"
|
||||
if log_exceptions:
|
||||
loss_callback.on_train_end(trainer, pl_module)
|
||||
assert message in caplog.records[-1].getMessage()
|
||||
else:
|
||||
with pytest.raises(Exception, match=fr"{message}"):
|
||||
loss_callback.on_train_end(trainer, pl_module)
|
||||
stages = [ModelKey.TRAIN, ModelKey.VAL]
|
||||
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
|
||||
for stage, on_end_hook in zip(stages, hooks):
|
||||
message = "Error while detecting " + stage.value + " loss values outliers"
|
||||
if log_exceptions:
|
||||
on_end_hook(trainer, pl_module)
|
||||
assert message in caplog.records[-1].getMessage()
|
||||
else:
|
||||
with pytest.raises(Exception, match=fr"{message}"):
|
||||
on_end_hook(trainer, pl_module)
|
||||
|
|
Загрузка…
Ссылка в новой задаче