Support Multilabel classification models with png files (#391)
- This PR adds support for multilabel classification models. Multiclass models (where labels are mutually exclusive) are not supported. Changes made: - `dataset.csv` can include multiple labels per sample - Loss functions changed to support multilabel classification tasks - Extra report for multilabel classification tasks - Add support for png images Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com> Co-authored-by: melanibe <32590828+melanibe@users.noreply.github.com> Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Родитель
6f475ffe4c
Коммит
917f8d0b30
|
@ -21,6 +21,13 @@ nodes in AzureML. Example: Add `--num_nodes=2` to the commandline arguments to t
|
|||
if `CHANGELOG.md` has been modified.
|
||||
- ([#412](https://github.com/microsoft/InnerEye-DeepLearning/pull/412)) Dataset files can now have arbitrary names, and are no longer restricted to be called
|
||||
`dataset.csv`, via the config field `dataset_csv`. This allows to have a single set of image files in a folder, but multiple datasets derived from it.
|
||||
- ([#391](https://github.com/microsoft/InnerEye-DeepLearning/pull/391)) Support for multilabel classification tasks.
|
||||
Multilabel models can be trained by adding the parameter `class_names` to the config for classification models.
|
||||
`class_names` should contain the name of each label class in the dataset, and the order of names should match the
|
||||
order of class label indices in `dataset.csv`.
|
||||
`dataset.csv` supports multiple labels (indices corresponding to `class_names`) per subject in the label column.
|
||||
Multiple labels should be encoded as a string with labels separated by a `|`, for example "0|2|4".
|
||||
Note that this PR does not add support for multiclass models, where the labels are mutually exclusive.
|
||||
|
||||
### Changed
|
||||
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) Starting an AzureML run now uses the
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
|
||||
|
||||
class DummyMulticlassClassification(ScalarModelBase):
|
||||
"A config file for dummy image classification model for debugging purposes"
|
||||
|
||||
def __init__(self) -> None:
|
||||
num_epochs = 4
|
||||
super().__init__(
|
||||
local_dataset=full_ml_test_data_path("classification_data_multiclass"),
|
||||
image_channels=["blue"],
|
||||
image_file_column="path",
|
||||
label_channels=["blue"],
|
||||
class_names=["class0", "class1", "class2", "class3", "class4"],
|
||||
label_value_column="label",
|
||||
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
|
||||
num_epochs=num_epochs,
|
||||
num_dataload_workers=0,
|
||||
use_mixed_precision=False,
|
||||
subject_column="ID",
|
||||
image_size=(4, 5, 7)
|
||||
)
|
||||
self.conv_in_3d = True
|
||||
self.expected_image_size_zyx = (4, 5, 7)
|
||||
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
|
||||
self.max_num_gpus = 1
|
||||
|
||||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
return DatasetSplits.from_proportions(
|
||||
df=dataset_df,
|
||||
proportion_train=0.7,
|
||||
proportion_test=0.2,
|
||||
proportion_val=0.1,
|
||||
subject_column=self.subject_column
|
||||
)
|
||||
|
||||
def create_model(self) -> Any:
|
||||
# Use a local import so that we don't need to import pytorch when creating configs in the runner
|
||||
from Tests.ML.models.architectures.DummyScalarModel import DummyScalarModel
|
||||
return DummyScalarModel(self.expected_image_size_zyx, num_classes=len(self.class_names))
|
|
@ -5,6 +5,7 @@
|
|||
import logging
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from multiprocessing import cpu_count
|
||||
|
@ -16,7 +17,6 @@ import pandas as pd
|
|||
import torch
|
||||
from joblib import Parallel, delayed
|
||||
from more_itertools import flatten
|
||||
from rich.progress import track
|
||||
|
||||
from InnerEye.ML.dataset.full_image_dataset import GeneralDataset
|
||||
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
|
||||
|
@ -31,57 +31,87 @@ from InnerEye.ML.utils.transforms import Compose3D, Transform3D
|
|||
T = TypeVar('T', bound=ScalarDataSource)
|
||||
|
||||
|
||||
def extract_label_classification(label_string: Union[str, float], sample_id: str) -> Union[float, int]:
|
||||
def extract_label_classification(label_string: str, sample_id: str, num_classes: int,
|
||||
is_classification_dataset: bool) -> List[float]:
|
||||
"""
|
||||
Converts a string from a dataset.csv file that contains a model's label to a scalar.
|
||||
The function maps ["1", "true", "yes"] to 1, ["0", "false", "no"] to 0.
|
||||
If the entry in the CSV file was missing (no string given at all), it returns math.nan.
|
||||
:param label_string: The value of the label as read from CSV via a DataFrame.
|
||||
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
|
||||
:return:
|
||||
"""
|
||||
if isinstance(label_string, float):
|
||||
if math.isnan(label_string):
|
||||
# When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
|
||||
return label_string
|
||||
else:
|
||||
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
|
||||
f"dataframe column as a string?")
|
||||
if label_string:
|
||||
label_lower = label_string.lower()
|
||||
if label_lower in ["1", "true", "yes"]:
|
||||
return 1
|
||||
if label_lower in ["0", "false", "no"]:
|
||||
return 0
|
||||
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
|
||||
else:
|
||||
return math.nan
|
||||
|
||||
For classification datasets:
|
||||
If num_classes is 1 (binary classification tasks):
|
||||
The function maps ["1", "true", "yes"] to [1], ["0", "false", "no"] to [0].
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
|
||||
If num_classes is greater than 1 (multilabel datasets):
|
||||
The function maps a pipe-separated set of classes to a tensor with ones at the indices
|
||||
of the positive classes and 0 elsewhere (for example if we have a task with 6 label classes,
|
||||
map "1|3|4" to [0, 1, 0, 1, 1, 0]).
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string,
|
||||
this function returns an all-zero tensor (none of the label classes were positive for this sample).
|
||||
|
||||
def extract_label_regression(label_string: Union[str, float], sample_id: str) -> Union[float, int]:
|
||||
"""
|
||||
Converts a string from a dataset.csv file that contains a model's label to a scalar.
|
||||
For regression datasets:
|
||||
The function casts a string label to float. Raises an exception if the conversion is
|
||||
not possible.
|
||||
If the entry in the CSV file was missing (no string given at all), it returns math.nan.
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
|
||||
|
||||
:param label_string: The value of the label as read from CSV via a DataFrame.
|
||||
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
|
||||
:return:
|
||||
:param num_classes: Number of classes. This should be equal the size of the model output.
|
||||
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes should
|
||||
correspond to the number of label classes in the problem.
|
||||
:param is_classification_dataset: If the model is a classification model
|
||||
:return: A list of floats with the same size as num_classes
|
||||
"""
|
||||
|
||||
if num_classes < 1:
|
||||
raise ValueError(f"Subject {sample_id}: Invalid number of classes: '{num_classes}'")
|
||||
|
||||
if isinstance(label_string, float):
|
||||
if math.isnan(label_string):
|
||||
# When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
|
||||
return label_string
|
||||
if num_classes == 1:
|
||||
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
|
||||
return [label_string]
|
||||
else:
|
||||
return [0] * num_classes
|
||||
else:
|
||||
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
|
||||
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
|
||||
f"dataframe column as a string?")
|
||||
if label_string:
|
||||
try:
|
||||
return float(label_string)
|
||||
except ValueError:
|
||||
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
|
||||
|
||||
if not label_string:
|
||||
if not is_classification_dataset or num_classes == 1:
|
||||
return [math.nan]
|
||||
else:
|
||||
return [0] * num_classes
|
||||
|
||||
if is_classification_dataset:
|
||||
if num_classes == 1:
|
||||
label_lower = label_string.lower()
|
||||
if label_lower in ["true", "yes"]:
|
||||
return [1.0]
|
||||
elif label_lower in ["false", "no"]:
|
||||
return [0.0]
|
||||
elif label_string in ["0", "1"]:
|
||||
return [float(label_string)]
|
||||
else:
|
||||
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'. "
|
||||
f"Should be one of true/false, yes/no or 0/1.")
|
||||
|
||||
if '|' in label_string or label_string.isdigit():
|
||||
classes = [int(a) for a in label_string.split('|')]
|
||||
|
||||
out_of_range = [_class for _class in classes if _class >= num_classes]
|
||||
if out_of_range:
|
||||
raise ValueError(f"Subject {sample_id}: Indices {out_of_range} are out of range, for number of classes "
|
||||
f"= {num_classes}")
|
||||
|
||||
one_hot_array = np.zeros(num_classes, dtype=np.float)
|
||||
one_hot_array[classes] = 1.0
|
||||
return one_hot_array.tolist()
|
||||
else:
|
||||
return math.nan
|
||||
try:
|
||||
return [float(label_string)]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
|
||||
|
||||
|
||||
def _get_single_channel_row(subject_rows: pd.DataFrame,
|
||||
|
@ -149,10 +179,12 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
categorical_data_encoder: Optional[CategoricalToOneHotEncoder] = None,
|
||||
metadata_columns: Optional[Set[str]] = None,
|
||||
is_classification_dataset: bool = True,
|
||||
num_classes: int = 1,
|
||||
sequence_position_numeric: Optional[int] = None) -> T:
|
||||
"""
|
||||
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
|
||||
labels, the non-image features, and the paths to the image files.
|
||||
:param num_classes: Number of classes, this is equivalent to model output tensor size
|
||||
:param channel_column: The name of the column that contains the row identifier ("channels")
|
||||
:param metadata_columns: A list of columns that well be added to the item metadata as key/value pairs.
|
||||
:param subject_rows: All dataset rows that belong to the same subject.
|
||||
|
@ -179,11 +211,12 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
return _get_single_channel_row(subject_rows, channel, subject_id, channel_column)
|
||||
|
||||
def _get_label_as_tensor(channel: Optional[str]) -> torch.Tensor:
|
||||
extract_fn = extract_label_classification if is_classification_dataset else extract_label_regression
|
||||
label_row = _get_row_for_channel(channel)
|
||||
label_string = label_row[label_value_column]
|
||||
return torch.tensor([extract_fn(label_string=label_string, sample_id=subject_id)],
|
||||
dtype=torch.float)
|
||||
return torch.tensor(
|
||||
extract_label_classification(label_string=label_string, sample_id=subject_id, num_classes=num_classes,
|
||||
is_classification_dataset=is_classification_dataset),
|
||||
dtype=torch.float)
|
||||
|
||||
def _apply_label_transforms(labels: Any) -> Any:
|
||||
"""
|
||||
|
@ -313,6 +346,7 @@ class DataSourceReader(Generic[T]):
|
|||
subject_column: str = CSV_SUBJECT_HEADER,
|
||||
channel_column: str = CSV_CHANNEL_HEADER,
|
||||
is_classification_dataset: bool = True,
|
||||
num_classes: int = 1,
|
||||
categorical_data_encoder: Optional[CategoricalToOneHotEncoder] = None):
|
||||
"""
|
||||
:param label_value_column: The column that contains the value for the label scalar or vector.
|
||||
|
@ -345,6 +379,7 @@ class DataSourceReader(Generic[T]):
|
|||
self.image_file_column = image_file_column
|
||||
self.label_value_column = label_value_column
|
||||
self.data_frame = data_frame
|
||||
self.num_classes = num_classes
|
||||
self.expected_non_image_channels: Union[List[None], Set[str]]
|
||||
|
||||
if self.non_image_feature_channels is None:
|
||||
|
@ -418,6 +453,7 @@ class DataSourceReader(Generic[T]):
|
|||
sequence_column=sequence_column,
|
||||
subject_column=args.subject_column,
|
||||
channel_column=args.channel_column,
|
||||
num_classes=len(args.class_names),
|
||||
is_classification_dataset=args.is_classification_model
|
||||
).load_data_sources(num_dataset_reader_workers=args.num_dataset_reader_workers)
|
||||
|
||||
|
@ -444,7 +480,7 @@ class DataSourceReader(Generic[T]):
|
|||
_n_jobs = max(1, num_dataset_reader_workers)
|
||||
|
||||
results = Parallel(n_jobs=_n_jobs, backend=_backend)(
|
||||
delayed(self.load_datasources_for_subject)(subject_id) for subject_id in track(subject_ids))
|
||||
delayed(self.load_datasources_for_subject)(subject_id) for subject_id in subject_ids)
|
||||
|
||||
return list(flatten(filter(None, results)))
|
||||
|
||||
|
@ -468,6 +504,7 @@ class DataSourceReader(Generic[T]):
|
|||
metadata_columns=self.metadata_columns,
|
||||
channel_column=self.channel_column,
|
||||
is_classification_dataset=self.is_classification_dataset,
|
||||
num_classes=self.num_classes,
|
||||
sequence_position_numeric=_sequence_position_numeric
|
||||
)
|
||||
|
||||
|
@ -745,15 +782,27 @@ class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
|
|||
Returns a list of all the labels in the dataset. Used to compute
|
||||
the sampling weights in Imbalanced Sampler
|
||||
"""
|
||||
if len(self.args.class_names) > 1:
|
||||
raise NotImplementedError("ImbalancedSampler is not supported for multilabel tasks.")
|
||||
|
||||
return [item.label.item() for item in self.items]
|
||||
|
||||
def get_class_counts(self) -> Dict:
|
||||
def get_class_counts(self) -> Dict[int, int]:
|
||||
"""
|
||||
Return class weights that are proportional to the inverse frequency of label counts.
|
||||
:return: Dictionary of {"label": count}
|
||||
Return the label counts as a dictionary with the key-value pairs being the class indices and per-class counts.
|
||||
In the binary case, the dictionary will have a single element. The key will be 0 as there is only one class and
|
||||
one class index. The value stored will be the number of samples that belong to the positive class.
|
||||
In the multilabel case, this returns a dictionary with class indices and samples per class as the key-value
|
||||
pairs.
|
||||
:return: Dictionary of {class_index: count}
|
||||
"""
|
||||
all_labels = [item.label.item() for item in self.items] # [N, 1]
|
||||
return dict(Counter(all_labels))
|
||||
all_labels = [torch.flatten(torch.nonzero(item.label).int()).tolist() for item in self.items] # [N, 1]
|
||||
flat_list = list(flatten(all_labels))
|
||||
freq_iter: typing.Counter = Counter()
|
||||
freq_iter.update({x: 0 for x in range(len(self.args.class_names))})
|
||||
freq_iter.update(flat_list)
|
||||
result = dict(freq_iter)
|
||||
return result
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
|
|
@ -135,19 +135,8 @@ class ScalarDataSource(ScalarItemBase):
|
|||
:return: An instance of ClassificationItem, with the same label and numerical_non_image_features fields,
|
||||
and all images loaded.
|
||||
"""
|
||||
full_channel_files: List[Path] = []
|
||||
for f in self.channel_files:
|
||||
if f is None:
|
||||
raise ValueError("When loading images, channel_files should no longer contain None entries.")
|
||||
elif file_mapping:
|
||||
if f in file_mapping:
|
||||
full_channel_files.append(file_mapping[str(f)])
|
||||
else:
|
||||
raise ValueError(f"File mapping does not contain an entry for {f}")
|
||||
elif root_path:
|
||||
full_channel_files.append(root_path / f)
|
||||
else:
|
||||
raise ValueError("One of the arguments 'file_mapping' or 'root_path' must be given.")
|
||||
full_channel_files = self.get_all_image_filepaths(root_path=root_path,
|
||||
file_mapping=file_mapping)
|
||||
|
||||
imaging_data = load_images_and_stack(files=full_channel_files,
|
||||
load_segmentation=load_segmentation,
|
||||
|
@ -175,6 +164,47 @@ class ScalarDataSource(ScalarItemBase):
|
|||
def files_valid(self) -> bool:
|
||||
return not any(f is None for f in self.channel_files)
|
||||
|
||||
def get_all_image_filepaths(self,
|
||||
root_path: Optional[Path],
|
||||
file_mapping: Optional[Dict[str, Path]]) -> List[Path]:
|
||||
"""
|
||||
Get a list of image paths for the object. Either root_path or file_mapping must be specified.
|
||||
:param root_path: The root path where all channel files for images are expected. This is ignored if
|
||||
file_mapping is given.
|
||||
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
"""
|
||||
full_channel_files: List[Path] = []
|
||||
for f in self.channel_files:
|
||||
if not f:
|
||||
raise ValueError(f"Got invalid file path: {f}")
|
||||
full_channel_files.append(self.get_full_image_filepath(f, root_path, file_mapping))
|
||||
|
||||
return full_channel_files
|
||||
|
||||
@staticmethod
|
||||
def get_full_image_filepath(file: str,
|
||||
root_path: Optional[Path],
|
||||
file_mapping: Optional[Dict[str, Path]]) -> Path:
|
||||
"""
|
||||
Get the full path of an image file given the path relative to the dataset folder and one of
|
||||
root_path or file_mapping.
|
||||
:param file: Image filepath relative to the dataset folder
|
||||
:param root_path: The root path where all channel files for images are expected. This is ignored if
|
||||
file_mapping is given.
|
||||
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
"""
|
||||
if file is None:
|
||||
raise ValueError("When loading images, channel_files should no longer contain None entries.")
|
||||
elif file_mapping:
|
||||
if file in file_mapping:
|
||||
return file_mapping[file]
|
||||
else:
|
||||
raise ValueError(f"File mapping does not contain an entry for {file}")
|
||||
elif root_path:
|
||||
return root_path / file
|
||||
else:
|
||||
raise ValueError("One of the arguments 'file_mapping' or 'root_path' must be given.")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SequenceDataSource(ScalarDataSource):
|
||||
|
|
|
@ -233,6 +233,9 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
|||
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
|
||||
"sequence index should be read from.")
|
||||
|
||||
if len(self.args.class_names) > 1:
|
||||
raise ValueError("Multilabel configs not supported for sequence datasets.")
|
||||
|
||||
data_sources = self.load_all_data_sources()
|
||||
grouped = group_samples_into_sequences(
|
||||
data_sources,
|
||||
|
@ -278,16 +281,18 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
|||
return [seq.get_labels_at_target_indices(self.args.get_target_indices())[-1].item()
|
||||
for seq in self.items]
|
||||
|
||||
def get_class_counts(self) -> Dict:
|
||||
def get_class_counts(self) -> Dict[int, int]:
|
||||
"""
|
||||
Return class weights that are proportional to the inverse frequency of label counts (summed
|
||||
over all target indices).
|
||||
Return the label counts (summed over all target indices).
|
||||
:return: Dictionary of {"label": count}
|
||||
"""
|
||||
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
|
||||
for seq in self.items]) # [N, T, 1]
|
||||
non_nan_labels = list(filter(lambda x: not np.isnan(x), all_labels_per_target.flatten().tolist()))
|
||||
return dict(Counter(non_nan_labels))
|
||||
non_nan_and_nonzero_labels = list(filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
|
||||
counts = dict(Counter(non_nan_and_nonzero_labels))
|
||||
if not len(counts.keys()) == 1 or 1 not in counts.keys():
|
||||
raise ValueError("get_class_counts supports only binary targets.")
|
||||
return {0: counts[1]}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
|
|
@ -23,7 +23,6 @@ from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode,\
|
|||
create_recovery_checkpoint_path, create_unique_timestamp_id,\
|
||||
get_best_checkpoint_path
|
||||
|
||||
VISUALIZATION_FOLDER = "Visualizations"
|
||||
# A folder inside of the outputs folder that will contain all information for running the model in inference mode
|
||||
FINAL_MODEL_FOLDER = "final_model"
|
||||
FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
|
||||
|
@ -31,6 +30,8 @@ FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
|
|||
# The checkpoints must be stored inside of the final model folder, if we want to avoid copying
|
||||
# them before registration.
|
||||
CHECKPOINT_FOLDER = "checkpoints"
|
||||
VISUALIZATION_FOLDER = "visualizations"
|
||||
|
||||
ARGS_TXT = "args.txt"
|
||||
WEIGHTS_FILE = "weights.pth"
|
||||
|
||||
|
|
|
@ -185,9 +185,10 @@ class ScalarLightning(InnerEyeLightning):
|
|||
else:
|
||||
self.loss_fn = raw_loss
|
||||
self.target_indices = []
|
||||
self.target_names = [MetricsDict.DEFAULT_HUE_KEY]
|
||||
self.target_names = config.class_names
|
||||
self.is_classification_model = config.is_classification_model
|
||||
self.use_mean_teacher_model = config.compute_mean_teacher_model
|
||||
self.is_binary_classification_or_regression = True if len(config.class_names) == 1 else False
|
||||
self.logits_to_posterior_fn = config.get_post_loss_logits_normalization_function()
|
||||
self.loss_type = config.loss_type
|
||||
# These two fields store the PyTorch Lightning Metrics objects that will compute metrics on validation
|
||||
|
@ -338,7 +339,8 @@ class ScalarLightning(InnerEyeLightning):
|
|||
metric_computers = self.train_metric_computers if is_training else self.val_metric_computers
|
||||
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
|
||||
for prediction_target, metric_list in metric_computers.items():
|
||||
target_suffix = "" if prediction_target == MetricsDict.DEFAULT_HUE_KEY else f"/{prediction_target}"
|
||||
target_suffix = "" if (prediction_target == MetricsDict.DEFAULT_HUE_KEY
|
||||
or self.is_binary_classification_or_regression) else f"/{prediction_target}"
|
||||
for metric in metric_list:
|
||||
if metric.has_predictions:
|
||||
# Sequence models can have no predictions at all for particular positions, depending on the data.
|
||||
|
|
|
@ -339,12 +339,22 @@ def store_epoch_metrics(metrics: DictStrFloat,
|
|||
"""
|
||||
logger_row = {}
|
||||
for key, value in metrics.items():
|
||||
if key == MetricType.SECONDS_PER_BATCH.value or key == MetricType.SECONDS_PER_EPOCH.value:
|
||||
continue
|
||||
if key in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
|
||||
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[key].value] = value
|
||||
tokens = key.split("/")
|
||||
if len(tokens) == 1:
|
||||
metric_name = tokens[0]
|
||||
hue_suffix = ""
|
||||
elif len(tokens) == 2:
|
||||
metric_name = tokens[0]
|
||||
hue_suffix = "/" + tokens[1]
|
||||
else:
|
||||
logger_row[key] = value
|
||||
raise ValueError(f"Expected key to have format 'metric_name[/optional_suffix_for_hue]', got {key}")
|
||||
|
||||
if metric_name == MetricType.SECONDS_PER_BATCH.value or metric_name == MetricType.SECONDS_PER_EPOCH.value:
|
||||
continue
|
||||
if metric_name in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
|
||||
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[metric_name].value + hue_suffix] = value
|
||||
else:
|
||||
logger_row[metric_name + hue_suffix] = value
|
||||
logger_row[LoggingColumns.Epoch.value] = epoch
|
||||
file_logger.add_record(logger_row)
|
||||
file_logger.flush()
|
||||
|
|
|
@ -192,9 +192,11 @@ class MetricsDict:
|
|||
default hue.
|
||||
:param is_classification_metrics: If this is a classification metrics dict
|
||||
"""
|
||||
if hues and MetricsDict.DEFAULT_HUE_KEY in hues:
|
||||
hues.remove(MetricsDict.DEFAULT_HUE_KEY)
|
||||
self.hues_without_default = hues or []
|
||||
|
||||
_hues = hues.copy() if hues else None
|
||||
if _hues and MetricsDict.DEFAULT_HUE_KEY in _hues:
|
||||
_hues.remove(MetricsDict.DEFAULT_HUE_KEY)
|
||||
self.hues_without_default = _hues or []
|
||||
_hue_keys = self.hues_without_default + [MetricsDict.DEFAULT_HUE_KEY]
|
||||
self.hues: OrderedDict[str, Hue] = OrderedDict([(x, Hue(name=x)) for x in _hue_keys])
|
||||
self.skip_nan_when_averaging: Dict[str, bool] = dict()
|
||||
|
|
|
@ -387,7 +387,8 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
|
|||
return SequenceMetricsDict.create(is_classification_model=config.is_classification_model,
|
||||
sequence_target_positions=config.sequence_target_positions)
|
||||
else:
|
||||
return ScalarMetricsDict(is_classification_metrics=config.is_classification_model)
|
||||
return ScalarMetricsDict(hues=config.class_names,
|
||||
is_classification_metrics=config.is_classification_model)
|
||||
|
||||
|
||||
def classification_model_test(config: ScalarModelBase,
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%javascript\n",
|
||||
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
|
||||
" return false;\n",
|
||||
"}\n",
|
||||
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"tags": [
|
||||
"parameters"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
|
||||
"# This cell must carry the tag \"parameters\" in its metadata.\n",
|
||||
"from pathlib import Path\n",
|
||||
"import pickle\n",
|
||||
"import codecs\n",
|
||||
"\n",
|
||||
"innereye_path = Path.cwd().parent.parent.parent\n",
|
||||
"train_metrics_csv = \"\"\n",
|
||||
"val_metrics_csv = \"\"\n",
|
||||
"test_metrics_csv = \"\"\n",
|
||||
"number_best_and_worst_performing = 20\n",
|
||||
"config= \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"print(f\"Adding to path: {innereye_path}\")\n",
|
||||
"if str(innereye_path) not in sys.path:\n",
|
||||
" sys.path.append(str(innereye_path))\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
|
||||
"\n",
|
||||
"from InnerEye.ML.reports.notebook_report import print_header\n",
|
||||
"from InnerEye.ML.reports.classification_multilabel_report import print_metrics_for_thresholded_output_for_all_prediction_targets\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(\"ignore\")\n",
|
||||
"plt.rcParams['figure.figsize'] = (20, 10)\n",
|
||||
"\n",
|
||||
"#convert params to Path\n",
|
||||
"train_metrics_csv = Path(train_metrics_csv)\n",
|
||||
"val_metrics_csv = Path(val_metrics_csv)\n",
|
||||
"test_metrics_csv = Path(test_metrics_csv)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train Metrics (for label combinations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=train_metrics_csv,\n",
|
||||
" test_metrics_csv=train_metrics_csv,\n",
|
||||
" config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Validation Metrics (for label combinations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=val_metrics_csv,\n",
|
||||
" test_metrics_csv=val_metrics_csv,\n",
|
||||
" config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Test Metrics (for label combinations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=val_metrics_csv,\n",
|
||||
" test_metrics_csv=test_metrics_csv,\n",
|
||||
" config=config)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"celltoolbar": "Tags",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import List, Set, FrozenSet
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import math
|
||||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.reports.classification_report import LabelsAndPredictions, print_metrics, get_labels_and_predictions, \
|
||||
get_metric, ReportedMetrics
|
||||
from InnerEye.ML.reports.notebook_report import print_header
|
||||
|
||||
|
||||
def get_unique_prediction_target_combinations(config: ScalarModelBase) -> Set[FrozenSet[str]]:
|
||||
"""
|
||||
Get a list of all the combinations of labels that exist in the dataset.
|
||||
|
||||
For multilabel classification tasks, this function will return all unique combinations of labels that
|
||||
occur in the dataset csv.
|
||||
For example, if there are 6 samples in the dataset with the following ground truth labels
|
||||
Sample1: class1, class2
|
||||
Sample2: class0
|
||||
Sample3: class1
|
||||
Sample4: class2, class3
|
||||
Sample5: (all label classes are negative in Sample 5)
|
||||
Sample6: class1, class2
|
||||
This function will return {{"class1", "class2"}, {"class0"}, {"class1"}, {"class2", "class3"}, {}}
|
||||
|
||||
For binary classification tasks (assume class_names has not been changed from ["Default"]):
|
||||
This function will return a set with two members - {{"Default"}, {}} if there is at least one positive example
|
||||
in the dataset. If there are no positive examples, it returns {{}}.
|
||||
"""
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
all_labels = [torch.flatten(torch.nonzero(item.label)).tolist() for item in dataset.items]
|
||||
label_set = set(frozenset([config.class_names[i] for i in labels if not math.isnan(i)])
|
||||
for labels in all_labels)
|
||||
|
||||
return label_set
|
||||
|
||||
|
||||
def get_dataframe_with_exact_label_matches(metrics_df: pd.DataFrame,
|
||||
prediction_target_set_to_match: List[str],
|
||||
all_prediction_targets: List[str],
|
||||
thresholds_per_prediction_target: List[float]) -> pd.DataFrame:
|
||||
"""
|
||||
Given a set of prediction targets (a subset of the classes in the classification task), for each sample find
|
||||
(i) if the set of ground truth labels matches this set exactly,
|
||||
(ii) if the predicted model outputs (after thresholding) match this set exactly
|
||||
|
||||
Generates an output dataframe with the rows:
|
||||
LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput, LoggingColumns.Hue
|
||||
|
||||
The output dataframe is generated according to the following rules:
|
||||
- LoggingColumns.Patient: For each sample, the sample id is copied over into this field
|
||||
- LoggingColumns.Label: For each sample, this field is set to 1 if the ground truth value is true
|
||||
for every prediction target (i.e. every class) in the given set and false for all other prediction
|
||||
targets. It is set to 0 otherwise.
|
||||
- LoggingColumns.ModelOutput: For each sample, this field is set to 1 if the model predicts a value exceeding the
|
||||
prediction target threshold for every prediction target in the given set and lower for all other prediction
|
||||
targets. It is set to 0 otherwise.
|
||||
- LoggingColumns.Hue: For every sample, this is set to "|".join(prediction_target_set_to_match)
|
||||
|
||||
:param metrics_df: Dataframe with the model predictions (read from the csv written by the inference pipeline)
|
||||
The dataframe must have at least the following columns (defined in the LoggingColumns enum):
|
||||
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
|
||||
Any other columns will be ignored.
|
||||
:param prediction_target_set_to_match: The set of prediction targets to which each sample is compared
|
||||
:param all_prediction_targets: The entire set of prediction targets on which the model is trained
|
||||
:param thresholds_per_prediction_target: Thresholds per prediction target to decide if model has predicted True or
|
||||
False for the specific prediction target
|
||||
:return: Dataframe with generated label and model outputs per sample
|
||||
"""
|
||||
|
||||
def get_exact_label_match(df: pd.DataFrame) -> pd.DataFrame:
|
||||
values_to_return = {LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]]}
|
||||
|
||||
pred_positives = df[df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.ModelOutput.value].values
|
||||
pred_negatives = df[~df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.ModelOutput.value].values
|
||||
|
||||
if all(pred_positives) and not any(pred_negatives):
|
||||
values_to_return[LoggingColumns.ModelOutput.value] = [1]
|
||||
else:
|
||||
values_to_return[LoggingColumns.ModelOutput.value] = [0]
|
||||
|
||||
true_positives = df[df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.Label.value].values
|
||||
true_negatives = df[~df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.Label.value].values
|
||||
|
||||
if all(true_positives) and not any(true_negatives):
|
||||
values_to_return[LoggingColumns.Label.value] = [1]
|
||||
else:
|
||||
values_to_return[LoggingColumns.Label.value] = [0]
|
||||
|
||||
return pd.DataFrame.from_dict(values_to_return)
|
||||
|
||||
df = metrics_df.copy()
|
||||
for i in range(len(thresholds_per_prediction_target)):
|
||||
df_for_prediction_target = df[LoggingColumns.Hue.value] == all_prediction_targets[i]
|
||||
df.loc[df_for_prediction_target, LoggingColumns.ModelOutput.value] = \
|
||||
df.loc[df_for_prediction_target, LoggingColumns.ModelOutput.value] > thresholds_per_prediction_target[i]
|
||||
|
||||
df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_exact_label_match).reset_index(drop=True)
|
||||
df[LoggingColumns.Hue.value] = ["|".join(prediction_target_set_to_match)] * len(df)
|
||||
return df
|
||||
|
||||
|
||||
def get_labels_and_predictions_for_prediction_target_set(csv: Path,
|
||||
prediction_target_set_to_match: List[str],
|
||||
all_prediction_targets: List[str],
|
||||
thresholds_per_prediction_target: List[float]) -> LabelsAndPredictions:
|
||||
"""
|
||||
Given a CSV file, generate a set of labels and model predictions for the given set of prediction targets
|
||||
(in other words, for the given subset of the classes in the classification task).
|
||||
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
|
||||
like the ones written while training.
|
||||
"""
|
||||
metrics_df = pd.read_csv(csv)
|
||||
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
|
||||
prediction_target_set_to_match=prediction_target_set_to_match,
|
||||
all_prediction_targets=all_prediction_targets,
|
||||
thresholds_per_prediction_target=thresholds_per_prediction_target)
|
||||
|
||||
labels = df[LoggingColumns.Label.value].to_numpy()
|
||||
model_outputs = df[LoggingColumns.ModelOutput.value].to_numpy()
|
||||
subjects = df[LoggingColumns.Patient.value].to_numpy()
|
||||
return LabelsAndPredictions(subject_ids=subjects, labels=labels, model_outputs=model_outputs)
|
||||
|
||||
|
||||
def print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv: Path,
|
||||
test_metrics_csv: Path,
|
||||
config: ScalarModelBase) -> None:
|
||||
|
||||
"""
|
||||
Given csvs written during inference for the validation and test sets, print out metrics for every combination of
|
||||
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).
|
||||
|
||||
:param val_metrics_csv: Csv written during inference time for the val set. This is used to determine the
|
||||
optimal threshold for classification.
|
||||
:param test_metrics_csv: Csv written during inference time for the test set. Metrics are calculated for this csv.
|
||||
:param config: Model config
|
||||
"""
|
||||
|
||||
unique_prediction_target_combinations = get_unique_prediction_target_combinations(config)
|
||||
all_prediction_target_combinations = list(set([frozenset([prediction_target])
|
||||
for prediction_target in config.class_names])
|
||||
| unique_prediction_target_combinations)
|
||||
|
||||
thresholds_per_prediction_target = []
|
||||
for label in config.class_names:
|
||||
val_metrics = get_labels_and_predictions(val_metrics_csv, label)
|
||||
test_metrics = get_labels_and_predictions(test_metrics_csv, label)
|
||||
thresholds_per_prediction_target.append(get_metric(val_labels_and_predictions=val_metrics,
|
||||
test_labels_and_predictions=test_metrics,
|
||||
metric=ReportedMetrics.OptimalThreshold))
|
||||
|
||||
for labels in all_prediction_target_combinations:
|
||||
print_header(f"Class {'|'.join(labels) or 'Negative'}", level=3)
|
||||
val_metrics = get_labels_and_predictions_for_prediction_target_set(
|
||||
csv=val_metrics_csv,
|
||||
prediction_target_set_to_match=list(labels),
|
||||
all_prediction_targets=config.class_names,
|
||||
thresholds_per_prediction_target=thresholds_per_prediction_target)
|
||||
test_metrics = get_labels_and_predictions_for_prediction_target_set(
|
||||
csv=test_metrics_csv,
|
||||
prediction_target_set_to_match=list(labels),
|
||||
all_prediction_targets=config.class_names,
|
||||
thresholds_per_prediction_target=thresholds_per_prediction_target)
|
||||
print_metrics(val_labels_and_predictions=val_metrics, test_labels_and_predictions=test_metrics, is_thresholded=True)
|
|
@ -31,14 +31,15 @@
|
|||
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
|
||||
"# This cell must carry the tag \"parameters\" in its metadata.\n",
|
||||
"from pathlib import Path\n",
|
||||
"import pickle\n",
|
||||
"import codecs\n",
|
||||
"\n",
|
||||
"innereye_path = Path.cwd().parent.parent.parent\n",
|
||||
"train_metrics_csv = \"\"\n",
|
||||
"val_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'val_metrics_classification.csv'\n",
|
||||
"test_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'test_metrics_classification.csv'\n",
|
||||
"number_best_and_worst_performing = 20\n",
|
||||
"dataset_csv_path=innereye_path / 'Tests' / 'ML' / 'reports' / 'dataset.csv'\n",
|
||||
"dataset_subject_column=\"subject\"\n",
|
||||
"dataset_file_column=\"filePath\""
|
||||
"config= \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -52,15 +53,20 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"print(f\"Adding to path: {innereye_path}\")\n",
|
||||
"if str(innereye_path) not in sys.path:\n",
|
||||
" sys.path.append(str(innereye_path))\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
|
||||
"\n",
|
||||
"from InnerEye.ML.reports.notebook_report import print_header\n",
|
||||
"from InnerEye.ML.reports.classification_report import plot_pr_and_roc_curves_from_csv, \\\n",
|
||||
"print_k_best_and_worst_performing, print_metrics, plot_k_best_and_worst_performing\n",
|
||||
"print_k_best_and_worst_performing, print_metrics_for_all_prediction_targets, \\\n",
|
||||
"plot_k_best_and_worst_performing, get_labels_and_predictions\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(\"ignore\")\n",
|
||||
|
@ -69,8 +75,7 @@
|
|||
"#convert params to Path\n",
|
||||
"train_metrics_csv = Path(train_metrics_csv)\n",
|
||||
"val_metrics_csv = Path(val_metrics_csv)\n",
|
||||
"test_metrics_csv = Path(test_metrics_csv)\n",
|
||||
"dataset_csv_path = Path(dataset_csv_path)"
|
||||
"test_metrics_csv = Path(test_metrics_csv)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -78,7 +83,7 @@
|
|||
"id": "4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Validation Metrics"
|
||||
"# Train Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -88,8 +93,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv)"
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -97,7 +103,7 @@
|
|||
"id": "6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Test Metrics"
|
||||
"# Validation Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -107,13 +113,34 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv)"
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Test Metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" config=config, is_thresholded=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
|
@ -121,30 +148,7 @@
|
|||
},
|
||||
"source": [
|
||||
"# AUC and PR curves\n",
|
||||
"## Test Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if test_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(test_metrics_csv)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Validation set"
|
||||
"## Train Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -158,20 +162,16 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(val_metrics_csv)"
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=train_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training set"
|
||||
"## Validation set"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -185,28 +185,35 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(train_metrics_csv)"
|
||||
"if val_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=val_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Best and worst samples by ID"
|
||||
"## Test set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, \n",
|
||||
" k=number_best_and_worst_performing)"
|
||||
"if test_metrics_csv.is_file():\n",
|
||||
" plot_pr_and_roc_curves_from_csv(metrics_csv=test_metrics_csv, config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -214,7 +221,7 @@
|
|||
"id": "16",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Plot best and worst sample images"
|
||||
"# Best and worst samples by ID"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -224,20 +231,35 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file() and dataset_csv_path.is_file() and \\\n",
|
||||
" dataset_subject_column and dataset_file_column:\n",
|
||||
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, \n",
|
||||
" k=number_best_and_worst_performing, dataset_csv_path=dataset_csv_path,\n",
|
||||
" dataset_subject_column=dataset_subject_column, dataset_file_column=dataset_file_column)"
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" for prediction_target in config.class_names:\n",
|
||||
" print_header(f\"Class {prediction_target}\", level=3)\n",
|
||||
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" k=number_best_and_worst_performing,\n",
|
||||
" prediction_target=prediction_target)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Plot best and worst sample images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18",
|
||||
"id": "19",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
|
||||
" for prediction_target in config.class_names:\n",
|
||||
" print_header(f\"Class {prediction_target}\", level=3)\n",
|
||||
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
|
||||
" k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -262,4 +284,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
|
@ -3,10 +3,11 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import math
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -20,6 +21,8 @@ from InnerEye.Common.metrics_constants import LoggingColumns
|
|||
from InnerEye.ML.metrics_dict import MetricsDict, binary_classification_accuracy
|
||||
from InnerEye.ML.reports.notebook_report import print_header
|
||||
from InnerEye.ML.utils.io_util import load_image_in_known_formats
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -49,19 +52,35 @@ class ReportedMetrics(Enum):
|
|||
FalseNegativeRate = "false_negative_rate"
|
||||
|
||||
|
||||
def get_results(csv: Path) -> LabelsAndPredictions:
|
||||
def read_csv_and_filter_prediction_target(csv: Path, prediction_target: str) -> pd.DataFrame:
|
||||
"""
|
||||
Given a CSV file, reads the subject IDs, ground truth labels and model outputs for each subject.
|
||||
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
|
||||
like the ones written while training.
|
||||
Given one of the csv files written during inference time, read it and select only those rows which belong to the
|
||||
given prediction_target. Also check that there is only a single entry per prediction_target per subject in the file.
|
||||
The csv must have at least the following columns (defined in the LoggingColumns enum):
|
||||
LoggingColumns.Hue, LoggingColumns.Patient.
|
||||
"""
|
||||
df = pd.read_csv(csv)
|
||||
labels = df[LoggingColumns.Label.value]
|
||||
model_outputs = df[LoggingColumns.ModelOutput.value]
|
||||
subjects = df[LoggingColumns.Patient.value]
|
||||
if not subjects.is_unique:
|
||||
df = df[df[LoggingColumns.Hue.value] == prediction_target] # Filter by prediction target
|
||||
if not df[LoggingColumns.Patient.value].is_unique:
|
||||
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
|
||||
f"in column {LoggingColumns.Patient.value} in the csv file.")
|
||||
return df
|
||||
|
||||
|
||||
def get_labels_and_predictions(csv: Path, prediction_target: str) -> LabelsAndPredictions:
|
||||
"""
|
||||
Given a CSV file, reads the subject IDs, ground truth labels and model outputs for each subject
|
||||
for the given prediction target.
|
||||
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
|
||||
like the ones written while training. It must have at least the following columns (defined in the LoggingColumns
|
||||
enum):
|
||||
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
|
||||
|
||||
"""
|
||||
df = read_csv_and_filter_prediction_target(csv, prediction_target)
|
||||
labels = df[LoggingColumns.Label.value].to_numpy()
|
||||
model_outputs = df[LoggingColumns.ModelOutput.value].to_numpy()
|
||||
subjects = df[LoggingColumns.Patient.value].to_numpy()
|
||||
return LabelsAndPredictions(subject_ids=subjects, labels=labels, model_outputs=model_outputs)
|
||||
|
||||
|
||||
|
@ -85,116 +104,172 @@ def plot_auc(x_values: np.ndarray, y_values: np.ndarray, title: str, ax: Axes, p
|
|||
ax.annotate(f"{x:0.3f}, {y:0.3f}", xy=(x, y), xytext=(15, 0), textcoords='offset points')
|
||||
|
||||
|
||||
def plot_pr_and_roc_curves_from_csv(metrics_csv: Path) -> None:
|
||||
def plot_pr_and_roc_curves(labels_and_model_outputs: LabelsAndPredictions) -> None:
|
||||
"""
|
||||
Given a csv file, read the predicted values and ground truth labels and plot the ROC and PR curves.
|
||||
Given a LabelsAndPredictions object, plot the ROC and PR curves.
|
||||
"""
|
||||
print_header("ROC and PR curves", level=3)
|
||||
results = get_results(metrics_csv)
|
||||
|
||||
_, ax = plt.subplots(1, 2)
|
||||
|
||||
fpr, tpr, thresholds = roc_curve(results.labels, results.model_outputs)
|
||||
fpr, tpr, thresholds = roc_curve(labels_and_model_outputs.labels, labels_and_model_outputs.model_outputs)
|
||||
|
||||
plot_auc(fpr, tpr, "ROC Curve", ax[0])
|
||||
precision, recall, thresholds = precision_recall_curve(results.labels, results.model_outputs)
|
||||
precision, recall, thresholds = precision_recall_curve(labels_and_model_outputs.labels,
|
||||
labels_and_model_outputs.model_outputs)
|
||||
plot_auc(recall, precision, "PR Curve", ax[1])
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_metric(val_metrics_csv: Path, test_metrics_csv: Path, metric: ReportedMetrics) -> float:
|
||||
def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase) -> None:
|
||||
"""
|
||||
Given a csv file, read the predicted values and ground truth labels and return the specified metric.
|
||||
Given the csv written during inference time and the model config,
|
||||
plot the ROC and PR curves for all prediction targets.
|
||||
"""
|
||||
results_val = get_results(val_metrics_csv)
|
||||
fpr, tpr, thresholds = roc_curve(results_val.labels, results_val.model_outputs)
|
||||
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
|
||||
optimal_threshold = thresholds[optimal_idx]
|
||||
for prediction_target in config.class_names:
|
||||
print_header(f"Class {prediction_target}", level=3)
|
||||
metrics = get_labels_and_predictions(metrics_csv, prediction_target)
|
||||
plot_pr_and_roc_curves(metrics)
|
||||
|
||||
|
||||
def get_metric(val_labels_and_predictions: LabelsAndPredictions,
|
||||
test_labels_and_predictions: LabelsAndPredictions,
|
||||
metric: ReportedMetrics,
|
||||
optimal_threshold: Optional[float] = None) -> float:
|
||||
"""
|
||||
Given LabelsAndPredictions objects for the validation and test sets, return the specified metric.
|
||||
:param val_labels_and_predictions: This set of ground truth labels and model predictions is used to determine the
|
||||
optimal threshold for classification.
|
||||
:param test_labels_and_predictions: The set of labels and model outputs to calculate metrics for.
|
||||
:param metric: The name of the metric to calculate.
|
||||
:param optimal_threshold: If provided, use this threshold instead of calculating an optimal threshold.
|
||||
"""
|
||||
if not optimal_threshold:
|
||||
fpr, tpr, thresholds = roc_curve(val_labels_and_predictions.labels, val_labels_and_predictions.model_outputs)
|
||||
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
|
||||
optimal_threshold = thresholds[optimal_idx]
|
||||
|
||||
assert optimal_threshold # for mypy, we have already calculated optimal threshold if it was set to None
|
||||
|
||||
if metric is ReportedMetrics.OptimalThreshold:
|
||||
return optimal_threshold
|
||||
|
||||
results_test = get_results(test_metrics_csv)
|
||||
only_one_class_present = len(set(results_test.labels)) < 2
|
||||
only_one_class_present = len(set(test_labels_and_predictions.labels)) < 2
|
||||
|
||||
if metric is ReportedMetrics.AUC_ROC:
|
||||
return math.nan if only_one_class_present else roc_auc_score(results_test.labels, results_test.model_outputs)
|
||||
return math.nan if only_one_class_present else roc_auc_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs)
|
||||
elif metric is ReportedMetrics.AUC_PR:
|
||||
if only_one_class_present:
|
||||
return math.nan
|
||||
precision, recall, _ = precision_recall_curve(results_test.labels, results_test.model_outputs)
|
||||
precision, recall, _ = precision_recall_curve(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs)
|
||||
return auc(recall, precision)
|
||||
elif metric is ReportedMetrics.Accuracy:
|
||||
return binary_classification_accuracy(model_output=results_test.model_outputs,
|
||||
label=results_test.labels,
|
||||
return binary_classification_accuracy(model_output=test_labels_and_predictions.model_outputs,
|
||||
label=test_labels_and_predictions.labels,
|
||||
threshold=optimal_threshold)
|
||||
elif metric is ReportedMetrics.FalsePositiveRate:
|
||||
tnr = recall_score(results_test.labels, results_test.model_outputs >= optimal_threshold, pos_label=0)
|
||||
tnr = recall_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs >= optimal_threshold, pos_label=0)
|
||||
return 1 - tnr
|
||||
elif metric is ReportedMetrics.FalseNegativeRate:
|
||||
return 1 - recall_score(results_test.labels, results_test.model_outputs >= optimal_threshold)
|
||||
return 1 - recall_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs >= optimal_threshold)
|
||||
else:
|
||||
raise ValueError("Unknown metric")
|
||||
|
||||
|
||||
def print_metrics(val_metrics_csv: Path, test_metrics_csv: Path) -> None:
|
||||
def print_metrics(val_labels_and_predictions: LabelsAndPredictions,
|
||||
test_labels_and_predictions: LabelsAndPredictions,
|
||||
is_thresholded: bool = False) -> None:
|
||||
"""
|
||||
Given a csv file, read the predicted values and ground truth labels and print out some metrics.
|
||||
Given LabelsAndPredictions objects for the validation and test sets, print out some metrics.
|
||||
:param val_labels_and_predictions: LabelsAndPredictions object for the val set. This is used to determine the
|
||||
optimal threshold for classification.
|
||||
:param test_labels_and_predictions: LabelsAndPredictions object for the test set. Metrics are calculated for this
|
||||
set.
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
or are floating point numbers.
|
||||
:return:
|
||||
"""
|
||||
roc_auc = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.AUC_ROC)
|
||||
print_header(f"Area under ROC Curve: {roc_auc:.4f}", level=4)
|
||||
|
||||
pr_auc = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.AUC_PR)
|
||||
print_header(f"Area under PR Curve: {pr_auc:.4f}", level=4)
|
||||
optimal_threshold = 0.5 if is_thresholded else None
|
||||
|
||||
optimal_threshold = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.OptimalThreshold)
|
||||
if not is_thresholded:
|
||||
roc_auc = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.AUC_ROC)
|
||||
print_header(f"Area under ROC Curve: {roc_auc:.4f}", level=4)
|
||||
|
||||
print_header(f"Optimal threshold: {optimal_threshold: .4f}", level=4)
|
||||
pr_auc = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.AUC_PR)
|
||||
print_header(f"Area under PR Curve: {pr_auc:.4f}", level=4)
|
||||
|
||||
accuracy = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.Accuracy)
|
||||
optimal_threshold = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.OptimalThreshold)
|
||||
|
||||
print_header(f"Optimal threshold: {optimal_threshold: .4f}", level=4)
|
||||
|
||||
accuracy = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.Accuracy,
|
||||
optimal_threshold=optimal_threshold)
|
||||
print_header(f"Accuracy at optimal threshold: {accuracy:.4f}", level=4)
|
||||
|
||||
fpr = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.FalsePositiveRate)
|
||||
fpr = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.FalsePositiveRate,
|
||||
optimal_threshold=optimal_threshold)
|
||||
print_header(f"Specificity at optimal threshold: {1 - fpr:.4f}", level=4)
|
||||
|
||||
fnr = get_metric(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
metric=ReportedMetrics.FalseNegativeRate)
|
||||
fnr = get_metric(val_labels_and_predictions=val_labels_and_predictions,
|
||||
test_labels_and_predictions=test_labels_and_predictions,
|
||||
metric=ReportedMetrics.FalseNegativeRate,
|
||||
optimal_threshold=optimal_threshold)
|
||||
print_header(f"Sensitivity at optimal threshold: {1 - fnr:.4f}", level=4)
|
||||
print_header("", level=4)
|
||||
|
||||
|
||||
def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path) -> Results:
|
||||
def print_metrics_for_all_prediction_targets(val_metrics_csv: Path,
|
||||
test_metrics_csv: Path,
|
||||
config: ScalarModelBase,
|
||||
is_thresholded: bool = False) -> None:
|
||||
"""
|
||||
Given csvs written during inference for the validation and test sets, print out metrics for every prediction target
|
||||
in the config.
|
||||
|
||||
:param val_metrics_csv: Csv written during inference time for the val set. This is used to determine the
|
||||
optimal threshold for classification.
|
||||
:param test_metrics_csv: Csv written during inference time for the test set. Metrics are calculated for this csv.
|
||||
:param config: Model config
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
or are floating point numbers.
|
||||
"""
|
||||
|
||||
for prediction_target in config.class_names:
|
||||
print_header(f"Class {prediction_target}", level=3)
|
||||
val_metrics = get_labels_and_predictions(val_metrics_csv, prediction_target)
|
||||
test_metrics = get_labels_and_predictions(test_metrics_csv, prediction_target)
|
||||
print_metrics(val_labels_and_predictions=val_metrics, test_labels_and_predictions=test_metrics, is_thresholded=is_thresholded)
|
||||
|
||||
|
||||
def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path,
|
||||
prediction_target: str = "Default") -> Results:
|
||||
"""
|
||||
Given the paths to the metrics files for the validation and test sets, get a list of true positives,
|
||||
false positives, false negatives and true negatives.
|
||||
The threshold for classification is obtained by looking at the validation file, and applied to the test set to get
|
||||
label predictions.
|
||||
"""
|
||||
df_val = pd.read_csv(val_metrics_csv)
|
||||
The validation and test csvs must have at least the following columns (defined in the LoggingColumns enum):
|
||||
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
|
||||
|
||||
if not df_val[LoggingColumns.Patient.value].is_unique:
|
||||
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
|
||||
f"in column {LoggingColumns.Patient.value} in the csv file.")
|
||||
"""
|
||||
df_val = read_csv_and_filter_prediction_target(val_metrics_csv, prediction_target)
|
||||
|
||||
fpr, tpr, thresholds = roc_curve(df_val[LoggingColumns.Label.value], df_val[LoggingColumns.ModelOutput.value])
|
||||
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
|
||||
optimal_threshold = thresholds[optimal_idx]
|
||||
|
||||
df_test = pd.read_csv(test_metrics_csv)
|
||||
|
||||
if not df_test[LoggingColumns.Patient.value].is_unique:
|
||||
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
|
||||
f"in column {LoggingColumns.Patient.value} in the csv file.")
|
||||
df_test = read_csv_and_filter_prediction_target(test_metrics_csv, prediction_target)
|
||||
|
||||
df_test["predicted"] = df_test.apply(lambda x: int(x[LoggingColumns.ModelOutput.value] >= optimal_threshold),
|
||||
axis=1)
|
||||
|
@ -210,13 +285,15 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c
|
|||
false_negatives=false_negatives)
|
||||
|
||||
|
||||
def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int) -> Results:
|
||||
def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int,
|
||||
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results:
|
||||
"""
|
||||
Get the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
|
||||
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
|
||||
"""
|
||||
results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv)
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
prediction_target=prediction_target)
|
||||
|
||||
# sort by model_output
|
||||
sorted = Results(true_positives=results.true_positives.sort_values(by=LoggingColumns.ModelOutput.value,
|
||||
|
@ -230,14 +307,22 @@ def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pat
|
|||
return sorted
|
||||
|
||||
|
||||
def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int) -> None:
|
||||
def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, prediction_target: str) -> None:
|
||||
"""
|
||||
Print the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
|
||||
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
|
||||
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
|
||||
used to determine the thresholds for predicting labels on the test set. The best and worst
|
||||
performing subjects will not be printed out for this csv.
|
||||
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
|
||||
best and worst performing subjects will be printed out.
|
||||
:param k: Number of subjects of each category to print out.
|
||||
:param prediction_target: The class label to filter on
|
||||
"""
|
||||
results = get_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
k=k)
|
||||
k=k,
|
||||
prediction_target=prediction_target)
|
||||
|
||||
print_header(f"Top {k} false positives", level=2)
|
||||
for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value],
|
||||
|
@ -261,33 +346,57 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: P
|
|||
|
||||
|
||||
def get_image_filepath_from_subject_id(subject_id: str,
|
||||
dataset_df: pd.DataFrame,
|
||||
dataset_subject_column: str,
|
||||
dataset_file_column: str,
|
||||
dataset_dir: Path) -> Optional[Path]:
|
||||
dataset: ScalarDataset,
|
||||
config: ScalarModelBase) -> List[Path]:
|
||||
"""
|
||||
Returns the filepath for the image associated with a subject. If the subject is not found, return None.
|
||||
If the csv contains multiple entries per subject (which may happen if the csv uses the channels column) then
|
||||
return None as we do not support these csv types yet.
|
||||
Return the filepaths for images associated with a subject. If the subject is not found, raises a ValueError.
|
||||
:param subject_id: Subject to retrive image for
|
||||
:param dataset_df: Dataset dataframe (from the datset.csv file)
|
||||
:param dataset_subject_column: Name of the column with the subject IDs
|
||||
:param dataset_file_column: Name of the column with the image filepaths
|
||||
:param dataset_dir: Path to the dataset
|
||||
:return: path to the image file for the patient or None if it is not found.
|
||||
:param dataset: scalar dataset object
|
||||
:param config: model config
|
||||
:return: List of paths to the image files for the patient.
|
||||
"""
|
||||
for item in dataset.items:
|
||||
if item.metadata.id == subject_id:
|
||||
return item.get_all_image_filepaths(root_path=config.local_dataset,
|
||||
file_mapping=dataset.file_to_full_path)
|
||||
|
||||
raise ValueError(f"Could not find subject {subject_id} in the dataset.")
|
||||
|
||||
|
||||
def get_image_labels_from_subject_id(subject_id: str,
|
||||
dataset: ScalarDataset,
|
||||
config: ScalarModelBase) -> List[str]:
|
||||
"""
|
||||
Return the ground truth labels associated with a subject. If the subject is not found, raises a ValueError.
|
||||
:param subject_id: Subject to retrive image for
|
||||
:param dataset: scalar dataset object
|
||||
:param config: model config
|
||||
:return: List of labels for the patient.
|
||||
"""
|
||||
labels = None
|
||||
|
||||
for item in dataset.items:
|
||||
if item.metadata.id == subject_id:
|
||||
labels = torch.flatten(torch.nonzero(item.label)).tolist()
|
||||
break
|
||||
|
||||
if labels is None:
|
||||
raise ValueError(f"Could not find subject {subject_id} in the dataset.")
|
||||
|
||||
return [config.class_names[int(label)] for label in labels
|
||||
if not math.isnan(label)]
|
||||
|
||||
|
||||
def get_image_outputs_from_subject_id(subject_id: str,
|
||||
metrics_df: pd.DataFrame) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Return a list of tuples (Label class name, model output for the class) for a single subject.
|
||||
"""
|
||||
|
||||
if not dataset_df[dataset_subject_column].is_unique:
|
||||
return None
|
||||
|
||||
dataset_df[dataset_subject_column] = dataset_df.apply(lambda x: str(x[dataset_subject_column]), axis=1)
|
||||
|
||||
if subject_id not in dataset_df[dataset_subject_column].unique():
|
||||
return None
|
||||
|
||||
filtered = dataset_df[dataset_df[dataset_subject_column] == subject_id]
|
||||
filepath = filtered.iloc[0][dataset_file_column]
|
||||
return dataset_dir / Path(filepath)
|
||||
filtered = metrics_df[metrics_df[LoggingColumns.Patient.value] == subject_id]
|
||||
outputs = list(zip(filtered[LoggingColumns.Hue.value].values.tolist(),
|
||||
filtered[LoggingColumns.ModelOutput.value].values.astype(float).tolist()))
|
||||
return outputs
|
||||
|
||||
|
||||
def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
|
||||
|
@ -321,58 +430,81 @@ def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
|
|||
|
||||
|
||||
def plot_image_for_subject(subject_id: str,
|
||||
dataset_df: pd.DataFrame,
|
||||
dataset_subject_column: str,
|
||||
dataset_file_column: str,
|
||||
dataset_dir: Path,
|
||||
dataset: ScalarDataset,
|
||||
im_width: int,
|
||||
model_output: float,
|
||||
header: Optional[str]) -> None:
|
||||
header: Optional[str],
|
||||
config: ScalarModelBase,
|
||||
metrics_df: Optional[pd.DataFrame] = None) -> None:
|
||||
"""
|
||||
Given a subject ID, plots the corresponding image.
|
||||
:param subject_id: Subject to plot image for
|
||||
:param dataset_df: Dataset dataframe (from the datset.csv file)
|
||||
:param dataset_subject_column: Name of the column with the subject IDs
|
||||
:param dataset_file_column: Name of the column with the image filepaths
|
||||
:param dataset_dir: Path to the dataset
|
||||
:param dataset: scalar dataset object
|
||||
:param im_width: Display width for image
|
||||
:param model_output: The predicted value for this image
|
||||
:param header: Optional header printed along with the subject ID and score for the image.
|
||||
:param config: model config
|
||||
:param metrics_df: dataframe with the metrics written out during inference time
|
||||
"""
|
||||
print_header("", level=4)
|
||||
if header:
|
||||
print_header(header, level=4)
|
||||
print_header(f"ID: {subject_id} Score: {model_output}", level=4)
|
||||
|
||||
filepath = get_image_filepath_from_subject_id(subject_id=str(subject_id),
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column,
|
||||
dataset_dir=dataset_dir)
|
||||
if not filepath:
|
||||
print_header(f"Subject ID {subject_id} not found, or found duplicate entries for this subject "
|
||||
f"in column {dataset_subject_column} in the csv file. "
|
||||
labels = get_image_labels_from_subject_id(subject_id=subject_id,
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
|
||||
print_header(f"True labels: {', '.join(labels) if labels else 'Negative'}", level=4)
|
||||
|
||||
if metrics_df is not None:
|
||||
all_model_outputs = get_image_outputs_from_subject_id(subject_id=subject_id,
|
||||
metrics_df=metrics_df)
|
||||
print_header(f"ID: {subject_id}", level=4)
|
||||
print_header(f"Model output: {', '.join([':'.join([str(x) for x in output]) for output in all_model_outputs])}",
|
||||
level=4)
|
||||
else:
|
||||
print_header(f"ID: {subject_id} Score: {model_output}", level=4)
|
||||
|
||||
filepaths = get_image_filepath_from_subject_id(subject_id=str(subject_id),
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
|
||||
if not filepaths:
|
||||
print_header(f"Subject ID {subject_id} not found."
|
||||
f"Note: Reports with datasets that use channel columns in the dataset.csv "
|
||||
f"are not yet supported.")
|
||||
f"are not yet supported.", level=0)
|
||||
return
|
||||
|
||||
success = plot_image_from_filepath(filepath, im_width=im_width)
|
||||
if not success:
|
||||
print_header("Unable to plot image: image must be 2D with shape [w, h] or [1, w, h].", level=0)
|
||||
for filepath in filepaths:
|
||||
success = plot_image_from_filepath(filepath, im_width=im_width)
|
||||
if not success:
|
||||
print_header("Unable to plot image: image must be 2D with shape [w, h] or [1, w, h].", level=0)
|
||||
|
||||
|
||||
def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, dataset_csv_path: Path,
|
||||
dataset_subject_column: str, dataset_file_column: str) -> None:
|
||||
def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int,
|
||||
prediction_target: str, config: ScalarModelBase) -> None:
|
||||
"""
|
||||
Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain)
|
||||
and the top "k" worst predictions (i.e. misclassifications where the model was the most confident).
|
||||
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
|
||||
used to determine the thresholds for predicting labels on the test set. The best and worst
|
||||
performing subjects will not be printed out for this csv.
|
||||
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
|
||||
best and worst performing subjects will be printed out.
|
||||
:param k: Number of subjects of each category to print out.
|
||||
:param prediction_target: The class label to filter on
|
||||
:param config: scalar model config object
|
||||
|
||||
"""
|
||||
results = get_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv,
|
||||
test_metrics_csv=test_metrics_csv,
|
||||
k=k)
|
||||
k=k,
|
||||
prediction_target=prediction_target)
|
||||
|
||||
dataset_df = pd.read_csv(dataset_csv_path)
|
||||
dataset_dir = dataset_csv_path.parent
|
||||
test_metrics = pd.read_csv(test_metrics_csv, dtype=str)
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
im_width = 800
|
||||
|
||||
|
@ -381,46 +513,42 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa
|
|||
for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value],
|
||||
results.false_positives[LoggingColumns.ModelOutput.value])):
|
||||
plot_image_for_subject(subject_id=str(subject),
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
im_width=im_width,
|
||||
model_output=model_output,
|
||||
header="False Positive")
|
||||
header="False Positive",
|
||||
config=config,
|
||||
metrics_df=test_metrics)
|
||||
|
||||
print_header(f"Top {k} false negatives", level=2)
|
||||
for index, (subject, model_output) in enumerate(zip(results.false_negatives[LoggingColumns.Patient.value],
|
||||
results.false_negatives[LoggingColumns.ModelOutput.value])):
|
||||
plot_image_for_subject(subject_id=str(subject),
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
im_width=im_width,
|
||||
model_output=model_output,
|
||||
header="False Negative")
|
||||
header="False Negative",
|
||||
config=config,
|
||||
metrics_df=test_metrics)
|
||||
|
||||
print_header(f"Top {k} true positives", level=2)
|
||||
for index, (subject, model_output) in enumerate(zip(results.true_positives[LoggingColumns.Patient.value],
|
||||
results.true_positives[LoggingColumns.ModelOutput.value])):
|
||||
plot_image_for_subject(subject_id=str(subject),
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
im_width=im_width,
|
||||
model_output=model_output,
|
||||
header="True Positive")
|
||||
header="True Positive",
|
||||
config=config,
|
||||
metrics_df=test_metrics)
|
||||
|
||||
print_header(f"Top {k} true negatives", level=2)
|
||||
for index, (subject, model_output) in enumerate(zip(results.true_negatives[LoggingColumns.Patient.value],
|
||||
results.true_negatives[LoggingColumns.ModelOutput.value])):
|
||||
plot_image_for_subject(subject_id=str(subject),
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
im_width=im_width,
|
||||
model_output=model_output,
|
||||
header="True Negative")
|
||||
header="True Negative",
|
||||
config=config,
|
||||
metrics_df=test_metrics)
|
||||
|
|
|
@ -5,13 +5,40 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import codecs
|
||||
import nbformat
|
||||
import papermill
|
||||
import pickle
|
||||
from IPython.display import Markdown, display
|
||||
from nbconvert import HTMLExporter
|
||||
from nbconvert.writers import FilesWriter
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
|
||||
|
||||
REPORT_PREFIX = "report"
|
||||
REPORT_IPYNB_SUFFIX = ".ipynb"
|
||||
REPORT_HTML_SUFFIX = ".html"
|
||||
reports_folder = "reports"
|
||||
|
||||
|
||||
def get_ipynb_report_name(report_type: str) -> str:
|
||||
"""
|
||||
Constructs the name of the report (as an ipython notebook).
|
||||
:param report_type: suffix describing the report, added to the filename
|
||||
:return:
|
||||
"""
|
||||
return f"{REPORT_PREFIX}_{report_type}{REPORT_IPYNB_SUFFIX}"
|
||||
|
||||
|
||||
def get_html_report_name(report_type: str) -> str:
|
||||
"""
|
||||
Constructs the name of the report (as an html file).
|
||||
:param report_type: suffix describing the report, added to the filename
|
||||
:return:
|
||||
"""
|
||||
return f"{REPORT_PREFIX}_{report_type}{REPORT_HTML_SUFFIX}"
|
||||
|
||||
|
||||
def str_or_empty(p: Union[None, str, Path]) -> str:
|
||||
|
@ -89,12 +116,10 @@ def generate_segmentation_notebook(result_notebook: Path,
|
|||
|
||||
|
||||
def generate_classification_notebook(result_notebook: Path,
|
||||
config: ScalarModelBase,
|
||||
train_metrics: Optional[Path] = None,
|
||||
val_metrics: Optional[Path] = None,
|
||||
test_metrics: Optional[Path] = None,
|
||||
dataset_csv_path: Optional[Path] = None,
|
||||
dataset_subject_column: Optional[str] = None,
|
||||
dataset_file_column: Optional[str] = None) -> Path:
|
||||
test_metrics: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
Creates a reporting notebook for a classification model, using the given training, validation, and test set metrics.
|
||||
Returns the report file after HTML conversion.
|
||||
|
@ -106,11 +131,35 @@ def generate_classification_notebook(result_notebook: Path,
|
|||
'train_metrics_csv': str_or_empty(train_metrics),
|
||||
'val_metrics_csv': str_or_empty(val_metrics),
|
||||
'test_metrics_csv': str_or_empty(test_metrics),
|
||||
'dataset_csv_path': str_or_empty(dataset_csv_path),
|
||||
"dataset_subject_column": str_or_empty(dataset_subject_column),
|
||||
"dataset_file_column": str_or_empty(dataset_file_column)
|
||||
"config": codecs.encode(pickle.dumps(config), "base64").decode()
|
||||
}
|
||||
template = Path(__file__).absolute().parent / "classification_report.ipynb"
|
||||
return generate_notebook(template,
|
||||
notebook_params=notebook_params,
|
||||
result_notebook=result_notebook)
|
||||
|
||||
|
||||
def generate_classification_multilabel_notebook(result_notebook: Path,
|
||||
config: ScalarModelBase,
|
||||
train_metrics: Optional[Path] = None,
|
||||
val_metrics: Optional[Path] = None,
|
||||
test_metrics: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
Creates a reporting notebook for a multilabel classification model, using the given training, validation,
|
||||
and test set metrics. This report adds metrics specific to the multilabel task, and is meant to be used in
|
||||
addition to the standard report created for all classification models.
|
||||
Returns the report file after HTML conversion.
|
||||
"""
|
||||
|
||||
notebook_params = \
|
||||
{
|
||||
'innereye_path': str(fixed_paths.repository_root_directory()),
|
||||
'train_metrics_csv': str_or_empty(train_metrics),
|
||||
'val_metrics_csv': str_or_empty(val_metrics),
|
||||
'test_metrics_csv': str_or_empty(test_metrics),
|
||||
"config": codecs.encode(pickle.dumps(config), "base64").decode()
|
||||
}
|
||||
template = Path(__file__).absolute().parent / "classification_multilabel_report.ipynb"
|
||||
return generate_notebook(template,
|
||||
notebook_params=notebook_params,
|
||||
result_notebook=result_notebook)
|
||||
|
|
|
@ -40,9 +40,10 @@ from InnerEye.ML.model_config_base import ModelConfigBase
|
|||
from InnerEye.ML.model_inference_config import ModelInferenceConfig
|
||||
from InnerEye.ML.model_testing import model_test
|
||||
from InnerEye.ML.model_training import model_train
|
||||
from InnerEye.ML.reports.notebook_report import generate_classification_notebook, generate_segmentation_notebook
|
||||
from InnerEye.ML.runner import ModelDeploymentHookSignature, PostCrossValidationHookSignature, REPORT_HTML, \
|
||||
REPORT_IPYNB, get_all_environment_files
|
||||
from InnerEye.ML.reports.notebook_report import get_ipynb_report_name, generate_classification_notebook, \
|
||||
generate_segmentation_notebook, \
|
||||
generate_classification_multilabel_notebook, reports_folder
|
||||
from InnerEye.ML.runner import ModelDeploymentHookSignature, PostCrossValidationHookSignature, get_all_environment_files
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils import ml_util
|
||||
|
@ -663,7 +664,8 @@ class MLRunner:
|
|||
model_proc=ModelProcessing.ENSEMBLE_CREATION)
|
||||
|
||||
crossval_dir = self.plot_cross_validation_and_upload_results()
|
||||
self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
|
||||
if self.model_config.generate_report:
|
||||
self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
|
||||
# CrossValResults should have been uploaded to the parent run, so we don't need it here.
|
||||
remove_file_or_directory(crossval_dir)
|
||||
# We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
|
||||
|
@ -676,8 +678,7 @@ class MLRunner:
|
|||
for subdir in other_runs_ensemble_dir.glob("*"):
|
||||
if subdir.name not in [BASELINE_WILCOXON_RESULTS_FILE,
|
||||
SCATTERPLOTS_SUBDIR_NAME,
|
||||
REPORT_HTML,
|
||||
REPORT_IPYNB]:
|
||||
reports_folder]:
|
||||
remove_file_or_directory(subdir)
|
||||
PARENT_RUN_CONTEXT.upload_folder(name=BASELINE_COMPARISONS_FOLDER, path=str(other_runs_ensemble_dir))
|
||||
else:
|
||||
|
@ -724,21 +725,33 @@ class MLRunner:
|
|||
|
||||
output_dir = config.outputs_folder / OTHER_RUNS_SUBDIR_NAME / ENSEMBLE_SPLIT_NAME \
|
||||
if model_proc == ModelProcessing.ENSEMBLE_CREATION else config.outputs_folder
|
||||
|
||||
reports_dir = output_dir / reports_folder
|
||||
if not reports_dir.exists():
|
||||
reports_dir.mkdir(exist_ok=False)
|
||||
|
||||
if config.model_category == ModelCategory.Segmentation:
|
||||
generate_segmentation_notebook(result_notebook=output_dir / REPORT_IPYNB,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
generate_segmentation_notebook(
|
||||
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
else:
|
||||
if isinstance(config, ScalarModelBase) and not isinstance(config, SequenceModelBase):
|
||||
dataset_csv_path = config.local_dataset / config.dataset_csv if config.local_dataset else None
|
||||
generate_classification_notebook(result_notebook=output_dir / REPORT_IPYNB,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test,
|
||||
dataset_csv_path=dataset_csv_path,
|
||||
dataset_subject_column=config.subject_column,
|
||||
dataset_file_column=config.image_file_column)
|
||||
generate_classification_notebook(
|
||||
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
|
||||
config=config,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
|
||||
if len(config.class_names) > 1:
|
||||
generate_classification_multilabel_notebook(
|
||||
result_notebook=reports_dir / get_ipynb_report_name(f"{config.model_category.value}_multilabel"),
|
||||
config=config,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
else:
|
||||
logging.info(f"Cannot create report for config of type {type(config)}.")
|
||||
except Exception as ex:
|
||||
|
|
|
@ -41,9 +41,6 @@ from InnerEye.ML.config import SegmentationModelBase
|
|||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.utils.config_util import ModelConfigLoader
|
||||
|
||||
REPORT_IPYNB = "report.ipynb"
|
||||
REPORT_HTML = "report.html"
|
||||
|
||||
LOG_FILE_NAME = "stdout.txt"
|
||||
|
||||
PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None]
|
||||
|
|
|
@ -16,6 +16,7 @@ from InnerEye.Common.generic_parsing import ListOrDictParam
|
|||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.common import ModelExecutionMode, OneHotEncoderBase
|
||||
from InnerEye.ML.deep_learning_config import ModelCategory
|
||||
from InnerEye.ML.metrics_dict import MetricsDict
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase, ModelTransformsPerExecutionMode
|
||||
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
@ -102,6 +103,17 @@ class LabelTransformation(Enum):
|
|||
|
||||
|
||||
class ScalarModelBase(ModelConfigBase):
|
||||
class_names: List[str] = param.List(class_=str,
|
||||
default=[MetricsDict.DEFAULT_HUE_KEY],
|
||||
bounds=(1, None),
|
||||
doc="The label names for each label class in the dataset and model output "
|
||||
"in the case of binary and multi-label classification tasks."
|
||||
"The order of the names should match the order of label class indices "
|
||||
"in dataset.csv"
|
||||
"For multi-label classification, this field is required."
|
||||
"For binary classification, this field must be a list of size 1, and "
|
||||
"is by default ['Default'], but can optionally be set to a more descriptive "
|
||||
"name for the positive class.")
|
||||
aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType,
|
||||
doc="The type of global pooling aggregation to use between"
|
||||
" the encoder and the classifier.")
|
||||
|
@ -205,6 +217,10 @@ class ScalarModelBase(ModelConfigBase):
|
|||
else:
|
||||
self.num_dataset_reader_workers = num_dataset_reader_workers
|
||||
|
||||
def validate(self) -> None:
|
||||
if len(self.class_names) > 1 and not self.is_classification_model:
|
||||
raise ValueError("Multiple label classes supported only for classification tasks.")
|
||||
|
||||
@property
|
||||
def is_classification_model(self) -> bool:
|
||||
"""
|
||||
|
@ -385,6 +401,12 @@ class ScalarModelBase(ModelConfigBase):
|
|||
assert self._datasets_for_training is not None # for mypy
|
||||
return self._datasets_for_training[ModelExecutionMode.TRAIN].get_class_counts()
|
||||
|
||||
def get_total_number_of_training_samples(self) -> int:
|
||||
if self._datasets_for_training is None:
|
||||
self.create_and_set_torch_datasets(for_inference=False)
|
||||
assert self._datasets_for_training is not None # for mypy
|
||||
return len(self._datasets_for_training[ModelExecutionMode.TRAIN])
|
||||
|
||||
def create_model(self) -> Any:
|
||||
pass
|
||||
|
||||
|
|
|
@ -384,6 +384,13 @@ def load_images_and_stack(files: Iterable[Path],
|
|||
return ImageAndSegmentations(images=image_tensor, segmentations=segmentation_tensor)
|
||||
|
||||
|
||||
def is_png(file: PathOrString) -> bool:
|
||||
"""
|
||||
Returns true if file is png
|
||||
"""
|
||||
return _file_matches_extension(file, [".png"])
|
||||
|
||||
|
||||
def load_image_in_known_formats(file: Path,
|
||||
load_segmentation: bool) -> ImageAndSegmentations[np.ndarray]:
|
||||
"""
|
||||
|
@ -404,6 +411,9 @@ def load_image_in_known_formats(file: Path,
|
|||
return ImageAndSegmentations(images=load_numpy_image(path=file))
|
||||
elif is_dicom_file_path(file):
|
||||
return ImageAndSegmentations(images=load_dicom_image(path=file))
|
||||
elif is_png(file):
|
||||
image_with_header = load_image(path=file)
|
||||
return ImageAndSegmentations(images=image_with_header.image)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image file type for path {file}")
|
||||
|
||||
|
@ -463,6 +473,11 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW
|
|||
image = load_hdf5_dataset_from_file(Path(h5_path), dataset)[channel]
|
||||
header = get_unit_image_header()
|
||||
return ImageWithHeader(image, header)
|
||||
elif is_png(path):
|
||||
import imageio
|
||||
image = imageio.imread(path).astype(np.float)
|
||||
header = get_unit_image_header()
|
||||
return ImageWithHeader(image, header)
|
||||
raise ValueError(f"Invalid file type {path}")
|
||||
|
||||
|
||||
|
|
|
@ -105,11 +105,14 @@ def create_scalar_loss_function(config: ScalarModelBase) -> torch.nn.Module:
|
|||
Returns a torch module that computes a loss function for classification and regression models.
|
||||
"""
|
||||
if config.loss_type == ScalarLoss.BinaryCrossEntropyWithLogits:
|
||||
return BinaryCrossEntropyWithLogitsLoss(smoothing_eps=config.label_smoothing_eps)
|
||||
return BinaryCrossEntropyWithLogitsLoss(num_classes=len(config.class_names),
|
||||
smoothing_eps=config.label_smoothing_eps)
|
||||
if config.loss_type == ScalarLoss.WeightedCrossEntropyWithLogits:
|
||||
return BinaryCrossEntropyWithLogitsLoss(
|
||||
num_classes=len(config.class_names),
|
||||
smoothing_eps=config.label_smoothing_eps,
|
||||
class_counts=config.get_training_class_counts())
|
||||
class_counts=config.get_training_class_counts(),
|
||||
num_train_samples=config.get_total_number_of_training_samples())
|
||||
elif config.loss_type == ScalarLoss.MeanSquaredError:
|
||||
return MSELoss()
|
||||
else:
|
||||
|
|
|
@ -37,7 +37,7 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
|
|||
# Smooth the one-hot target: 1.0 becomes 1.0-eps, 0.0 becomes eps / (nClasses - 1)
|
||||
# noinspection PyTypeChecker
|
||||
return target * (1.0 - self.smoothing_eps) + \
|
||||
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
|
||||
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
|
||||
|
||||
_input: List[T] = list(input)
|
||||
if self.smoothing_eps > 0.0:
|
||||
|
@ -56,10 +56,26 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
|
|||
class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
|
||||
"""A wrapper function for torch.nn.BCEWithLogitsLoss to enable label smoothing"""
|
||||
|
||||
def __init__(self, class_counts: Optional[Dict[float, float]] = None, **kwargs: Any):
|
||||
def __init__(self, num_classes: int,
|
||||
class_counts: Optional[Dict[float, int]] = None,
|
||||
num_train_samples: Optional[int] = None,
|
||||
**kwargs: Any):
|
||||
"""
|
||||
:param num_classes: The number of classes the model predicts. For binary classification num_classes is one
|
||||
and for multi-label classification tasks num_classes will be greater than one.
|
||||
:param class_counts: The number of positive samples for each class. class_counts is a dictionary with key-value
|
||||
pairs corresponding to each class and the positive sample count for the class.
|
||||
For binary classification tasks, class_counts should have a single key-value pair
|
||||
for the positive class.
|
||||
:param num_train_samples: The total number of training samples in the dataset.
|
||||
"""
|
||||
super().__init__(is_binary_classification=True, **kwargs)
|
||||
if class_counts and not num_train_samples:
|
||||
raise ValueError("Need to specify the num_train_samples with class_counts")
|
||||
self._positive_class_weights = None
|
||||
self._class_counts = class_counts
|
||||
self._num_train_samples = num_train_samples
|
||||
self.num_classes = num_classes
|
||||
if class_counts:
|
||||
self._positive_class_weights = self.get_positive_class_weights()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -73,16 +89,18 @@ class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
|
|||
target position.
|
||||
:return: a list of weights to use for the positive class for each target position.
|
||||
"""
|
||||
assert self._class_counts is not None
|
||||
labels = list(self._class_counts.keys())
|
||||
if sorted(labels) != [0.0, 1.0]:
|
||||
if labels == [1.0] or labels == [0.0]:
|
||||
return torch.tensor(1.0)
|
||||
else:
|
||||
raise ValueError(f"Expected one-hot encoded binary label."
|
||||
f"Found labels {self._class_counts.keys()}")
|
||||
else:
|
||||
return torch.tensor(float(self._class_counts[0.0]) / self._class_counts[1.0], dtype=torch.float32)
|
||||
assert self._class_counts and self._num_train_samples
|
||||
if len(self._class_counts) != self.num_classes:
|
||||
raise ValueError(f"Have {self.num_classes} classes but got counts for {len(self._class_counts)} classes "
|
||||
f"Note: If this is a binary classification task, there should be a single class count "
|
||||
f"corresponding to the positive class.")
|
||||
# These weights are given to the pos_weight parameter of Pytorch's BCEWithLogitsLoss.
|
||||
# Weights are calculated as (number of negative samples for class 'i')/(number of positive samples for class 'i')
|
||||
# for every class 'i' in a binary/multi-label classification task.
|
||||
# For a binary classification task, this reduces to (number of false samples / number of true samples).
|
||||
weights = [(self._num_train_samples - value) / value if value != 0 else 1.0 for (key, value) in
|
||||
sorted(self._class_counts.items())] # Uses the first number on the tuple to compare
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
def forward_minibatch(self, output: T, target: T, **kwargs: Any) -> Any:
|
||||
if isinstance(target, PackedSequence) and isinstance(output, PackedSequence):
|
||||
|
|
|
@ -15,7 +15,7 @@ from InnerEye.Azure.azure_runner import create_experiment_name, get_or_create_py
|
|||
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, fetch_child_runs, fetch_run, \
|
||||
get_cross_validation_split_index, is_cross_validation_child_run, is_run_and_child_runs_completed, \
|
||||
merge_conda_dependencies, merge_conda_files, to_azure_friendly_container_path
|
||||
from InnerEye.Common.common_util import logging_to_stdout
|
||||
from InnerEye.Common.common_util import logging_to_stdout, is_linux
|
||||
from InnerEye.Common.fixed_paths import PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \
|
||||
get_environment_yaml_file, repository_root_directory
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
|
@ -85,6 +85,7 @@ def test_is_cross_validation_child_run_ensemble_run() -> None:
|
|||
assert all([is_cross_validation_child_run(x) for x in fetch_child_runs(run)])
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_linux(), reason="Spurious file read/write errors on linux build agents.")
|
||||
def test_merge_conda(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Tests the logic for merging Conda environment files.
|
||||
|
|
|
@ -21,7 +21,7 @@ from InnerEye.Common.output_directories import OutputFolderForTests
|
|||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
|
||||
from InnerEye.ML.dataset.scalar_dataset import DataSourceReader, ScalarDataSource, ScalarDataset, \
|
||||
_get_single_channel_row, _string_to_float, extract_label_classification, extract_label_regression, files_by_stem, \
|
||||
_get_single_channel_row, _string_to_float, extract_label_classification, files_by_stem, \
|
||||
is_valid_item_index, load_single_data_source
|
||||
from InnerEye.ML.photometric_normalization import WindowNormalizationForScalarItem, mri_window
|
||||
from InnerEye.ML.scalar_config import LabelTransformation, ScalarLoss, ScalarModelBase
|
||||
|
@ -496,41 +496,49 @@ def test_load_single_item_7() -> None:
|
|||
assert torch.all(torch.isnan(item.categorical_non_image_features[4:6]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["text", "expected_classification", "expected_regression"],
|
||||
@pytest.mark.parametrize(["text", "is_classification", "num_classes", "expected_label"],
|
||||
[
|
||||
("true", 1, None),
|
||||
("tRuE", 1, None),
|
||||
("false", 0, None),
|
||||
("False", 0, None),
|
||||
("nO", 0, None),
|
||||
("Yes", 1, None),
|
||||
("1.23", None, 1.23),
|
||||
(3.45, None, None),
|
||||
(math.nan, math.nan, math.nan),
|
||||
("", math.nan, math.nan),
|
||||
(None, math.nan, math.nan),
|
||||
("abc", None, None),
|
||||
("1", 1, 1.0),
|
||||
("-1", None, -1.0)
|
||||
("false", True, 0, None),
|
||||
("false", False, 0, None),
|
||||
("true", True, 1, [1]),
|
||||
("tRuE", True, 1, [1]),
|
||||
("false", True, 1, [0]),
|
||||
("False", True, 1, [0]),
|
||||
("nO", True, 1, [0]),
|
||||
("Yes", True, 1, [1.0]),
|
||||
("1.23", False, 1, [1.23]),
|
||||
(3.45, True, 1, None),
|
||||
(3.45, False, 1, None),
|
||||
("3.45", False, 1, [3.45]),
|
||||
(math.nan, True, 1, [math.nan]),
|
||||
(math.nan, True, 3, [0.0, 0.0, 0.0]),
|
||||
(math.nan, False, 1, [math.nan]),
|
||||
("", True, 1, [math.nan]),
|
||||
("", True, 3, [0.0, 0.0, 0.0]),
|
||||
("", False, 1, [math.nan]),
|
||||
("abc", True, 1, None),
|
||||
("abc", True, 3, None),
|
||||
("abc", False, 1, None),
|
||||
("1", True, 1, [1.0]),
|
||||
("1", True, 3, [0.0, 1.0, 0.0]),
|
||||
("1", False, 1, [1.0]),
|
||||
("-1", False, 1, [-1.0]),
|
||||
("1|2", True, 3, [0.0, 1.0, 1.0]),
|
||||
("1|5", True, 3, None)
|
||||
])
|
||||
def test_extract_label(text: Union[float, str], expected_classification: Optional[float],
|
||||
expected_regression: Optional[float]) -> None:
|
||||
_check_label_extraction_function(extract_label_classification, text, expected_classification)
|
||||
_check_label_extraction_function(extract_label_regression, text, expected_regression)
|
||||
|
||||
|
||||
def _check_label_extraction_function(extract_fn: Callable, text: Union[float, str], expected: Optional[float]) -> None:
|
||||
if expected is None:
|
||||
def test_extract_label(text: str, is_classification: bool, num_classes: int,
|
||||
expected_label: List[float], ) -> None:
|
||||
if expected_label is None:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
extract_fn(text, "foo")
|
||||
assert "Subject foo:" in str(ex)
|
||||
extract_label_classification(text, "subject1", num_classes, is_classification)
|
||||
assert "Subject subject1:" in str(ex)
|
||||
else:
|
||||
actual = extract_fn(text, "foo")
|
||||
assert isinstance(actual, type(expected))
|
||||
if math.isnan(expected):
|
||||
assert math.isnan(actual)
|
||||
actual = extract_label_classification(text, "subject1", num_classes, is_classification)
|
||||
assert isinstance(actual, type(expected_label))
|
||||
if expected_label == [math.nan]:
|
||||
assert math.isnan(actual[0])
|
||||
else:
|
||||
assert actual == expected
|
||||
assert actual == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["text", "expected"],
|
||||
|
@ -791,10 +799,9 @@ def test_imbalanced_sampler() -> None:
|
|||
assert count_negative_subjects / float(len(drawn_subjects)) > 0.3
|
||||
|
||||
|
||||
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_get_class_counts_binary(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models that predicts at multiple time points,
|
||||
when it is started via run_ml.
|
||||
Test the get_class_counts method for binary scalar datasets.
|
||||
"""
|
||||
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
|
||||
dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
|
||||
|
@ -816,4 +823,86 @@ def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> No
|
|||
config.set_output_to(test_output_dirs.root_dir)
|
||||
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
|
||||
class_counts = train_dataset.get_class_counts()
|
||||
assert class_counts == {0.0: 1, 1.0: 2}
|
||||
assert class_counts == {0: 2}
|
||||
|
||||
|
||||
def test_get_class_counts_multilabel(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test the get_class_counts method for multilabel scalar datasets.
|
||||
"""
|
||||
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
|
||||
dataset_contents = """subject,channel,path,label,CAT1
|
||||
S1,week0,scan1.npy,,A
|
||||
S1,week1,scan2.npy,0|1|2,A
|
||||
S2,week0,scan3.npy,,A
|
||||
S2,week1,scan4.npy,1|2,A
|
||||
S3,week0,scan1.npy,,A
|
||||
S3,week1,scan3.npy,1,A
|
||||
"""
|
||||
config = ScalarModelBase(
|
||||
local_dataset=dataset_folder,
|
||||
class_names=["class0", "class1", "class2", "class3"],
|
||||
label_channels=["week1"],
|
||||
label_value_column="label",
|
||||
non_image_feature_channels=["week0", "week1"],
|
||||
should_validate=False
|
||||
)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
|
||||
class_counts = train_dataset.get_class_counts()
|
||||
assert class_counts == {0: 1, 1: 3, 2: 2, 3: 0}
|
||||
|
||||
|
||||
def test_get_labels_for_imbalanced_sampler_binary(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test the get_labels_for_imbalanced_sampler method for binary scalar datasets.
|
||||
"""
|
||||
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
|
||||
dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
|
||||
S1,week0,scan1.npy,,1,10,A
|
||||
S1,week1,scan2.npy,True,2,20,A
|
||||
S2,week0,scan3.npy,,3,30,A
|
||||
S2,week1,scan4.npy,False,4,40,A
|
||||
S3,week0,scan1.npy,,5,50,A
|
||||
S3,week1,scan3.npy,True,6,60,A
|
||||
"""
|
||||
config = ScalarModelBase(
|
||||
local_dataset=dataset_folder,
|
||||
label_channels=["week1"],
|
||||
label_value_column="label",
|
||||
non_image_feature_channels=["week0", "week1"],
|
||||
numerical_columns=["numerical1", "numerical2"],
|
||||
should_validate=False
|
||||
)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
|
||||
labels = train_dataset.get_labels_for_imbalanced_sampler()
|
||||
assert labels == [1.0, 0.0, 1.0]
|
||||
|
||||
|
||||
def test_get_labels_for_imbalanced_sampler_multilabel(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test that the get_labels_for_imbalanced_sampler method raises an error for multilabel scalar datasets.
|
||||
"""
|
||||
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
|
||||
dataset_contents = """subject,channel,path,label,CAT1
|
||||
S1,week0,scan1.npy,,A
|
||||
S1,week1,scan2.npy,0|1|2,A
|
||||
S2,week0,scan3.npy,,A
|
||||
S2,week1,scan4.npy,1|2,A
|
||||
S3,week0,scan1.npy,,A
|
||||
S3,week1,scan3.npy,1,A
|
||||
"""
|
||||
config = ScalarModelBase(
|
||||
local_dataset=dataset_folder,
|
||||
class_names=["class0", "class1", "class2", "class3"],
|
||||
label_channels=["week1"],
|
||||
label_value_column="label",
|
||||
non_image_feature_channels=["week0", "week1"],
|
||||
should_validate=False
|
||||
)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
|
||||
with pytest.raises(NotImplementedError) as ex:
|
||||
train_dataset.get_labels_for_imbalanced_sampler()
|
||||
assert "ImbalancedSampler is not supported for multilabel tasks." in str(ex)
|
||||
|
|
|
@ -599,7 +599,7 @@ S4,0,True,4,40,M2,B1
|
|||
assert_tensors_equal(test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])
|
||||
|
||||
|
||||
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_get_class_counts(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models that predicts at multiple time points,
|
||||
when it is started via run_ml.
|
||||
|
@ -615,7 +615,7 @@ def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> No
|
|||
splits = config.get_dataset_splits()
|
||||
train_dataset = config.create_torch_datasets(splits)[ModelExecutionMode.TRAIN]
|
||||
class_counts = train_dataset.get_class_counts()
|
||||
assert class_counts == {0.0: 9, 1.0: 2}
|
||||
assert class_counts == {0: 2}
|
||||
|
||||
|
||||
def test_get_labels_at_target_indices() -> None:
|
||||
|
|
|
@ -15,7 +15,8 @@ from InnerEye.ML.models.layers.identity import Identity
|
|||
class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
||||
def __init__(self, expected_image_size_zyx: TupleInt3,
|
||||
activation: torch.nn.Module = Identity(),
|
||||
use_mixed_precision: bool = False) -> None:
|
||||
use_mixed_precision: bool = False,
|
||||
num_classes: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.expected_image_size_zyx = expected_image_size_zyx
|
||||
self._layers = torch.nn.ModuleList()
|
||||
|
@ -25,7 +26,7 @@ class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
fc_out = fc(torch.zeros((1, 1) + self.expected_image_size_zyx))
|
||||
self.feature_size = fc_out.view(-1).shape[0]
|
||||
self._layers.append(fc)
|
||||
self.fc = torch.nn.Linear(self.feature_size, 1)
|
||||
self.fc = torch.nn.Linear(self.feature_size, out_features=num_classes)
|
||||
self.activation = activation
|
||||
self.last_encoder_layer: List[str] = ["_layers", "0"]
|
||||
self.conv_in_3d = False
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from typing import Dict
|
||||
|
||||
from InnerEye.ML.models.losses.cross_entropy import CrossEntropyLoss
|
||||
# Set random seed
|
||||
|
@ -58,8 +59,9 @@ def test_cross_entropy_loss_forward_smoothing(is_segmentation: bool) -> None:
|
|||
smoothed_target = torch.tensor([[[0.1, 0.1, 0.9], [0.9, 0.9, 0.1]]], dtype=torch.float32)
|
||||
logits = torch.tensor([[[-10, -10, 0], [0, 0, 0]]], dtype=torch.float32)
|
||||
|
||||
barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0)
|
||||
smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1)
|
||||
barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1,
|
||||
smoothing_eps=0)
|
||||
smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1, smoothing_eps=0.1)
|
||||
if is_segmentation:
|
||||
# The two loss values are only expected to be the same when no class weighting takes place,
|
||||
# because weighting is done on the *unsmoothed* target values.
|
||||
|
@ -143,10 +145,17 @@ def test_weighted_binary_cross_entropy_loss_forward_smoothing() -> None:
|
|||
smoothed_target = torch.tensor([[0.9], [0.9], [0.9], [0.9], [0.9], [0.1]], dtype=torch.float32)
|
||||
logits = torch.tensor([[-10], [-10], [0], [0], [0], [0]], dtype=torch.float32)
|
||||
weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
|
||||
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0, class_counts={0.0: 1.0, 1.0: 5.0})
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=1,
|
||||
smoothing_eps=0,
|
||||
class_counts={1.0: 5},
|
||||
num_train_samples=target.shape[0])
|
||||
weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
|
||||
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts={0.0: 1.0, 1.0: 5.0})
|
||||
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1,
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=1,
|
||||
smoothing_eps=0.1,
|
||||
class_counts={1.0: 5},
|
||||
num_train_samples=target.shape[0])
|
||||
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1,
|
||||
smoothing_eps=0.1,
|
||||
class_counts=None)
|
||||
w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
|
||||
w_loss2 = weighted_smoothed_loss_fn(logits, target)
|
||||
|
@ -157,23 +166,44 @@ def test_weighted_binary_cross_entropy_loss_forward_smoothing() -> None:
|
|||
assert torch.all(positive_class_weights == torch.tensor([[0.2]]))
|
||||
|
||||
|
||||
def test_weighted_binary_cross_entropy_loss_multi_target() -> None:
|
||||
target = torch.tensor([[[1], [0]], [[1], [0]], [[0], [0]]], dtype=torch.float32)
|
||||
smoothed_target = torch.tensor([[[0.9], [0.1]], [[0.9], [0.1]], [[0.1], [0.1]]], dtype=torch.float32)
|
||||
logits = torch.tensor([[[-10], [1]], [[-10], [1]], [[10], [0]]], dtype=torch.float32)
|
||||
def test_weighted_binary_cross_entropy_loss_multi_label() -> None:
|
||||
# Class 0 has 2 positive examples, class 1 has none
|
||||
target = torch.tensor([[1, 0], [1, 0], [0, 0]], dtype=torch.float32)
|
||||
smoothed_target = torch.tensor([[0.9, 0.1], [0.9, 0.1], [0.1, 0.1]], dtype=torch.float32)
|
||||
logits = torch.tensor([[-10, 1], [-10, 1], [10, 0]], dtype=torch.float32)
|
||||
weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
|
||||
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0, class_counts={1.0: 2, 0.0: 4})
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
|
||||
smoothing_eps=0,
|
||||
class_counts={1.0: 0, 0.0: 2},
|
||||
num_train_samples=target.shape[0])
|
||||
weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
|
||||
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts={1.0: 2, 0.0: 4})
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
|
||||
smoothing_eps=0.1,
|
||||
class_counts={1.0: 0, 0.0: 2},
|
||||
num_train_samples=target.shape[0])
|
||||
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
|
||||
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts=None)
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
|
||||
smoothing_eps=0.1,
|
||||
class_counts=None)
|
||||
w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
|
||||
w_loss2 = weighted_smoothed_loss_fn(logits, target)
|
||||
w_loss3 = non_weighted_smoothed_loss_fn(logits, target)
|
||||
positive_class_weights = weighted_smoothed_loss_fn.get_positive_class_weights() # type: ignore
|
||||
assert torch.isclose(w_loss1, w_loss2)
|
||||
assert not torch.isclose(w_loss2, w_loss3)
|
||||
assert torch.all(positive_class_weights == torch.tensor(2))
|
||||
assert torch.equal(positive_class_weights, torch.tensor([0.5, 1]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_classes, class_counts", [(1, {1.0: 0, 0.0: 2}),
|
||||
(3, {1.0: 0, 0.0: 2})])
|
||||
def test_invalid_initialization(num_classes: int,
|
||||
class_counts: Dict[float, int]) -> None:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
BinaryCrossEntropyWithLogitsLoss(num_classes=num_classes,
|
||||
smoothing_eps=0,
|
||||
class_counts=class_counts,
|
||||
num_train_samples=10)
|
||||
assert f"Have {num_classes} classes but got counts for {len(class_counts)} classes" in str(ex)
|
||||
|
||||
|
||||
class ToyNet(torch.nn.Module):
|
||||
|
|
|
@ -21,10 +21,13 @@ from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
|
|||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML import model_testing, model_training, runner
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
from InnerEye.ML.metrics import InferenceMetricsForClassification, binary_classification_accuracy, \
|
||||
compute_scalar_metrics
|
||||
from InnerEye.ML.metrics_dict import MetricsDict, ScalarMetricsDict
|
||||
from InnerEye.ML.reports.notebook_report import get_ipynb_report_name, get_html_report_name, \
|
||||
generate_classification_notebook, generate_classification_multilabel_notebook
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.utils.config_util import ModelConfigLoader
|
||||
|
@ -36,7 +39,8 @@ from Tests.ML.util import get_default_azure_config, get_default_checkpoint_handl
|
|||
|
||||
|
||||
@pytest.mark.cpu_and_gpu
|
||||
def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> None:
|
||||
@pytest.mark.parametrize("class_name", [MetricsDict.DEFAULT_HUE_KEY, "foo"])
|
||||
def test_train_classification_model(class_name: str, test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of classification models, asserting on the individual results from training and
|
||||
testing.
|
||||
|
@ -44,6 +48,7 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
"""
|
||||
logging_to_stdout(logging.DEBUG)
|
||||
config = ClassificationModelForTesting()
|
||||
config.class_names = [class_name]
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=Path(test_output_dirs.root_dir))
|
||||
|
@ -59,6 +64,7 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
assert len(model_training_result.val_results_per_epoch) == config.num_epochs
|
||||
assert len(model_training_result.train_results_per_epoch[0]) >= 11
|
||||
assert len(model_training_result.val_results_per_epoch[0]) >= 11
|
||||
|
||||
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
|
||||
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
|
||||
MetricType.AREA_UNDER_PR_CURVE,
|
||||
|
@ -67,10 +73,12 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
MetricType.LOSS,
|
||||
MetricType.SECONDS_PER_BATCH,
|
||||
MetricType.SECONDS_PER_EPOCH,
|
||||
MetricType.SUBJECT_COUNT,
|
||||
]:
|
||||
assert metric.value in model_training_result.train_results_per_epoch[0], f"{metric.value} not in training"
|
||||
assert metric.value in model_training_result.val_results_per_epoch[0], f"{metric.value} not in validation"
|
||||
MetricType.SUBJECT_COUNT]:
|
||||
assert metric.value in model_training_result.train_results_per_epoch[0], \
|
||||
f"{metric.value} not in training"
|
||||
assert metric.value in model_training_result.val_results_per_epoch[0], \
|
||||
f"{metric.value} not in validation"
|
||||
|
||||
actual_train_loss = model_training_result.get_metric(is_training=True, metric_type=MetricType.LOSS.value)
|
||||
actual_val_loss = model_training_result.get_metric(is_training=False, metric_type=MetricType.LOSS.value)
|
||||
actual_lr = model_training_result.get_metric(is_training=True, metric_type=MetricType.LEARNING_RATE.value)
|
||||
|
@ -81,20 +89,27 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
checkpoint_handler=checkpoint_handler)
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
expected_metrics = [0.636085, 0.735952]
|
||||
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
|
||||
assert test_results.metrics.values(class_name)[MetricType.CROSS_ENTROPY.value] == \
|
||||
pytest.approx(expected_metrics, abs=1e-5)
|
||||
# Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
|
||||
# we want to mostly assert that the files look reasonable
|
||||
if machine_has_gpu:
|
||||
return
|
||||
|
||||
# Check epoch_metrics.csv
|
||||
epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
|
||||
# Auto-format will break the long header line, hence the strange way of writing it!
|
||||
expected_epoch_metrics = \
|
||||
"loss,cross_entropy,accuracy_at_threshold_05,learning_rate," + \
|
||||
"area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
|
||||
"false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
|
||||
"optimal_threshold,subject_count,epoch,cross_validation_split_index\n" + \
|
||||
f"{LoggingColumns.Loss.value},{LoggingColumns.CrossEntropy.value}," \
|
||||
f"{LoggingColumns.AccuracyAtThreshold05.value},{LoggingColumns.LearningRate.value}," + \
|
||||
f"{LoggingColumns.AreaUnderRocCurve.value}," \
|
||||
f"{LoggingColumns.AreaUnderPRCurve.value}," \
|
||||
f"{LoggingColumns.AccuracyAtOptimalThreshold.value}," \
|
||||
f"{LoggingColumns.FalsePositiveRateAtOptimalThreshold.value}," \
|
||||
f"{LoggingColumns.FalseNegativeRateAtOptimalThreshold.value}," \
|
||||
f"{LoggingColumns.OptimalThreshold.value}," \
|
||||
f"{LoggingColumns.SubjectCount.value},{LoggingColumns.Epoch.value}," \
|
||||
f"{LoggingColumns.CrossValidationSplitIndex.value}\n" + \
|
||||
"""0.6866141557693481,0.6866141557693481,0.5,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1
|
||||
0.6864652633666992,0.6864652633666992,0.5,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1
|
||||
0.6863163113594055,0.6863162517547607,0.5,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1
|
||||
|
@ -107,15 +122,15 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
return
|
||||
metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / SUBJECT_METRICS_FILE_NAME
|
||||
metrics_expected = \
|
||||
"""epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index
|
||||
0,S2,Default,0.529514,1,Train,-1
|
||||
0,S4,Default,0.521659,0,Train,-1
|
||||
1,S4,Default,0.521482,0,Train,-1
|
||||
1,S2,Default,0.529475,1,Train,-1
|
||||
2,S4,Default,0.521305,0,Train,-1
|
||||
2,S2,Default,0.529437,1,Train,-1
|
||||
3,S2,Default,0.529399,1,Train,-1
|
||||
3,S4,Default,0.521128,0,Train,-1
|
||||
f"""epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index
|
||||
0,S2,{class_name},0.529514,1,Train,-1
|
||||
0,S4,{class_name},0.521659,0,Train,-1
|
||||
1,S4,{class_name},0.521482,0,Train,-1
|
||||
1,S2,{class_name},0.529475,1,Train,-1
|
||||
2,S4,{class_name},0.521305,0,Train,-1
|
||||
2,S2,{class_name},0.529437,1,Train,-1
|
||||
3,S2,{class_name},0.529399,1,Train,-1
|
||||
3,S4,{class_name},0.521128,0,Train,-1
|
||||
"""
|
||||
check_log_file(metrics_path, metrics_expected, ignore_columns=[])
|
||||
# Check log METRICS_FILE_NAME inside of the folder epoch_004/Train, which is written when we run model_test.
|
||||
|
@ -123,13 +138,97 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
|
|||
inference_metrics_path = config.outputs_folder / get_epoch_results_path(ModelExecutionMode.TRAIN) / \
|
||||
SUBJECT_METRICS_FILE_NAME
|
||||
inference_metrics_expected = \
|
||||
"""prediction_target,subject,model_output,label,cross_validation_split_index,data_split
|
||||
Default,S2,0.5293986201286316,1.0,-1,Train
|
||||
Default,S4,0.5211275815963745,0.0,-1,Train
|
||||
f"""prediction_target,subject,model_output,label,cross_validation_split_index,data_split
|
||||
{class_name},S2,0.5293986201286316,1.0,-1,Train
|
||||
{class_name},S4,0.5211275815963745,0.0,-1,Train
|
||||
"""
|
||||
check_log_file(inference_metrics_path, inference_metrics_expected, ignore_columns=[])
|
||||
|
||||
|
||||
@pytest.mark.cpu_and_gpu
|
||||
def test_train_classification_multilabel_model(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of classification models, asserting on the individual results from training and
|
||||
testing.
|
||||
Expected test results are stored for GPU with and without mixed precision.
|
||||
"""
|
||||
logging_to_stdout(logging.DEBUG)
|
||||
config = DummyMulticlassClassification()
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=Path(test_output_dirs.root_dir))
|
||||
# Train for 4 epochs, checkpoints at epochs 2 and 4
|
||||
config.num_epochs = 4
|
||||
model_training_result = model_training.model_train(config, checkpoint_handler=checkpoint_handler)
|
||||
assert model_training_result is not None
|
||||
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
|
||||
expected_train_loss = [0.699870228767395, 0.6239662170410156, 0.551329493522644, 0.4825132489204407]
|
||||
expected_val_loss = [0.6299371719360352, 0.5546272993087769, 0.4843321740627289, 0.41909298300743103]
|
||||
# Ensure that all metrics are computed on both training and validation set
|
||||
assert len(model_training_result.train_results_per_epoch) == config.num_epochs
|
||||
assert len(model_training_result.val_results_per_epoch) == config.num_epochs
|
||||
assert len(model_training_result.train_results_per_epoch[0]) >= 11
|
||||
assert len(model_training_result.val_results_per_epoch[0]) >= 11
|
||||
for class_name in config.class_names:
|
||||
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
|
||||
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
|
||||
MetricType.AREA_UNDER_PR_CURVE,
|
||||
MetricType.AREA_UNDER_ROC_CURVE,
|
||||
MetricType.CROSS_ENTROPY]:
|
||||
assert f'{metric.value}/{class_name}' in model_training_result.train_results_per_epoch[
|
||||
0], f"{metric.value} not in training"
|
||||
assert f'{metric.value}/{class_name}' in model_training_result.val_results_per_epoch[
|
||||
0], f"{metric.value} not in validation"
|
||||
for metric in [MetricType.LOSS,
|
||||
MetricType.SECONDS_PER_EPOCH,
|
||||
MetricType.SUBJECT_COUNT]:
|
||||
assert metric.value in model_training_result.train_results_per_epoch[0], f"{metric.value} not in training"
|
||||
assert metric.value in model_training_result.val_results_per_epoch[0], f"{metric.value} not in validation"
|
||||
|
||||
actual_train_loss = model_training_result.get_metric(is_training=True, metric_type=MetricType.LOSS.value)
|
||||
actual_val_loss = model_training_result.get_metric(is_training=False, metric_type=MetricType.LOSS.value)
|
||||
actual_lr = model_training_result.get_metric(is_training=True, metric_type=MetricType.LEARNING_RATE.value)
|
||||
assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6), "Training loss"
|
||||
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
|
||||
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
|
||||
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
|
||||
checkpoint_handler=checkpoint_handler)
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
|
||||
expected_metrics = {MetricType.CROSS_ENTROPY: [1.3996, 5.2966, 1.4020, 0.3553, 0.6908],
|
||||
MetricType.ACCURACY_AT_THRESHOLD_05: [0.0000, 0.0000, 0.0000, 1.0000, 1.0000]
|
||||
}
|
||||
|
||||
for i, class_name in enumerate(config.class_names):
|
||||
for metric in expected_metrics.keys():
|
||||
assert expected_metrics[metric][i] == pytest.approx(
|
||||
test_results.metrics.get_single_metric(
|
||||
metric_name=metric,
|
||||
hue=class_name), 1e-4)
|
||||
|
||||
def get_epoch_path(mode: ModelExecutionMode) -> Path:
|
||||
p = get_epoch_results_path(mode=mode)
|
||||
return config.outputs_folder / p / SUBJECT_METRICS_FILE_NAME
|
||||
|
||||
path_to_best_epoch_train = get_epoch_path(ModelExecutionMode.TRAIN)
|
||||
path_to_best_epoch_val = get_epoch_path(ModelExecutionMode.VAL)
|
||||
path_to_best_epoch_test = get_epoch_path(ModelExecutionMode.TEST)
|
||||
generate_classification_notebook(result_notebook=config.outputs_folder / get_ipynb_report_name(config.model_category.value),
|
||||
config=config,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
assert (config.outputs_folder / get_html_report_name(config.model_category.value)).exists()
|
||||
|
||||
report_name_multilabel = f"{config.model_category.value}_multilabel"
|
||||
generate_classification_multilabel_notebook(result_notebook=config.outputs_folder / get_ipynb_report_name(report_name_multilabel),
|
||||
config=config,
|
||||
train_metrics=path_to_best_epoch_train,
|
||||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
assert (config.outputs_folder / get_html_report_name(report_name_multilabel)).exists()
|
||||
|
||||
|
||||
def _count_lines(s: str) -> int:
|
||||
lines = [line for line in s.splitlines() if line.strip()]
|
||||
return len(lines)
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
subject,filePath,label
|
||||
0,../test_data/classification_data_2d/im1.npy,0
|
||||
1,../test_data/classification_data_2d/im2.npy,0
|
||||
2,../test_data/classification_data_2d/im1.npy,0
|
||||
3,../test_data/classification_data_2d/im2.npy,0
|
||||
4,../test_data/classification_data_2d/im1.npy,0
|
||||
5,../test_data/classification_data_2d/im2.npy,0
|
||||
6,../test_data/classification_data_2d/im1.npy,0
|
||||
7,../test_data/classification_data_2d/im2.npy,0
|
||||
8,../test_data/classification_data_2d/im1.npy,0
|
||||
9,../test_data/classification_data_2d/im2.npy,0
|
||||
10,../test_data/classification_data_2d/im1.npy,0
|
||||
11,../test_data/classification_data_2d/im2.npy,0
|
|
|
@ -0,0 +1,240 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
|
||||
from InnerEye.ML.metrics_dict import MetricsDict
|
||||
from InnerEye.ML.reports.classification_multilabel_report import get_dataframe_with_exact_label_matches, \
|
||||
get_labels_and_predictions_for_prediction_target_set, get_unique_prediction_target_combinations
|
||||
from InnerEye.ML.reports.notebook_report import generate_classification_multilabel_notebook
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
|
||||
def test_generate_classification_multilabel_report(test_output_dirs: OutputFolderForTests) -> None:
|
||||
hues = ["Hue1", "Hue2"]
|
||||
|
||||
config = ScalarModelBase(label_value_column="label",
|
||||
image_file_column="filePath",
|
||||
image_channels=["image1", "image2"],
|
||||
label_channels=["image1"])
|
||||
config.class_names = hues
|
||||
|
||||
test_metrics_file = test_output_dirs.root_dir / "test_metrics_classification.csv"
|
||||
val_metrics_file = test_output_dirs.root_dir / "val_metrics_classification.csv"
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv_path = config.local_dataset / "dataset.csv"
|
||||
image_file_name = "image.npy"
|
||||
|
||||
pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
|
||||
LoggingColumns.Epoch.value: [0] * 12,
|
||||
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
|
||||
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.9, 0.1, 0.9,
|
||||
0.9, 0.9, 0.9, 0.9, 0.9, 0.1],
|
||||
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
|
||||
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
|
||||
LoggingColumns.DataSplit.value: [0] * 12,
|
||||
}).to_csv(test_metrics_file, index=False)
|
||||
|
||||
pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
|
||||
LoggingColumns.Epoch.value: [0] * 12,
|
||||
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
|
||||
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.1, 0.1, 0.9,
|
||||
0.9, 0.9, 0.9, 0.1, 0.9, 0.1],
|
||||
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
|
||||
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
|
||||
LoggingColumns.DataSplit.value: [0] * 12,
|
||||
}).to_csv(val_metrics_file, index=False)
|
||||
|
||||
pd.DataFrame.from_dict({config.subject_column: [s for s in range(6) for _ in range(2)],
|
||||
config.channel_column: ["image1", "image2"] * 6,
|
||||
config.image_file_column: [f for f in [f"0_{image_file_name}", f"1_{image_file_name}"]
|
||||
for _ in range(6)],
|
||||
config.label_value_column: ["", "", "1", "1", "1", "1", "0|1", "0|1", "0|1", "0|1", "0", "0"]
|
||||
}).to_csv(dataset_csv_path, index=False)
|
||||
|
||||
np.save(str(Path(config.local_dataset / f"0_{image_file_name}")),
|
||||
np.random.randint(0, 255, [5, 4]))
|
||||
np.save(str(Path(config.local_dataset / f"1_{image_file_name}")),
|
||||
np.random.randint(0, 255, [5, 4]))
|
||||
|
||||
result_file = test_output_dirs.root_dir / "report.ipynb"
|
||||
result_html = generate_classification_multilabel_notebook(result_notebook=result_file,
|
||||
config=config,
|
||||
val_metrics=val_metrics_file,
|
||||
test_metrics=test_metrics_file)
|
||||
assert result_file.is_file()
|
||||
assert result_html.is_file()
|
||||
assert result_html.suffix == ".html"
|
||||
|
||||
|
||||
def test_get_pseudo_labels_and_predictions() -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
test_metrics_file = reports_folder / "test_metrics_classification.csv"
|
||||
|
||||
results = get_labels_and_predictions_for_prediction_target_set(test_metrics_file,
|
||||
[MetricsDict.DEFAULT_HUE_KEY],
|
||||
all_prediction_targets=[MetricsDict.DEFAULT_HUE_KEY],
|
||||
thresholds_per_prediction_target=[0.5])
|
||||
assert all([results.subject_ids[i] == i for i in range(12)])
|
||||
assert all([results.labels[i] == label for i, label in enumerate([1] * 6 + [0] * 6)])
|
||||
assert all([results.model_outputs[i] == op for i, op in enumerate([0.0, 0.0, 0.0, 1.0, 1.0, 1.0] * 2)])
|
||||
|
||||
|
||||
def test_get_pseudo_labels_and_predictions_multiple_hues(test_output_dirs: OutputFolderForTests) -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
test_metrics_file = reports_folder / "test_metrics_classification.csv"
|
||||
|
||||
# Write a new metrics file with 2 prediction targets,
|
||||
# prediction_target_set_to_match will only be associated with one prediction target
|
||||
csv = pd.read_csv(test_metrics_file)
|
||||
hues = ["Hue1", "Hue2"]
|
||||
csv.loc[::2, LoggingColumns.Hue.value] = hues[0]
|
||||
csv.loc[1::2, LoggingColumns.Hue.value] = hues[1]
|
||||
csv.loc[::2, LoggingColumns.Patient.value] = list(range(len(csv)//2))
|
||||
csv.loc[1::2, LoggingColumns.Patient.value] = list(range(len(csv)//2))
|
||||
csv.loc[::2, LoggingColumns.Label.value] = [0, 0, 0, 1, 1, 1]
|
||||
csv.loc[1::2, LoggingColumns.Label.value] = [0, 1, 1, 1, 1, 0]
|
||||
csv.loc[::2, LoggingColumns.ModelOutput.value] = [0.1, 0.1, 0.1, 0.9, 0.9, 0.9]
|
||||
csv.loc[1::2, LoggingColumns.ModelOutput.value] = [0.1, 0.9, 0.9, 0.9, 0.9, 0.1]
|
||||
metrics_csv_multi_hue = test_output_dirs.root_dir / "metrics.csv"
|
||||
csv.to_csv(metrics_csv_multi_hue, index=False)
|
||||
|
||||
for h, hue in enumerate(hues):
|
||||
results = get_labels_and_predictions_for_prediction_target_set(metrics_csv_multi_hue,
|
||||
prediction_target_set_to_match=[hue],
|
||||
all_prediction_targets=hues,
|
||||
thresholds_per_prediction_target=[0.5, 0.5])
|
||||
assert all([results.subject_ids[i] == i for i in range(6)])
|
||||
assert all([results.labels[i] == label
|
||||
for i, label in enumerate([0, 0, 0, 0, 0, 1] if h == 0 else [0, 1, 1, 0, 0, 0])])
|
||||
assert all([results.model_outputs[i] == op
|
||||
for i, op in enumerate([0, 0, 0, 0, 0, 1] if h == 0 else [0, 1, 1, 0, 0, 0])])
|
||||
|
||||
|
||||
def test_generate_pseudo_labels() -> None:
|
||||
|
||||
metrics_df = pd.DataFrame.from_dict({"prediction_target": ["Hue1", "Hue2", "Hue3"] * 4,
|
||||
"epoch": [0] * 12,
|
||||
"subject": [i for i in range(4) for _ in range(3)],
|
||||
"model_output": [0.5, 0.6, 0.3, 0.5, 0.6, 0.5, 0.5, 0.6, 0.5, 0.5, 0.6, 0.3],
|
||||
"label": [1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1],
|
||||
"cross_validation_split_index": [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
|
||||
"data_split": [ModelExecutionMode.TEST.value] * 12
|
||||
})
|
||||
|
||||
expected_df = pd.DataFrame.from_dict({"subject": list(range(4)),
|
||||
"model_output": [1, 0, 0, 1],
|
||||
"label": [1, 0, 1, 0],
|
||||
"prediction_target": ["Hue1|Hue2"] * 4})
|
||||
|
||||
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
|
||||
prediction_target_set_to_match=["Hue1", "Hue2"],
|
||||
all_prediction_targets=["Hue1", "Hue2", "Hue3"],
|
||||
thresholds_per_prediction_target=[0.4, 0.5, 0.4])
|
||||
assert expected_df.equals(df)
|
||||
|
||||
|
||||
def test_generate_pseudo_labels_negative_class() -> None:
|
||||
|
||||
metrics_df = pd.DataFrame.from_dict({"prediction_target": ["Hue1", "Hue2", "Hue3"] * 4,
|
||||
"epoch": [0] * 12,
|
||||
"subject": [i for i in range(4) for _ in range(3)],
|
||||
"model_output": [0.2, 0.3, 0.2, 0.5, 0.6, 0.5, 0.5, 0.6, 0.5, 0.2, 0.3, 0.2],
|
||||
"label": [0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1],
|
||||
"cross_validation_split_index": [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
|
||||
"data_split": [ModelExecutionMode.TEST.value] * 12
|
||||
})
|
||||
|
||||
expected_df = pd.DataFrame.from_dict({"subject": list(range(4)),
|
||||
"model_output": [1, 0, 0, 1],
|
||||
"label": [1, 1, 0, 0],
|
||||
"prediction_target": [""] * 4})
|
||||
|
||||
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
|
||||
prediction_target_set_to_match=[],
|
||||
all_prediction_targets=["Hue1", "Hue2", "Hue3"],
|
||||
thresholds_per_prediction_target=[0.4, 0.5, 0.4])
|
||||
assert expected_df.equals(df)
|
||||
|
||||
|
||||
def test_get_unique_label_combinations_single_label(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
label_value_column="value",
|
||||
image_channels=["image"],
|
||||
image_file_column="path",
|
||||
subject_column="subjectID")
|
||||
class_names = config.class_names
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("subjectID,channel,path,value\n"
|
||||
"S1,label,random,1\n"
|
||||
"S1,image,random,\n"
|
||||
"S2,label,random,0\n"
|
||||
"S2,image,random,\n"
|
||||
"S3,label,random,1\n"
|
||||
"S3,image,random,\n")
|
||||
|
||||
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
|
||||
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
|
||||
for labels in [[], [0]])
|
||||
assert unique_labels == expected_label_combinations
|
||||
|
||||
|
||||
def test_get_unique_label_combinations_nan(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
label_value_column="value",
|
||||
image_channels=["image"],
|
||||
image_file_column="path",
|
||||
subject_column="subjectID")
|
||||
class_names = config.class_names
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("subjectID,channel,path,value\n"
|
||||
"S1,label,random,1\n"
|
||||
"S1,image,random,\n"
|
||||
"S2,label,random,\n"
|
||||
"S2,image,random,\n")
|
||||
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
|
||||
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
|
||||
for labels in [[0]])
|
||||
assert unique_labels == expected_label_combinations
|
||||
|
||||
|
||||
def test_get_unique_label_combinations_multi_label(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = DummyMulticlassClassification()
|
||||
class_names = config.class_names
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("ID,channel,path,label\n"
|
||||
"S1,blue,random,1|2|3\n"
|
||||
"S1,green,random,\n"
|
||||
"S2,blue,random,2|3\n"
|
||||
"S2,green,random,\n"
|
||||
"S3,blue,random,3\n"
|
||||
"S3,green,random,\n"
|
||||
"S4,blue,random,\n"
|
||||
"S4,green,random,\n")
|
||||
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
|
||||
|
||||
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
|
||||
for labels in [[1, 2, 3], [2, 3], [3], []])
|
||||
assert unique_labels == expected_label_combinations
|
|
@ -12,38 +12,64 @@ import pytest
|
|||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.ML.reports.classification_report import ReportedMetrics, get_correct_and_misclassified_examples, \
|
||||
get_image_filepath_from_subject_id, get_k_best_and_worst_performing, get_metric, get_results, \
|
||||
plot_image_from_filepath
|
||||
get_image_filepath_from_subject_id, get_k_best_and_worst_performing, get_metric, get_labels_and_predictions, \
|
||||
plot_image_from_filepath, get_image_labels_from_subject_id, get_image_outputs_from_subject_id
|
||||
from InnerEye.ML.reports.notebook_report import generate_classification_notebook
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
|
||||
from InnerEye.ML.metrics_dict import MetricsDict
|
||||
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
|
||||
def test_generate_classification_report(test_output_dirs: OutputFolderForTests) -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
test_metrics_file = reports_folder / "test_metrics_classification.csv"
|
||||
val_metrics_file = reports_folder / "val_metrics_classification.csv"
|
||||
dataset_csv_path = reports_folder / 'dataset.csv'
|
||||
dataset_subject_column = "subject"
|
||||
dataset_file_column = "filePath"
|
||||
|
||||
current_dir = test_output_dirs.make_sub_dir("test_classification_report")
|
||||
result_file = current_dir / "report.ipynb"
|
||||
config = ScalarModelBase(label_value_column="label",
|
||||
image_file_column="filePath",
|
||||
subject_column="subject")
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
image_file_name = "image.npy"
|
||||
dataset_csv.write_text("subject,filePath,label\n"
|
||||
f"0,0_{image_file_name},0\n"
|
||||
f"1,1_{image_file_name},0\n"
|
||||
f"2,0_{image_file_name},0\n"
|
||||
f"3,1_{image_file_name},0\n"
|
||||
f"4,0_{image_file_name},0\n"
|
||||
f"5,1_{image_file_name},0\n"
|
||||
f"6,0_{image_file_name},0\n"
|
||||
f"7,1_{image_file_name},0\n"
|
||||
f"8,0_{image_file_name},0\n"
|
||||
f"9,1_{image_file_name},0\n"
|
||||
f"10,0_{image_file_name},0\n"
|
||||
f"11,1_{image_file_name},0\n")
|
||||
|
||||
np.save(str(Path(config.local_dataset / f"0_{image_file_name}")),
|
||||
np.random.randint(0, 255, [5, 4]))
|
||||
np.save(str(Path(config.local_dataset / f"1_{image_file_name}")),
|
||||
np.random.randint(0, 255, [5, 4]))
|
||||
|
||||
result_file = test_output_dirs.root_dir / "report.ipynb"
|
||||
result_html = generate_classification_notebook(result_notebook=result_file,
|
||||
config=config,
|
||||
val_metrics=val_metrics_file,
|
||||
test_metrics=test_metrics_file,
|
||||
dataset_csv_path=dataset_csv_path,
|
||||
dataset_subject_column=dataset_subject_column,
|
||||
dataset_file_column=dataset_file_column)
|
||||
test_metrics=test_metrics_file)
|
||||
assert result_file.is_file()
|
||||
assert result_html.is_file()
|
||||
assert result_html.suffix == ".html"
|
||||
|
||||
|
||||
def test_get_results() -> None:
|
||||
def test_get_labels_and_predictions() -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
test_metrics_file = reports_folder / "test_metrics_classification.csv"
|
||||
|
||||
results = get_results(test_metrics_file)
|
||||
results = get_labels_and_predictions(test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
assert all([results.subject_ids[i] == i for i in range(12)])
|
||||
assert all([results.labels[i] == label for i, label in enumerate([1] * 6 + [0] * 6)])
|
||||
assert all([results.model_outputs[i] == op for i, op in enumerate([0.0, 0.2, 0.4, 0.6, 0.8, 1.0] * 2)])
|
||||
|
@ -57,16 +83,18 @@ def test_functions_with_invalid_csv(test_output_dirs: OutputFolderForTests) -> N
|
|||
shutil.copyfile(test_metrics_file, invalid_metrics_file)
|
||||
# Duplicate a subject
|
||||
with open(invalid_metrics_file, "a") as file:
|
||||
file.write("Default,1,5,1.0,1,-1,Test")
|
||||
file.write(f"{MetricsDict.DEFAULT_HUE_KEY},1,5,1.0,1,-1,Test")
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_labels_and_predictions(invalid_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
assert "Subject IDs should be unique" in str(ex)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_results(invalid_metrics_file)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_correct_and_misclassified_examples(invalid_metrics_file, test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
assert "Subject IDs should be unique" in str(ex)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_correct_and_misclassified_examples(invalid_metrics_file, test_metrics_file)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_correct_and_misclassified_examples(val_metrics_file, invalid_metrics_file)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_correct_and_misclassified_examples(val_metrics_file, invalid_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
assert "Subject IDs should be unique" in str(ex)
|
||||
|
||||
|
||||
def test_get_metric() -> None:
|
||||
|
@ -74,41 +102,72 @@ def test_get_metric() -> None:
|
|||
test_metrics_file = reports_folder / "test_metrics_classification.csv"
|
||||
val_metrics_file = reports_folder / "val_metrics_classification.csv"
|
||||
|
||||
optimal_threshold = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
val_metrics = get_labels_and_predictions(val_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
test_metrics = get_labels_and_predictions(test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
|
||||
|
||||
optimal_threshold = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.OptimalThreshold)
|
||||
|
||||
assert optimal_threshold == 0.6
|
||||
|
||||
auc_roc = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
optimal_threshold = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.OptimalThreshold,
|
||||
optimal_threshold=0.3)
|
||||
|
||||
assert optimal_threshold == 0.3
|
||||
|
||||
auc_roc = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.AUC_ROC)
|
||||
assert auc_roc == 0.5
|
||||
|
||||
auc_pr = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
auc_pr = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.AUC_PR)
|
||||
|
||||
assert math.isclose(auc_pr, 13 / 24, abs_tol=1e-15)
|
||||
|
||||
accuracy = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
accuracy = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.Accuracy)
|
||||
|
||||
assert accuracy == 0.5
|
||||
|
||||
fpr = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
accuracy = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.Accuracy,
|
||||
optimal_threshold=0.1)
|
||||
|
||||
assert accuracy == 0.5
|
||||
|
||||
fpr = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.FalsePositiveRate)
|
||||
|
||||
assert fpr == 0.5
|
||||
|
||||
fnr = get_metric(test_metrics_csv=test_metrics_file,
|
||||
val_metrics_csv=val_metrics_file,
|
||||
fpr = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.FalsePositiveRate,
|
||||
optimal_threshold=0.1)
|
||||
|
||||
assert fpr == 5 / 6
|
||||
|
||||
fnr = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.FalseNegativeRate)
|
||||
|
||||
assert fnr == 0.5
|
||||
|
||||
fnr = get_metric(test_labels_and_predictions=test_metrics,
|
||||
val_labels_and_predictions=val_metrics,
|
||||
metric=ReportedMetrics.FalseNegativeRate,
|
||||
optimal_threshold=0.1)
|
||||
|
||||
assert math.isclose(fnr, 1 / 6, abs_tol=1e-15)
|
||||
|
||||
|
||||
def test_get_correct_and_misclassified_examples() -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
|
@ -153,34 +212,188 @@ def test_get_k_best_and_worst_performing() -> None:
|
|||
assert worst_false_negatives == [0, 1]
|
||||
|
||||
|
||||
def test_get_image_filepath_from_subject_id() -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
dataset_csv_file = reports_folder / "dataset.csv"
|
||||
dataset_df = pd.read_csv(dataset_csv_file)
|
||||
def test_get_image_filepath_from_subject_id_single(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(image_file_column="filePath",
|
||||
label_value_column="label",
|
||||
subject_column="subject")
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
image_file_name = "image.npy"
|
||||
dataset_csv.write_text(f"subject,filePath,label\n"
|
||||
f"0,0_{image_file_name},0\n"
|
||||
f"1,1_{image_file_name},1\n")
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
Path(config.local_dataset / f"0_{image_file_name}").touch()
|
||||
Path(config.local_dataset / f"1_{image_file_name}").touch()
|
||||
|
||||
filepath = get_image_filepath_from_subject_id(subject_id="1",
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column="subject",
|
||||
dataset_file_column="filePath",
|
||||
dataset_dir=reports_folder)
|
||||
expected_path = Path(reports_folder / "../test_data/classification_data_2d/im2.npy")
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
expected_path = Path(config.local_dataset / f"1_{image_file_name}")
|
||||
|
||||
assert filepath
|
||||
assert expected_path.samefile(filepath)
|
||||
assert len(filepath) == 1
|
||||
assert expected_path.samefile(filepath[0])
|
||||
|
||||
# Check error is raised if the subject does not exist
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_image_filepath_from_subject_id(subject_id="100",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
assert "Could not find subject" in str(ex)
|
||||
|
||||
|
||||
def test_get_image_filepath_from_subject_id_invalid_id() -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
dataset_csv_file = reports_folder / "dataset.csv"
|
||||
dataset_df = pd.read_csv(dataset_csv_file)
|
||||
def test_get_image_filepath_from_subject_id_with_image_channels(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
image_file_column="filePath",
|
||||
label_value_column="label",
|
||||
image_channels=["image"],
|
||||
subject_column="subject")
|
||||
|
||||
filepath = get_image_filepath_from_subject_id(subject_id="100",
|
||||
dataset_df=dataset_df,
|
||||
dataset_subject_column="subject",
|
||||
dataset_file_column="filePath",
|
||||
dataset_dir=reports_folder)
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
image_file_name = "image.npy"
|
||||
dataset_csv.write_text(f"subject,channel,filePath,label\n"
|
||||
f"0,label,,0\n"
|
||||
f"0,image,0_{image_file_name},\n"
|
||||
f"1,label,,1\n"
|
||||
f"1,image,1_{image_file_name},\n")
|
||||
|
||||
assert not filepath
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
Path(config.local_dataset / f"0_{image_file_name}").touch()
|
||||
Path(config.local_dataset / f"1_{image_file_name}").touch()
|
||||
|
||||
filepath = get_image_filepath_from_subject_id(subject_id="1",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
expected_path = Path(config.local_dataset / f"1_{image_file_name}")
|
||||
|
||||
assert filepath
|
||||
assert len(filepath) == 1
|
||||
assert filepath[0].samefile(expected_path)
|
||||
|
||||
|
||||
def test_get_image_filepath_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
image_file_column="filePath",
|
||||
label_value_column="label",
|
||||
image_channels=["image1", "image2"],
|
||||
subject_column="subject")
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
image_file_name = "image.npy"
|
||||
dataset_csv.write_text(f"subject,channel,filePath,label\n"
|
||||
f"0,label,,0\n"
|
||||
f"0,image1,00_{image_file_name},\n"
|
||||
f"0,image2,01_{image_file_name},\n"
|
||||
f"1,label,,1\n"
|
||||
f"1,image1,10_{image_file_name},\n"
|
||||
f"1,image2,11_{image_file_name},\n")
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
Path(config.local_dataset / f"00_{image_file_name}").touch()
|
||||
Path(config.local_dataset / f"01_{image_file_name}").touch()
|
||||
Path(config.local_dataset / f"10_{image_file_name}").touch()
|
||||
Path(config.local_dataset / f"11_{image_file_name}").touch()
|
||||
|
||||
filepath = get_image_filepath_from_subject_id(subject_id="1",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
expected_paths = [config.local_dataset / f"10_{image_file_name}",
|
||||
config.local_dataset / f"11_{image_file_name}"]
|
||||
|
||||
assert filepath
|
||||
assert len(filepath) == 2
|
||||
assert expected_paths[0].samefile(filepath[0])
|
||||
assert expected_paths[1].samefile(filepath[1])
|
||||
|
||||
|
||||
def test_image_labels_from_subject_id_single(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_value_column="label",
|
||||
subject_column="subject")
|
||||
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("subject,channel,label\n"
|
||||
"0,label,0\n"
|
||||
"1,label,1\n")
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
labels = get_image_labels_from_subject_id(subject_id="0",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
assert not labels
|
||||
|
||||
labels = get_image_labels_from_subject_id(subject_id="1",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
assert labels
|
||||
assert len(labels) == 1
|
||||
assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
|
||||
|
||||
|
||||
def test_image_labels_from_subject_id_with_label_channels(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
label_value_column="label",
|
||||
subject_column="subject")
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("subject,channel,label\n"
|
||||
"0,label,0\n"
|
||||
"0,image,\n"
|
||||
"1,label,1\n"
|
||||
"1,image,\n")
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
labels = get_image_labels_from_subject_id(subject_id="1",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
assert labels
|
||||
assert len(labels) == 1
|
||||
assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
|
||||
|
||||
|
||||
def test_image_labels_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ScalarModelBase(label_channels=["label"],
|
||||
label_value_column="label",
|
||||
subject_column="subject",
|
||||
class_names=["class1", "class2", "class3"])
|
||||
config.local_dataset = test_output_dirs.root_dir / "dataset"
|
||||
config.local_dataset.mkdir()
|
||||
dataset_csv = config.local_dataset / "dataset.csv"
|
||||
dataset_csv.write_text("subject,channel,label\n"
|
||||
"0,label,0\n"
|
||||
"0,image,\n"
|
||||
"1,label,1|2\n"
|
||||
"1,image,\n")
|
||||
|
||||
df = config.read_dataset_if_needed()
|
||||
dataset = ScalarDataset(args=config, data_frame=df)
|
||||
|
||||
labels = get_image_labels_from_subject_id(subject_id="1",
|
||||
dataset=dataset,
|
||||
config=config)
|
||||
assert labels
|
||||
assert len(labels) == 2
|
||||
assert set(labels) == {config.class_names[1], config.class_names[2]}
|
||||
|
||||
|
||||
def test_plot_image_from_filepath(test_output_dirs: OutputFolderForTests) -> None:
|
||||
|
@ -197,3 +410,27 @@ def test_plot_image_from_filepath(test_output_dirs: OutputFolderForTests) -> Non
|
|||
np.save(invalid_file, array)
|
||||
res = plot_image_from_filepath(invalid_file, im_width)
|
||||
assert not res
|
||||
|
||||
|
||||
def test_get_image_outputs_from_subject_id(test_output_dirs: OutputFolderForTests) -> None:
|
||||
hues = ["Hue1", "Hue2"]
|
||||
|
||||
metrics_df = pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
|
||||
LoggingColumns.Epoch.value: [0] * 12,
|
||||
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
|
||||
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.9, 0.1, 0.9,
|
||||
0.9, 0.9, 0.9, 0.9, 0.9, 0.1],
|
||||
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
|
||||
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
|
||||
LoggingColumns.DataSplit.value: [0] * 12,
|
||||
}, dtype=str)
|
||||
|
||||
config = DummyMulticlassClassification()
|
||||
config.class_names = hues
|
||||
config.subject_column = "subject"
|
||||
|
||||
model_output = get_image_outputs_from_subject_id(subject_id="1",
|
||||
metrics_df=metrics_df)
|
||||
assert model_output
|
||||
assert len(model_output) == 2
|
||||
assert all([m == e for m, e in zip(model_output, [(hues[0], 0.1), (hues[1], 0.9)])])
|
||||
|
|
|
@ -6,14 +6,17 @@ from io import StringIO
|
|||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from InnerEye.Common.metrics_constants import MetricsFileColumns
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.ML.reports.notebook_report import generate_segmentation_notebook
|
||||
from InnerEye.ML.reports.segmentation_report import describe_score, worst_patients_and_outliers
|
||||
from InnerEye.ML.utils.csv_util import COL_IS_OUTLIER
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
|
||||
def test_generate_segmentation_report(test_output_dirs: OutputFolderForTests) -> None:
|
||||
reports_folder = Path(__file__).parent
|
||||
metrics_file = reports_folder / "metrics_hn.csv"
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d4450200a7b3c1f1c4b94d844a3adfd429690c698592fe9cefac429e77b6ec56
|
||||
size 4716
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d4450200a7b3c1f1c4b94d844a3adfd429690c698592fe9cefac429e77b6ec56
|
||||
size 4716
|
|
@ -0,0 +1,7 @@
|
|||
ID,channel,path,label
|
||||
S1,blue,1_blue.png,1|2|3
|
||||
S1,green,1_green.png,
|
||||
S2,blue,1_blue.png,2|3
|
||||
S2,green,1_green.png,
|
||||
S3,blue,1_blue.png,3
|
||||
S3,green,1_green.png,
|
|
|
@ -152,7 +152,11 @@ Classification datasets should have a `dataset.csv` and a folder containing the
|
|||
have at least the following fields:
|
||||
* subject: The subject ID, a unique positive integer assigned to every image
|
||||
* path: Path to the image file for this subject
|
||||
* value: For classification, a (binary) ground truth label. For regression, a scalar value.
|
||||
* value:
|
||||
* For binary classification, a (binary) ground truth label. This can be "true" and "false" or "0" and "1".
|
||||
* For multi-label classification, the set of all positive labels for the image, separated by a `|` character.
|
||||
Ex: "0|2|4" for a sample with true labels 0, 2 and 4 and "" for a sample in which all labels are false.
|
||||
* For regression, a scalar value.
|
||||
|
||||
These, and other fields which can be added to dataset.csv are described in the examples below.
|
||||
|
||||
|
@ -233,7 +237,7 @@ Other recognized fields, apart from subject, channel, file path and label are nu
|
|||
These are extra scalar and categorical values to be used as model input.
|
||||
|
||||
Any *unrecognized* columns (any column which is both not described in the model config and has no default)
|
||||
will be converted to a dict of key-value pairs and stored in a object of type `GeneralSampleMetadata` in the sample.
|
||||
will be converted to a dict of key-value pairs and stored in an object of type `GeneralSampleMetadata` in the sample.
|
||||
|
||||
```
|
||||
SubjectID, Channel, FilePath, Label, Tag, weight, class
|
||||
|
@ -270,4 +274,58 @@ In this example, `weight` is a scalar feature read from the csv, and `class` is
|
|||
**Filtering on channels**: This example also shows why filtering values by channel is useful: In this example, each subject has 2 images taken at
|
||||
different times with different label values. By using `label_channels=["image_time_2"]`, we can use the label associated with
|
||||
the second image for all subjects.
|
||||
|
||||
|
||||
#### Multi-label classification datasets
|
||||
Classification datasets can be multi-label, i.e. they can have more than one label associated with every sample.
|
||||
In this case, in the label column, separate the (numerical) ground truth labels with a pipe character (`|`) to
|
||||
provide multiple ground truth labels for the sample.
|
||||
|
||||
Note that only *multi-label* datasets are supported, *multi-class* datasets (where the labels are mutually exclusive)
|
||||
are not supported.
|
||||
|
||||
For example, the `dataset.csv` for a multi-label task with 4 classes (0, 1, 2, 3) would look like the following:
|
||||
|
||||
```
|
||||
SubjectID, Channel, FilePath, Label
|
||||
1, image_feature_1, images/image_1_feature_1.npy,
|
||||
1, image_feature_2, images/image_1_feature_2.npy,
|
||||
1, label, , 0|2|3
|
||||
2, image_feature_1, images/image_2_feature_1.npy
|
||||
2, image_feature_2, images/image_2_feature_2.npy
|
||||
2, label, , 1|2
|
||||
3, image_feature_1, images/image_3_feature_1.npy
|
||||
3, image_feature_2, images/image_3_feature_2.npy
|
||||
3, label, , 1
|
||||
4, image_feature_1, images/image_4_feature_1.npy
|
||||
4, image_feature_2, images/image_4_feature_2.npy
|
||||
4, label, ,
|
||||
```
|
||||
Note that the label field for sample 4 is left empty, this indicates that all labels are negative in Sample 4.
|
||||
In multi-label tasks, the negative class (all ground truth classes being false for a sample) should not be
|
||||
considered a separate class, and should be encoded by an empty label field.
|
||||
|
||||
The labels which are true for each sample in the `dataset.csv` shown above are:
|
||||
* Sample 1: 0, 2, 3
|
||||
* Sample 2: 1, 2
|
||||
* Sample 3: 1
|
||||
* Sample 4: No labels are true for this sample
|
||||
|
||||
The config file would be
|
||||
|
||||
```python
|
||||
class GlaucomaPublicExt(GlaucomaPublic):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(azure_dataset_id="name_of_your_dataset_on_azure",
|
||||
subject_column="SubjectID",
|
||||
channel_column="Channel",
|
||||
image_channels=["image_feature_1", "image_feature_2"],
|
||||
image_file_column="FilePath",
|
||||
label_channels=["label"],
|
||||
label_value_column="Label",
|
||||
class_names=["class0", "class1", "class2", "class3"])
|
||||
```
|
||||
|
||||
The added parameter `class_names` gives the string name corresponding to each ground truth class index.
|
||||
In multi-label configs, the `class_names` parameter must be specified, so that InnerEye can recognize that the task is
|
||||
a multi-label task and parse the `dataset.csv` accordingly. In binary tasks, the class_names field can optionally be
|
||||
set to a list with a single string in it corresponding to the name of the positive class.
|
||||
|
|
Загрузка…
Ссылка в новой задаче