ENH: Track validation loss and add entropy value (#629)

Add entropy values to detect ambiguity and track validation loss during training
This commit is contained in:
Kenza Bouzid 2022-10-12 14:24:25 +01:00 коммит произвёл GitHub
Родитель 85a317d0f7
Коммит e2c1ca1cb4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 411 добавлений и 262 удалений

Просмотреть файл

@ -291,7 +291,8 @@ class BaseMILTiles(BaseMIL):
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler)
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module
@ -334,7 +335,8 @@ class BaseMILSlides(BaseMIL):
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler)
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module

Просмотреть файл

@ -51,7 +51,8 @@ class BaseDeepMILModule(LightningModule):
encoder_params: EncoderParams = EncoderParams(),
pooling_params: PoolingParams = PoolingParams(),
optimizer_params: OptimizerParams = OptimizerParams(),
outputs_handler: Optional[DeepMILOutputsHandler] = None) -> None:
outputs_handler: Optional[DeepMILOutputsHandler] = None,
analyse_loss: Optional[bool] = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
@ -71,6 +72,7 @@ class BaseDeepMILModule(LightningModule):
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
metrics logging).
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
"""
super().__init__()
@ -100,7 +102,8 @@ class BaseDeepMILModule(LightningModule):
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
self.classifier_fn = self.get_classifier()
self.activation_fn = self.get_activation()
self.loss_fn = self.get_loss()
self.analyse_loss = analyse_loss
self.loss_fn = self.get_loss(reduction="mean")
self.loss_fn_no_reduction = self.get_loss(reduction="none")
# Metrics Objects
@ -295,14 +298,17 @@ class BaseDeepMILModule(LightningModule):
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
raise NotImplementedError
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
if self.n_classes > 1:
return loss_fn(bag_logits, bag_labels.long())
else:
return loss_fn(bag_logits.squeeze(1), bag_labels.float())
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> BatchResultsType:
bag_logits, bag_labels, bag_attn_list = self.compute_bag_labels_logits_and_attn_maps(batch)
if self.n_classes > 1:
loss = self.loss_fn(bag_logits, bag_labels.long())
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
loss = self._compute_loss(self.loss_fn, bag_logits, bag_labels)
predicted_probs = self.activation_fn(bag_logits)
if self.n_classes > 1:
@ -320,9 +326,13 @@ class BaseDeepMILModule(LightningModule):
if self.n_classes == 1:
predicted_probs = predicted_probs.squeeze(dim=1)
results = dict()
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
bag_labels = bag_labels.view(-1, 1)
results = dict()
for metric_object in self.get_metrics_dict(stage).values():
metric_object.update(predicted_probs, bag_labels.view(batch_size,).int())
results.update({ResultsKey.LOSS: loss,
@ -341,13 +351,17 @@ class BaseDeepMILModule(LightningModule):
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
return results
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
if self.verbose:
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
return train_result[ResultsKey.LOSS]
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
if self.analyse_loss:
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS]})
return results
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)

Просмотреть файл

@ -11,17 +11,20 @@ import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from health_cpath.models.deepmil import BaseDeepMILModule
from health_cpath.utils.naming import ResultsKey
from health_cpath.utils.naming import ModelKey, ResultsKey
from health_cpath.utils.output_utils import BatchResultsType
LossCacheDictType = Dict[ResultsKey, List]
LossCacheDictType = Dict[Union[ResultsKey, str], List]
LossDictType = Dict[str, List]
AnomalyDictType = Dict[ModelKey, List[str]]
class LossCallbackParams(param.Parameterized):
@ -89,6 +92,7 @@ class LossAnalysisCallback(Callback):
num_slides_heatmap: int = 20,
save_tile_ids: bool = False,
log_exceptions: bool = True,
create_outputs_folders: bool = True,
) -> None:
"""
@ -102,6 +106,7 @@ class LossAnalysisCallback(Callback):
False. This is useful to analyse the tiles that are contributing to the loss value of a slide.
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
False will raise the intercepted exceptions.
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
"""
self.patience = patience
@ -112,89 +117,89 @@ class LossAnalysisCallback(Callback):
self.save_tile_ids = save_tile_ids
self.log_exceptions = log_exceptions
self.outputs_folder = outputs_folder / "loss_values_callback"
self.create_outputs_folders()
self.outputs_folder = outputs_folder / "loss_analysis_callback"
if create_outputs_folders:
self.create_outputs_folders()
self.loss_cache = self.reset_loss_cache()
self.loss_cache: Dict[ModelKey, LossCacheDictType] = {
stage: self.get_empty_loss_cache() for stage in [ModelKey.TRAIN, ModelKey.VAL]
}
self.epochs_range = list(range(self.patience, self.max_epochs, self.epochs_interval))
self.nan_slides: List[str] = []
self.anomaly_slides: List[str] = []
self.nan_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
self.anomaly_slides: AnomalyDictType = {stage: [] for stage in [ModelKey.TRAIN, ModelKey.VAL]}
@property
def cache_folder(self) -> Path:
return self.outputs_folder / "loss_cache"
def get_cache_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_cache"
@property
def scatter_folder(self) -> Path:
return self.outputs_folder / "loss_scatter"
def get_scatter_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_scatter"
@property
def heatmap_folder(self) -> Path:
return self.outputs_folder / "loss_heatmap"
def get_heatmap_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_heatmap"
@property
def stats_folder(self) -> Path:
return self.outputs_folder / "loss_stats"
def get_stats_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_stats"
@property
def anomalies_folder(self) -> Path:
return self.outputs_folder / "loss_anomalies"
def get_anomalies_folder(self, stage: ModelKey) -> Path:
return self.outputs_folder / f"{stage}/loss_anomalies"
def create_outputs_folders(self) -> None:
folders = [
self.cache_folder,
self.scatter_folder,
self.heatmap_folder,
self.stats_folder,
self.anomalies_folder,
self.get_cache_folder,
self.get_scatter_folder,
self.get_heatmap_folder,
self.get_stats_folder,
self.get_anomalies_folder,
]
stages = [ModelKey.TRAIN, ModelKey.VAL]
for folder in folders:
os.makedirs(folder, exist_ok=True)
for stage in stages:
os.makedirs(folder(stage), exist_ok=True)
def reset_loss_cache(self) -> LossCacheDictType:
keys = [ResultsKey.LOSS, ResultsKey.SLIDE_ID]
def get_empty_loss_cache(self) -> LossCacheDictType:
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
if self.save_tile_ids:
keys.append(ResultsKey.TILE_ID)
return {key: [] for key in keys}
def _get_filename(self, filename: str, epoch: int, order: Optional[str] = None) -> str:
zero_filled_epoch = str(epoch).zfill(len(str(self.max_epochs)))
filename = filename.format(zero_filled_epoch, order) if order else filename.format(zero_filled_epoch)
return filename
def _format_epoch(self, epoch: int) -> str:
return str(epoch).zfill(len(str(self.max_epochs)))
def get_loss_cache_file(self, epoch: int) -> Path:
return self.cache_folder / self._get_filename(filename="epoch_{}.csv", epoch=epoch)
def get_loss_cache_file(self, epoch: int, stage: ModelKey) -> Path:
return self.get_cache_folder(stage) / f"epoch_{self._format_epoch(epoch)}.csv"
def get_all_epochs_loss_cache_file(self) -> Path:
return self.cache_folder / "all_epochs.csv"
def get_all_epochs_loss_cache_file(self, stage: ModelKey) -> Path:
return self.get_cache_folder(stage) / f"all_epochs_{stage}.csv"
def get_loss_stats_file(self) -> Path:
return self.stats_folder / "loss_stats.csv"
def get_loss_stats_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_stats_{stage}.csv"
def get_loss_ranks_file(self) -> Path:
return self.stats_folder / "loss_ranks.csv"
def get_loss_ranks_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_ranks_{stage}.csv"
def get_loss_ranks_stats_file(self) -> Path:
return self.stats_folder / "loss_ranks_stats.csv"
def get_loss_ranks_stats_file(self, stage: ModelKey) -> Path:
return self.get_stats_folder(stage) / f"loss_ranks_stats_{stage}.csv"
def get_nan_slides_file(self) -> Path:
return self.anomalies_folder / "nan_slides.txt"
def get_nan_slides_file(self, stage: ModelKey) -> Path:
return self.get_anomalies_folder(stage) / f"nan_slides_{stage}.txt"
def get_anomaly_slides_file(self) -> Path:
return self.anomalies_folder / "anomaly_slides.txt"
def get_anomaly_slides_file(self, stage: ModelKey) -> Path:
return self.get_anomalies_folder(stage) / f"anomaly_slides_{stage}.txt"
def get_scatter_plot_file(self, order: str) -> Path:
return self.scatter_folder / "slides_with_{}_loss_values.png".format(order)
def get_scatter_plot_file(self, order: str, stage: ModelKey) -> Path:
return self.get_scatter_folder(stage) / f"slides_with_{order}_loss_values_{stage}.png"
def get_heatmap_plot_file(self, epoch: int, order: str) -> Path:
return self.heatmap_folder / self._get_filename(filename="epoch_{}_{}_slides.png", epoch=epoch, order=order)
def get_heatmap_plot_file(self, epoch: int, order: str, stage: ModelKey) -> Path:
return self.get_heatmap_folder(stage) / f"epoch_{self._format_epoch(epoch)}_{order}_slides_{stage}.png"
def read_loss_cache(self, epoch: int, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS]
return pd.read_csv(self.get_loss_cache_file(epoch), index_col=idx_col, usecols=columns)
def read_loss_cache(self, epoch: int, stage: ModelKey, idx_col: Optional[ResultsKey] = None) -> pd.DataFrame:
columns = [ResultsKey.SLIDE_ID, ResultsKey.LOSS, ResultsKey.ENTROPY]
return pd.read_csv(self.get_loss_cache_file(epoch, stage), index_col=idx_col, usecols=columns)
def should_cache_loss_values(self, current_epoch: int) -> bool:
if current_epoch >= self.max_epochs:
return False # Don't cache loss values for the extra validation epoch
current_epoch = current_epoch + 1
first_time = (current_epoch - self.patience) == 1
return first_time or (
@ -203,35 +208,41 @@ class LossAnalysisCallback(Callback):
def merge_loss_caches(self, loss_caches: List[LossCacheDictType]) -> LossCacheDictType:
"""Merges the loss caches from all the workers into a single loss cache"""
loss_cache = self.reset_loss_cache()
loss_cache = self.get_empty_loss_cache()
for loss_cache_per_device in loss_caches:
for key in loss_cache.keys():
loss_cache[key].extend(loss_cache_per_device[key])
return loss_cache
def gather_loss_cache(self, rank: int) -> None:
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
"""Gathers the loss cache from all the workers"""
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
if world_size > 1:
loss_caches = [None] * world_size
torch.distributed.all_gather_object(loss_caches, self.loss_cache)
torch.distributed.all_gather_object(loss_caches, self.loss_cache[stage])
if rank == 0:
self.loss_cache = self.merge_loss_caches(loss_caches) # type: ignore
self.loss_cache[stage] = self.merge_loss_caches(loss_caches) # type: ignore
def save_loss_cache(self, current_epoch: int) -> None:
def save_loss_cache(self, current_epoch: int, stage: ModelKey) -> None:
"""Saves the loss cache to a csv file"""
loss_cache_df = pd.DataFrame(self.loss_cache)
loss_cache_df = pd.DataFrame(self.loss_cache[stage])
# Some slides may be appear multiple times in the loss cache in DDP mode. The Distributed Sampler duplicate
# some slides to even out the number of samples per device, so we only keep the first occurrence.
loss_cache_df.drop_duplicates(subset=ResultsKey.SLIDE_ID, inplace=True, keep="first")
loss_cache_df = loss_cache_df.sort_values(by=ResultsKey.LOSS, ascending=False)
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch), index=False)
loss_cache_df.to_csv(self.get_loss_cache_file(current_epoch, stage), index=False)
def _select_values_for_epoch(
self, keys: List[ResultsKey], epoch: int, high: Optional[bool] = None, num_values: Optional[int] = None
self,
keys: List[ResultsKey],
epoch: int,
stage: ModelKey,
high: Optional[bool] = None,
num_values: Optional[int] = None
) -> List[np.ndarray]:
loss_cache = self.read_loss_cache(epoch)
"""Selects the values corresponding to keys from a dataframe for the given epoch and stage"""
loss_cache = self.read_loss_cache(epoch, stage)
return_values = []
for key in keys:
values = loss_cache[key].values
@ -246,52 +257,63 @@ class LossAnalysisCallback(Callback):
return return_values
def select_slides_for_epoch(
self, epoch: int, high: Optional[bool] = None, num_slides: Optional[int] = None
self, epoch: int, stage: ModelKey, high: Optional[bool] = None, num_slides: Optional[int] = None
) -> np.ndarray:
"""Selects slides in ascending/descending order of loss values at a given epoch
:param epoch: The epoch to select the slides from.
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values. If None, selects all slides.
:param num_slides: The number of slides to select. If None, selects all slides.
"""
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, high, num_slides)[0]
return self._select_values_for_epoch([ResultsKey.SLIDE_ID], epoch, stage, high, num_slides)[0]
def select_slides_losses_for_epoch(self, epoch: int, high: Optional[bool] = None) -> Tuple[np.ndarray, np.ndarray]:
def select_slides_entropy_for_epoch(
self, epoch: int, stage: ModelKey, high: Optional[bool] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""Selects slides and loss values of slides in ascending/descending order of loss at a given epoch
:param epoch: The epoch to select the slides from.
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values. If None, selects all slides.
"""
keys = [ResultsKey.SLIDE_ID, ResultsKey.LOSS]
return_values = self._select_values_for_epoch(keys, epoch, high, self.num_slides_scatter)
keys = [ResultsKey.SLIDE_ID, ResultsKey.ENTROPY]
return_values = self._select_values_for_epoch(keys, epoch, stage, high, self.num_slides_scatter)
return return_values[0], return_values[1]
def select_all_losses_for_selected_slides(self, slides: np.ndarray) -> LossDictType:
"""Selects the loss values for a given set of slides"""
def select_all_losses_for_selected_slides(self, slides: np.ndarray, stage: ModelKey) -> LossDictType:
"""Selects the loss values for a given set of slides
:param slides: The slides to select the loss values for.
:param stage: The model's stage (train/val).
"""
slides_loss_values: LossDictType = {slide_id: [] for slide_id in slides}
for epoch in self.epochs_range:
loss_cache = self.read_loss_cache(epoch, idx_col=ResultsKey.SLIDE_ID)
loss_cache = self.read_loss_cache(epoch, stage, idx_col=ResultsKey.SLIDE_ID)
for slide_id in slides:
slides_loss_values[slide_id].append(loss_cache.loc[slide_id, ResultsKey.LOSS])
return slides_loss_values
def select_slides_and_losses_across_epochs(self, high: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""Selects the slides with the highest/lowest loss values across epochs
def select_slides_and_entropy_across_epochs(
self, stage: ModelKey, high: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""Selects the slides with the highest/lowest loss values across epochs and their entropy values
:param stage: The model's stage (train/val).
:param high: If True, selects the slides with the highest loss values, else selects the slides with the lowest
loss values.
"""
slides = []
slides_loss = []
slides_entropy = []
for epoch in self.epochs_range:
epoch_slides, epoch_slides_loss = self.select_slides_losses_for_epoch(epoch, high)
epoch_slides, epoch_slides_entropy = self.select_slides_entropy_for_epoch(epoch, stage, high)
slides.append(epoch_slides)
slides_loss.append(epoch_slides_loss)
slides_entropy.append(epoch_slides_entropy)
return np.array(slides).T, np.array(slides_loss).T
return np.array(slides).T, np.array(slides_entropy).T
def save_slide_ids(self, slide_ids: List[str], path: Path) -> None:
"""Dumps the slides ids in a txt file."""
@ -300,10 +322,11 @@ class LossAnalysisCallback(Callback):
for slide_id in slide_ids:
f.write(f"{slide_id}\n")
def sanity_check_loss_values(self, loss_values: LossDictType) -> None:
def sanity_check_loss_values(self, loss_values: LossDictType, stage: ModelKey) -> None:
"""Checks if there are any NaNs or any other potential annomalies in the loss values.
:param loss_values: The loss values for all slides.
:param stage: The model's stage (train/val).
"""
# We don't want any of these exceptions to interrupt validation. So we catch them and log them.
loss_values_copy = loss_values.copy()
@ -311,68 +334,83 @@ class LossAnalysisCallback(Callback):
try:
if np.isnan(loss).any():
logging.warning(f"NaNs found in loss values for slide {slide_id}.")
self.nan_slides.append(slide_id)
self.nan_slides[stage].append(slide_id)
loss_values.pop(slide_id)
except Exception as e:
logging.warning(f"Error while checking for NaNs in loss values for slide {slide_id} with error {e}.")
logging.warning(f"Loss values that caused the issue: {loss}")
self.anomaly_slides.append(slide_id)
self.anomaly_slides[stage].append(slide_id)
loss_values.pop(slide_id)
self.save_slide_ids(self.nan_slides, self.get_nan_slides_file())
self.save_slide_ids(self.anomaly_slides, self.get_anomaly_slides_file())
self.save_slide_ids(self.nan_slides[stage], self.get_nan_slides_file(stage))
self.save_slide_ids(self.anomaly_slides[stage], self.get_anomaly_slides_file(stage))
def save_loss_ranks(self, slides_loss_values: LossDictType) -> None:
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files."""
def save_loss_ranks(self, slides_loss_values: LossDictType, stage: ModelKey) -> None:
"""Saves the loss ranks for each slide across all epochs and their respective statistics in csv files.
:param slides_loss_values: The loss values for all slides.
:param stage: The model's stage (train/val).
"""
loss_df = pd.DataFrame(slides_loss_values).T
loss_df.index.names = [ResultsKey.SLIDE_ID.value]
loss_df.to_csv(self.get_all_epochs_loss_cache_file())
loss_df.to_csv(self.get_all_epochs_loss_cache_file(stage))
loss_stats = loss_df.T.describe().T.sort_values(by="mean", ascending=False)
loss_stats.to_csv(self.get_loss_stats_file())
loss_stats.to_csv(self.get_loss_stats_file(stage))
loss_ranks = loss_df.rank(ascending=False)
loss_ranks.to_csv(self.get_loss_ranks_file())
loss_ranks.to_csv(self.get_loss_ranks_file(stage))
loss_ranks_stats = loss_ranks.T.describe().T.sort_values("mean", ascending=True)
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file())
loss_ranks_stats.to_csv(self.get_loss_ranks_stats_file(stage))
def plot_slides_loss_scatter(
self,
slides: np.ndarray,
slides_loss: np.ndarray,
slides_entropy: np.ndarray,
stage: ModelKey,
high: bool = True,
figsize: Tuple[float, float] = (30, 30),
) -> None:
"""Plots the slides with the highest/lowest loss values across epochs in a scatter plot
:param slides: The slides ids.
:param slides_loss: The loss values for each slide.
:param figsize: The figure size, defaults to (20, 20)
:param slides_entropy: The entropy values for each slide.
:param stage: The model's stage (train/val).
:param figsize: The figure size, defaults to (30, 30)
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
"""
label = self.TOP if high else self.BOTTOM
plt.figure(figsize=figsize)
markers_size = [15 * i for i in range(1, self.num_slides_scatter + 1)]
markers_size = markers_size[::-1] if high else markers_size
colors = cm.rainbow(np.linspace(0, 1, self.num_slides_scatter))
for i in range(self.num_slides_scatter - 1, -1, -1):
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}")
for loss, epoch, slide in zip(slides_loss[i], self.epochs_range, slides[i]):
plt.annotate(f"{loss:.3f}", (epoch, slide))
plt.scatter(self.epochs_range, slides[i], label=f"{label}_{i+1}", s=markers_size[i], color=colors[i])
for entropy, epoch, slide in zip(slides_entropy[i], self.epochs_range, slides[i]):
plt.annotate(f"{entropy:.3f}", (epoch, slide))
plt.xlabel(self.X_LABEL)
plt.ylabel(self.Y_LABEL)
order = self.HIGHEST if high else self.LOWEST
plt.title(f"Slides with {order} loss values per epoch.")
plt.title(f"Slides with {order} loss values per epoch and their entropy values")
plt.xticks(self.epochs_range)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.grid(True, linestyle="--")
plt.savefig(self.get_scatter_plot_file(order), bbox_inches="tight")
plt.savefig(self.get_scatter_plot_file(order, stage), bbox_inches="tight")
def plot_loss_heatmap_for_slides_of_epoch(
self, slides_loss_values: LossDictType, epoch: int, high: bool, figsize: Tuple[float, float] = (15, 15)
self,
slides_loss_values: LossDictType,
epoch: int,
stage: ModelKey,
high: bool,
figsize: Tuple[float, float] = (15, 15)
) -> None:
"""Plots the loss values for each slide across all epochs in a heatmap.
:param slides_loss_values: The loss values for each slide across all epochs.
:param epoch: The epoch used to select the slides.
:param stage: The model's stage (train/val).
:param high: If True, plots the slides with the highest loss values, else plots the slides with the lowest loss.
:param figsize: The figure size, defaults to (15, 15)
"""
@ -384,69 +422,119 @@ class LossAnalysisCallback(Callback):
plt.xlabel(self.X_LABEL)
plt.ylabel(self.Y_LABEL)
plt.title(f"Loss values evolution for {order} slides of epoch {epoch}")
plt.savefig(self.get_heatmap_plot_file(epoch, order), bbox_inches="tight")
plt.savefig(self.get_heatmap_plot_file(epoch, order, stage), bbox_inches="tight")
@torch.no_grad()
def on_train_batch_start( # type: ignore
self, trainer: Trainer, pl_module: BaseDeepMILModule, batch: Dict, batch_idx: int, unused: int = 0,
) -> None:
"""Caches loss values per slide at each training step in a local variable self.loss_cache."""
@staticmethod
def compute_entropy(class_probs: torch.Tensor) -> List[float]:
"""Computes the entropy of the class probabilities.
:param class_probs: The class probabilities.
"""
return (-torch.sum(class_probs * torch.log(class_probs), dim=1)).tolist()
def update_loss_cache(self, trainer: Trainer, outputs: BatchResultsType, batch: Dict, stage: ModelKey) -> None:
"""Updates the loss cache with the loss values for the current batch."""
if self.should_cache_loss_values(trainer.current_epoch):
bag_logits, bag_labels, _ = pl_module.compute_bag_labels_logits_and_attn_maps(batch)
if pl_module.n_classes > 1:
loss = pl_module.loss_fn_no_reduction(bag_logits, bag_labels.long())
else:
loss = pl_module.loss_fn_no_reduction(bag_logits.squeeze(1), bag_labels.float())
self.loss_cache[ResultsKey.LOSS].extend(loss.tolist())
self.loss_cache[ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
self.loss_cache[stage][ResultsKey.LOSS].extend(outputs[ResultsKey.LOSS_PER_SAMPLE])
self.loss_cache[stage][ResultsKey.SLIDE_ID].extend([slides[0] for slides in batch[ResultsKey.SLIDE_ID]])
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
if self.save_tile_ids:
self.loss_cache[ResultsKey.TILE_ID].extend(
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in batch[ResultsKey.TILE_ID]]
)
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None:
"""Synchronises the processes in DDP mode and resets the loss cache for the next epoch."""
if self.should_cache_loss_values(trainer.current_epoch):
self.gather_loss_cache(rank=pl_module.global_rank, stage=stage)
if pl_module.global_rank == 0:
self.save_loss_cache(trainer.current_epoch, stage)
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache for all processes
def save_loss_outliers_analaysis_results(self, stage: ModelKey) -> None:
"""Saves the loss outliers analysis results."""
all_slides = self.select_slides_for_epoch(epoch=0, stage=stage)
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides, stage=stage)
self.sanity_check_loss_values(all_loss_values_per_slides, stage=stage)
self.save_loss_ranks(all_loss_values_per_slides, stage=stage)
top_slides, top_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=True)
self.plot_slides_loss_scatter(top_slides, top_slides_entropy, stage, high=True)
bottom_slides, bottom_slides_entropy = self.select_slides_and_entropy_across_epochs(stage, high=False)
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_entropy, stage, high=False)
for epoch in self.epochs_range:
epoch_slides = self.select_slides_for_epoch(epoch, stage=stage)
top_slides = epoch_slides[:self.num_slides_heatmap]
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides, stage=stage)
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, stage, high=True)
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides, stage=stage)
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, stage, high=False)
self.loss_cache[stage] = self.get_empty_loss_cache() # reset loss cache
def handle_loss_exceptions(self, stage: ModelKey, exception: Exception) -> None:
"""Handles the loss exceptions."""
if self.log_exceptions:
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
# validation.
logging.warning(f"Error while detecting {stage} loss values outliers: {exception}")
else:
# If we want to debug the error, we raise it. This will crash the training. This is useful when
# running smoke tests.
raise Exception(f"Error while detecting {stage} loss values outliers: {exception}")
def on_train_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: BaseDeepMILModule,
outputs: BatchResultsType,
batch: Dict,
batch_idx: int,
unused: int = 0
) -> None:
"""Caches train loss values per slide at each training step in a local variable self.loss_cache."""
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.TRAIN)
def on_validation_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: BaseDeepMILModule,
outputs: BatchResultsType,
batch: Any,
batch_idx: int,
dataloader_idx: int
) -> None:
"""Caches validation loss values per slide at each training step in a local variable self.loss_cache."""
self.update_loss_cache(trainer, outputs, batch, stage=ModelKey.VAL)
def on_train_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
if self.should_cache_loss_values(trainer.current_epoch):
self.gather_loss_cache(rank=pl_module.global_rank)
if pl_module.global_rank == 0:
self.save_loss_cache(trainer.current_epoch)
self.loss_cache = self.reset_loss_cache() # reset loss cache for all processes
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.TRAIN)
def on_validation_epoch_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Gathers loss values per slide from all processes at the end of each epoch and saves them to a csv file."""
self.synchronise_processes_and_reset(trainer, pl_module, ModelKey.VAL)
def on_train_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Hook called at the end of training. Plot the loss heatmap and scratter plots after ranking the slides by loss
values."""
if pl_module.global_rank == 0:
try:
all_slides = self.select_slides_for_epoch(epoch=0)
all_loss_values_per_slides = self.select_all_losses_for_selected_slides(all_slides)
self.sanity_check_loss_values(all_loss_values_per_slides)
self.save_loss_ranks(all_loss_values_per_slides)
top_slides, top_slides_loss = self.select_slides_and_losses_across_epochs(high=True)
self.plot_slides_loss_scatter(top_slides, top_slides_loss, high=True)
bottom_slides, bottom_slides_loss = self.select_slides_and_losses_across_epochs(high=False)
self.plot_slides_loss_scatter(bottom_slides, bottom_slides_loss, high=False)
for epoch in self.epochs_range:
epoch_slides = self.select_slides_for_epoch(epoch)
top_slides = epoch_slides[:self.num_slides_heatmap]
top_slides_loss_values = self.select_all_losses_for_selected_slides(top_slides)
self.plot_loss_heatmap_for_slides_of_epoch(top_slides_loss_values, epoch, high=True)
bottom_slides = epoch_slides[-self.num_slides_heatmap:]
bottom_slides_loss_values = self.select_all_losses_for_selected_slides(bottom_slides)
self.plot_loss_heatmap_for_slides_of_epoch(bottom_slides_loss_values, epoch, high=False)
self.save_loss_outliers_analaysis_results(stage=ModelKey.TRAIN)
except Exception as e:
if self.log_exceptions:
# If something goes wrong, we don't want to crash the training. We just log the error and carry on
# validation.
logging.warning(f"Error while detecting loss values outliers: {e}")
else:
# If we want to debug the error, we raise it. This will crash the training. This is useful when
# running smoke tests.
raise Exception(f"Error while detecting loss values outliers: {e}")
self.handle_loss_exceptions(stage=ModelKey.TRAIN, exception=e)
def on_validation_end(self, trainer: Trainer, pl_module: BaseDeepMILModule) -> None: # type: ignore
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
loss values."""
epoch = trainer.current_epoch
if pl_module.global_rank == 0 and not pl_module._run_extra_val_epoch and epoch == (self.max_epochs - 1):
try:
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
except Exception as e:
self.handle_loss_exceptions(stage=ModelKey.VAL, exception=e)

Просмотреть файл

@ -50,6 +50,7 @@ class ResultsKey(str, Enum):
FEATURES = 'features'
IMAGE_PATH = 'image_path'
LOSS = 'loss'
LOSS_PER_SAMPLE = 'loss_per_sample'
PROB = 'prob'
CLASS_PROBS = 'prob_class'
PRED_LABEL = 'pred_label'
@ -59,6 +60,7 @@ class ResultsKey(str, Enum):
TILE_TOP = 'top'
TILE_RIGHT = 'right'
TILE_BOTTOM = 'bottom'
ENTROPY = 'entropy'
class MetricsKey(str, Enum):

Просмотреть файл

@ -87,7 +87,7 @@ def mock_panda_tiles_root_dir(
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=10,
n_slides=15,
n_channels=3,
tile_size=28,
img_size=224,

Просмотреть файл

@ -269,7 +269,7 @@ def assert_train_step(module: BaseDeepMILModule, data_module: HistoDataModule, u
train_data_loader = data_module.train_dataloader()
for batch_idx, batch in enumerate(train_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.training_step(batch, batch_idx)
loss = module.training_step(batch, batch_idx)[ResultsKey.LOSS]
loss.retain_grad()
loss.backward()
assert loss.grad is not None

Просмотреть файл

@ -5,9 +5,12 @@ import numpy as np
import pandas as pd
from pathlib import Path
from unittest.mock import MagicMock
from health_cpath.utils.naming import ResultsKey
from typing import Callable, List
from unittest.mock import MagicMock, patch
from health_cpath.configs.classification.BaseMIL import BaseMIL
from health_cpath.utils.naming import ModelKey, ResultsKey
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCacheDictType
from testhisto.mocks.container import MockDeepSMILETilesPanda
from testhisto.utils.utils_testhisto import run_distributed
@ -22,29 +25,40 @@ def _assert_loss_cache_contains_n_elements(loss_cache: LossCacheDictType, n: int
assert len(loss_cache[key]) == n
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int) -> None:
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
return {
ResultsKey.LOSS: list(range(1, n_slides + 1)),
ResultsKey.ENTROPY: list(range(1, n_slides + 1)),
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
}
def dump_loss_cache_for_epochs(loss_callback: LossAnalysisCallback, epochs: int, stage: ModelKey) -> None:
for epoch in range(epochs):
loss_callback.loss_cache = get_loss_cache()
loss_callback.save_loss_cache(epoch)
loss_callback.loss_cache[stage] = get_loss_cache(n_slides=4, rank=0)
loss_callback.save_loss_cache(epoch, stage)
def test_loss_callback_outputs_folder_exist(tmp_path: Path) -> None:
@pytest.mark.parametrize("create_outputs_folders", [True, False])
def test_loss_callback_outputs_folder_exist(create_outputs_folders: bool, tmp_path: Path) -> None:
outputs_folder = tmp_path / "outputs"
outputs_folder.mkdir()
callback = LossAnalysisCallback(outputs_folder=outputs_folder)
for folder in [
callback.outputs_folder,
callback.cache_folder,
callback.scatter_folder,
callback.heatmap_folder,
callback.anomalies_folder,
]:
assert folder.exists()
callback = LossAnalysisCallback(outputs_folder=outputs_folder, create_outputs_folders=create_outputs_folders)
for stage in [ModelKey.TRAIN, ModelKey.VAL]:
for folder in [
callback.outputs_folder,
callback.get_cache_folder(stage),
callback.get_scatter_folder(stage),
callback.get_heatmap_folder(stage),
callback.get_anomalies_folder(stage),
]:
assert folder.exists() == create_outputs_folders
@pytest.mark.parametrize("analyse_loss", [True, False])
def test_analyse_loss_param(analyse_loss: bool) -> None:
container = MockDeepSMILETilesPanda(tmp_path=Path("foo"), analyse_loss=analyse_loss)
container = BaseMIL(analyse_loss=analyse_loss)
container.data_module = MagicMock()
callbacks = container.get_callbacks()
assert isinstance(callbacks[-1], LossAnalysisCallback) == analyse_loss
@ -53,7 +67,8 @@ def test_analyse_loss_param(analyse_loss: bool) -> None:
def test_save_tile_ids_param(save_tile_ids: bool) -> None:
callback = LossAnalysisCallback(outputs_folder=Path("foo"), save_tile_ids=save_tile_ids)
assert callback.save_tile_ids == save_tile_ids
assert (ResultsKey.TILE_ID in callback.loss_cache) == save_tile_ids
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.TRAIN]) == save_tile_ids
assert (ResultsKey.TILE_ID in callback.loss_cache[ModelKey.VAL]) == save_tile_ids
@pytest.mark.parametrize("patience", [0, 1, 2])
@ -93,68 +108,76 @@ def test_loss_analysis_epochs_interval(epochs_interval: int) -> None:
current_epoch = 5
assert callback.should_cache_loss_values(current_epoch)
current_epoch = max_epochs # no loss caching for extra validation epoch
assert callback.should_cache_loss_values(current_epoch) is False
def test_on_train_batch_start(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
def test_on_train_and_val_batch_end(tmp_path: Path, mock_panda_tiles_root_dir: Path) -> None:
batch_size = 2
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
container.batch_size = batch_size
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir, analyse_loss=True, batch_size=batch_size)
container.setup()
container.create_lightning_module_and_store()
current_epoch = 5
trainer = MagicMock(current_epoch=current_epoch)
batch = next(iter(container.data_module.train_dataloader()))
callback = LossAnalysisCallback(outputs_folder=tmp_path)
_assert_loss_cache_contains_n_elements(callback.loss_cache, 0)
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.TRAIN], 0)
_assert_loss_cache_contains_n_elements(callback.loss_cache[ModelKey.VAL], 0)
dataloader = iter(container.data_module.train_dataloader())
callback.on_train_batch_start(trainer, container.model, batch, 0, None) # type: ignore
_assert_loss_cache_contains_n_elements(callback.loss_cache, batch_size)
def _call_on_batch_end_hook(on_batch_end_hook: Callable, batch_idx: int) -> None:
batch = next(dataloader)
outputs = container.model.training_step(batch, batch_idx)
on_batch_end_hook(trainer, container.model, outputs, batch, batch_idx, 0) # type: ignore
batch = next(iter(container.data_module.train_dataloader()))
callback.on_train_batch_start(trainer, container.model, batch, 1, None) # type: ignore
_assert_loss_cache_contains_n_elements(callback.loss_cache, 2 * batch_size)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks: List[Callable] = [callback.on_train_batch_end, callback.on_validation_batch_end]
for stage, on_batch_end_hook in zip(stages, hooks):
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=0)
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], batch_size)
_call_on_batch_end_hook(on_batch_end_hook, batch_idx=1)
_assert_loss_cache_contains_n_elements(callback.loss_cache[stage], 2 * batch_size)
def get_loss_cache(n_slides: int = 4, rank: int = 0) -> LossCacheDictType:
return {
ResultsKey.LOSS: list(range(1, n_slides + 1)),
ResultsKey.SLIDE_ID: [f"id_{i}" for i in range(rank * n_slides, (rank + 1) * n_slides)],
ResultsKey.TILE_ID: [f"a${i * (rank + 1)}$b" for i in range(rank * n_slides, (rank + 1) * n_slides)],
}
def test_on_train_epoch_end(
def test_on_train_and_val_epoch_end(
tmp_path: Path, duplicate: bool = False, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
current_epoch = 5
current_epoch = 2
n_slides_per_process = 4
trainer = MagicMock(current_epoch=current_epoch)
pl_module = MagicMock(global_rank=rank)
loss_callback = LossAnalysisCallback(outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2)
loss_callback.loss_cache = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
if duplicate:
# Duplicate slide "id_0" to test that the duplicates are removed
loss_callback.loss_cache[ResultsKey.SLIDE_ID][0] = "id_0"
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, num_slides_heatmap=2, num_slides_scatter=2, max_epochs=10
)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_epoch_end, loss_callback.on_validation_epoch_end]
for stage, on_epoch_hook in zip(stages, hooks):
loss_callback.loss_cache[stage] = get_loss_cache(rank=rank, n_slides=n_slides_per_process)
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache, n_slides_per_process)
loss_callback.on_train_epoch_end(trainer, pl_module)
# Loss cache is flushed after each epoch
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache, 0)
if duplicate:
# Duplicate slide "id_0" to test that the duplicates are removed
loss_callback.loss_cache[stage][ResultsKey.SLIDE_ID][0] = "id_0"
if rank > 0:
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], n_slides_per_process)
on_epoch_hook(trainer, pl_module)
# Loss cache is flushed after each epoch
_assert_loss_cache_contains_n_elements(loss_callback.loss_cache[stage], 0)
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch)
assert loss_callback.cache_folder.exists()
assert loss_cache_path.exists()
assert loss_cache_path.parent == loss_callback.cache_folder
if rank > 0:
time.sleep(10) # Wait for rank 0 to save the loss cache in a csv file
loss_cache = pd.read_csv(loss_cache_path)
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
loss_cache_path = loss_callback.get_loss_cache_file(current_epoch, stage)
assert loss_callback.get_cache_folder(stage).exists()
assert loss_cache_path.exists()
assert loss_cache_path.parent == loss_callback.get_cache_folder(stage)
loss_cache = pd.read_csv(loss_cache_path)
total_slides = n_slides_per_process * world_size if not duplicate else n_slides_per_process * world_size - 1
_assert_loss_cache_contains_n_elements(loss_cache, total_slides)
_assert_is_sorted(loss_cache[ResultsKey.LOSS].values)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@ -162,39 +185,54 @@ def test_on_train_epoch_end(
def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
# Test that the loss cache is saved correctly when using multiple GPUs
# First scenario: no duplicates
run_distributed(test_on_train_epoch_end, [tmp_path, False], world_size=2)
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, False], world_size=2)
# Second scenario: introduce duplicates
run_distributed(test_on_train_epoch_end, [tmp_path, True], world_size=2)
run_distributed(test_on_train_and_val_epoch_end, [tmp_path, True], world_size=2)
def test_on_train_end(tmp_path: Path) -> None:
trainer = MagicMock()
pl_module = MagicMock(global_rank=0)
def test_on_train_and_val_end(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
max_epochs = 4
trainer = MagicMock(current_epoch=max_epochs - 1)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
dump_loss_cache_for_epochs(loss_callback, max_epochs)
loss_callback.on_train_end(trainer, pl_module)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
for stage, on_end_hook in zip(stages, hooks):
dump_loss_cache_for_epochs(loss_callback, max_epochs, stage)
on_end_hook(trainer, pl_module)
for epoch in range(max_epochs):
assert loss_callback.get_loss_cache_file(epoch).exists()
for epoch in range(max_epochs):
assert loss_callback.get_loss_cache_file(epoch, stage).exists()
# check save_loss_ranks outputs
assert loss_callback.get_all_epochs_loss_cache_file().exists()
assert loss_callback.get_loss_stats_file().exists()
assert loss_callback.get_loss_ranks_file().exists()
assert loss_callback.get_loss_ranks_stats_file().exists()
# check save_loss_ranks outputs
assert loss_callback.get_all_epochs_loss_cache_file(stage).exists()
assert loss_callback.get_loss_stats_file(stage).exists()
assert loss_callback.get_loss_ranks_file(stage).exists()
assert loss_callback.get_loss_ranks_stats_file(stage).exists()
# check plot_slides_loss_scatter outputs
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST).exists()
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST).exists()
# check plot_slides_loss_scatter outputs
assert loss_callback.get_scatter_plot_file(loss_callback.HIGHEST, stage).exists()
assert loss_callback.get_scatter_plot_file(loss_callback.LOWEST, stage).exists()
# check plot_loss_heatmap_for_slides_of_epoch outputs
for epoch in range(max_epochs):
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST).exists()
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST).exists()
# check plot_loss_heatmap_for_slides_of_epoch outputs
for epoch in range(max_epochs):
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.HIGHEST, stage).exists()
assert loss_callback.get_heatmap_plot_file(epoch, loss_callback.LOWEST, stage).exists()
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=True)
max_epochs = 4
trainer = MagicMock(current_epoch=0)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
with patch.object(loss_callback, "save_loss_outliers_analaysis_results") as mock_func:
loss_callback.on_validation_end(trainer, pl_module)
mock_func.assert_not_called()
def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
@ -203,36 +241,41 @@ def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> Non
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs, num_slides_heatmap=2, num_slides_scatter=2
)
for epoch in range(max_epochs):
loss_callback.loss_cache = get_loss_cache(n_slides)
loss_callback.loss_cache[ResultsKey.LOSS][epoch] = np.nan
loss_callback.save_loss_cache(epoch)
stages = [ModelKey.TRAIN, ModelKey.VAL]
for stage in stages:
for epoch in range(max_epochs):
loss_callback.loss_cache[stage] = get_loss_cache(n_slides)
loss_callback.loss_cache[stage][ResultsKey.LOSS][epoch] = np.nan # Introduce NaNs
loss_callback.save_loss_cache(epoch, stage)
all_slides = loss_callback.select_slides_for_epoch(epoch=0)
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides)
loss_callback.sanity_check_loss_values(all_loss_values_per_slides)
all_slides = loss_callback.select_slides_for_epoch(epoch=0, stage=stage)
all_loss_values_per_slides = loss_callback.select_all_losses_for_selected_slides(all_slides, stage)
loss_callback.sanity_check_loss_values(all_loss_values_per_slides, stage)
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
assert "NaNs found in loss values for slide id_0" in caplog.records[-1].getMessage()
assert "NaNs found in loss values for slide id_1" in caplog.records[0].getMessage()
assert loss_callback.nan_slides == ["id_1", "id_0"]
assert loss_callback.get_nan_slides_file().exists()
assert loss_callback.get_nan_slides_file().parent == loss_callback.anomalies_folder
assert loss_callback.nan_slides[stage] == ["id_1", "id_0"]
assert loss_callback.get_nan_slides_file(stage).exists()
assert loss_callback.get_nan_slides_file(stage).parent == loss_callback.get_anomalies_folder(stage)
@pytest.mark.parametrize("log_exceptions", [True, False])
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
max_epochs = 3
trainer = MagicMock()
pl_module = MagicMock(global_rank=0)
trainer = MagicMock(current_epoch=max_epochs - 1)
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs,
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions
)
message = "Error while detecting loss values outliers:"
if log_exceptions:
loss_callback.on_train_end(trainer, pl_module)
assert message in caplog.records[-1].getMessage()
else:
with pytest.raises(Exception, match=fr"{message}"):
loss_callback.on_train_end(trainer, pl_module)
stages = [ModelKey.TRAIN, ModelKey.VAL]
hooks = [loss_callback.on_train_end, loss_callback.on_validation_end]
for stage, on_end_hook in zip(stages, hooks):
message = "Error while detecting " + stage.value + " loss values outliers"
if log_exceptions:
on_end_hook(trainer, pl_module)
assert message in caplog.records[-1].getMessage()
else:
with pytest.raises(Exception, match=fr"{message}"):
on_end_hook(trainer, pl_module)