Support Multilabel classification models with png files (#391)

- This PR adds support for multilabel classification models. Multiclass models (where labels are mutually exclusive) are not supported. Changes made:
  - `dataset.csv` can include multiple labels per sample
  - Loss functions changed to support multilabel classification tasks
  - Extra report for multilabel classification tasks
- Add support for png images

Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
Co-authored-by: melanibe <32590828+melanibe@users.noreply.github.com>
Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Javier 2021-03-19 16:03:05 +00:00 коммит произвёл GitHub
Родитель 6f475ffe4c
Коммит 917f8d0b30
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
35 изменённых файлов: 1987 добавлений и 465 удалений

Просмотреть файл

@ -21,6 +21,13 @@ nodes in AzureML. Example: Add `--num_nodes=2` to the commandline arguments to t
if `CHANGELOG.md` has been modified.
- ([#412](https://github.com/microsoft/InnerEye-DeepLearning/pull/412)) Dataset files can now have arbitrary names, and are no longer restricted to be called
`dataset.csv`, via the config field `dataset_csv`. This allows to have a single set of image files in a folder, but multiple datasets derived from it.
- ([#391](https://github.com/microsoft/InnerEye-DeepLearning/pull/391)) Support for multilabel classification tasks.
Multilabel models can be trained by adding the parameter `class_names` to the config for classification models.
`class_names` should contain the name of each label class in the dataset, and the order of names should match the
order of class label indices in `dataset.csv`.
`dataset.csv` supports multiple labels (indices corresponding to `class_names`) per subject in the label column.
Multiple labels should be encoded as a string with labels separated by a `|`, for example "0|2|4".
Note that this PR does not add support for multiclass models, where the labels are mutually exclusive.
### Changed
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) Starting an AzureML run now uses the

Просмотреть файл

@ -0,0 +1,50 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any
import pandas as pd
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.split_dataset import DatasetSplits
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
class DummyMulticlassClassification(ScalarModelBase):
"A config file for dummy image classification model for debugging purposes"
def __init__(self) -> None:
num_epochs = 4
super().__init__(
local_dataset=full_ml_test_data_path("classification_data_multiclass"),
image_channels=["blue"],
image_file_column="path",
label_channels=["blue"],
class_names=["class0", "class1", "class2", "class3", "class4"],
label_value_column="label",
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
num_epochs=num_epochs,
num_dataload_workers=0,
use_mixed_precision=False,
subject_column="ID",
image_size=(4, 5, 7)
)
self.conv_in_3d = True
self.expected_image_size_zyx = (4, 5, 7)
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
self.max_num_gpus = 1
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return DatasetSplits.from_proportions(
df=dataset_df,
proportion_train=0.7,
proportion_test=0.2,
proportion_val=0.1,
subject_column=self.subject_column
)
def create_model(self) -> Any:
# Use a local import so that we don't need to import pytorch when creating configs in the runner
from Tests.ML.models.architectures.DummyScalarModel import DummyScalarModel
return DummyScalarModel(self.expected_image_size_zyx, num_classes=len(self.class_names))

Просмотреть файл

@ -5,6 +5,7 @@
import logging
import math
import sys
import typing
from abc import abstractmethod
from collections import Counter, defaultdict
from multiprocessing import cpu_count
@ -16,7 +17,6 @@ import pandas as pd
import torch
from joblib import Parallel, delayed
from more_itertools import flatten
from rich.progress import track
from InnerEye.ML.dataset.full_image_dataset import GeneralDataset
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
@ -31,57 +31,87 @@ from InnerEye.ML.utils.transforms import Compose3D, Transform3D
T = TypeVar('T', bound=ScalarDataSource)
def extract_label_classification(label_string: Union[str, float], sample_id: str) -> Union[float, int]:
def extract_label_classification(label_string: str, sample_id: str, num_classes: int,
is_classification_dataset: bool) -> List[float]:
"""
Converts a string from a dataset.csv file that contains a model's label to a scalar.
The function maps ["1", "true", "yes"] to 1, ["0", "false", "no"] to 0.
If the entry in the CSV file was missing (no string given at all), it returns math.nan.
:param label_string: The value of the label as read from CSV via a DataFrame.
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
:return:
"""
if isinstance(label_string, float):
if math.isnan(label_string):
# When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
return label_string
else:
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
f"dataframe column as a string?")
if label_string:
label_lower = label_string.lower()
if label_lower in ["1", "true", "yes"]:
return 1
if label_lower in ["0", "false", "no"]:
return 0
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
else:
return math.nan
For classification datasets:
If num_classes is 1 (binary classification tasks):
The function maps ["1", "true", "yes"] to [1], ["0", "false", "no"] to [0].
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
If num_classes is greater than 1 (multilabel datasets):
The function maps a pipe-separated set of classes to a tensor with ones at the indices
of the positive classes and 0 elsewhere (for example if we have a task with 6 label classes,
map "1|3|4" to [0, 1, 0, 1, 1, 0]).
If the entry in the CSV file was missing (no string given at all) or an empty string,
this function returns an all-zero tensor (none of the label classes were positive for this sample).
def extract_label_regression(label_string: Union[str, float], sample_id: str) -> Union[float, int]:
"""
Converts a string from a dataset.csv file that contains a model's label to a scalar.
For regression datasets:
The function casts a string label to float. Raises an exception if the conversion is
not possible.
If the entry in the CSV file was missing (no string given at all), it returns math.nan.
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
:param label_string: The value of the label as read from CSV via a DataFrame.
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
:return:
:param num_classes: Number of classes. This should be equal the size of the model output.
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes should
correspond to the number of label classes in the problem.
:param is_classification_dataset: If the model is a classification model
:return: A list of floats with the same size as num_classes
"""
if num_classes < 1:
raise ValueError(f"Subject {sample_id}: Invalid number of classes: '{num_classes}'")
if isinstance(label_string, float):
if math.isnan(label_string):
# When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
return label_string
if num_classes == 1:
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
return [label_string]
else:
return [0] * num_classes
else:
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
f"dataframe column as a string?")
if label_string:
try:
return float(label_string)
except ValueError:
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
if not label_string:
if not is_classification_dataset or num_classes == 1:
return [math.nan]
else:
return [0] * num_classes
if is_classification_dataset:
if num_classes == 1:
label_lower = label_string.lower()
if label_lower in ["true", "yes"]:
return [1.0]
elif label_lower in ["false", "no"]:
return [0.0]
elif label_string in ["0", "1"]:
return [float(label_string)]
else:
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'. "
f"Should be one of true/false, yes/no or 0/1.")
if '|' in label_string or label_string.isdigit():
classes = [int(a) for a in label_string.split('|')]
out_of_range = [_class for _class in classes if _class >= num_classes]
if out_of_range:
raise ValueError(f"Subject {sample_id}: Indices {out_of_range} are out of range, for number of classes "
f"= {num_classes}")
one_hot_array = np.zeros(num_classes, dtype=np.float)
one_hot_array[classes] = 1.0
return one_hot_array.tolist()
else:
return math.nan
try:
return [float(label_string)]
except ValueError:
pass
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'")
def _get_single_channel_row(subject_rows: pd.DataFrame,
@ -149,10 +179,12 @@ def load_single_data_source(subject_rows: pd.DataFrame,
categorical_data_encoder: Optional[CategoricalToOneHotEncoder] = None,
metadata_columns: Optional[Set[str]] = None,
is_classification_dataset: bool = True,
num_classes: int = 1,
sequence_position_numeric: Optional[int] = None) -> T:
"""
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
labels, the non-image features, and the paths to the image files.
:param num_classes: Number of classes, this is equivalent to model output tensor size
:param channel_column: The name of the column that contains the row identifier ("channels")
:param metadata_columns: A list of columns that well be added to the item metadata as key/value pairs.
:param subject_rows: All dataset rows that belong to the same subject.
@ -179,11 +211,12 @@ def load_single_data_source(subject_rows: pd.DataFrame,
return _get_single_channel_row(subject_rows, channel, subject_id, channel_column)
def _get_label_as_tensor(channel: Optional[str]) -> torch.Tensor:
extract_fn = extract_label_classification if is_classification_dataset else extract_label_regression
label_row = _get_row_for_channel(channel)
label_string = label_row[label_value_column]
return torch.tensor([extract_fn(label_string=label_string, sample_id=subject_id)],
dtype=torch.float)
return torch.tensor(
extract_label_classification(label_string=label_string, sample_id=subject_id, num_classes=num_classes,
is_classification_dataset=is_classification_dataset),
dtype=torch.float)
def _apply_label_transforms(labels: Any) -> Any:
"""
@ -313,6 +346,7 @@ class DataSourceReader(Generic[T]):
subject_column: str = CSV_SUBJECT_HEADER,
channel_column: str = CSV_CHANNEL_HEADER,
is_classification_dataset: bool = True,
num_classes: int = 1,
categorical_data_encoder: Optional[CategoricalToOneHotEncoder] = None):
"""
:param label_value_column: The column that contains the value for the label scalar or vector.
@ -345,6 +379,7 @@ class DataSourceReader(Generic[T]):
self.image_file_column = image_file_column
self.label_value_column = label_value_column
self.data_frame = data_frame
self.num_classes = num_classes
self.expected_non_image_channels: Union[List[None], Set[str]]
if self.non_image_feature_channels is None:
@ -418,6 +453,7 @@ class DataSourceReader(Generic[T]):
sequence_column=sequence_column,
subject_column=args.subject_column,
channel_column=args.channel_column,
num_classes=len(args.class_names),
is_classification_dataset=args.is_classification_model
).load_data_sources(num_dataset_reader_workers=args.num_dataset_reader_workers)
@ -444,7 +480,7 @@ class DataSourceReader(Generic[T]):
_n_jobs = max(1, num_dataset_reader_workers)
results = Parallel(n_jobs=_n_jobs, backend=_backend)(
delayed(self.load_datasources_for_subject)(subject_id) for subject_id in track(subject_ids))
delayed(self.load_datasources_for_subject)(subject_id) for subject_id in subject_ids)
return list(flatten(filter(None, results)))
@ -468,6 +504,7 @@ class DataSourceReader(Generic[T]):
metadata_columns=self.metadata_columns,
channel_column=self.channel_column,
is_classification_dataset=self.is_classification_dataset,
num_classes=self.num_classes,
sequence_position_numeric=_sequence_position_numeric
)
@ -745,15 +782,27 @@ class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
Returns a list of all the labels in the dataset. Used to compute
the sampling weights in Imbalanced Sampler
"""
if len(self.args.class_names) > 1:
raise NotImplementedError("ImbalancedSampler is not supported for multilabel tasks.")
return [item.label.item() for item in self.items]
def get_class_counts(self) -> Dict:
def get_class_counts(self) -> Dict[int, int]:
"""
Return class weights that are proportional to the inverse frequency of label counts.
:return: Dictionary of {"label": count}
Return the label counts as a dictionary with the key-value pairs being the class indices and per-class counts.
In the binary case, the dictionary will have a single element. The key will be 0 as there is only one class and
one class index. The value stored will be the number of samples that belong to the positive class.
In the multilabel case, this returns a dictionary with class indices and samples per class as the key-value
pairs.
:return: Dictionary of {class_index: count}
"""
all_labels = [item.label.item() for item in self.items] # [N, 1]
return dict(Counter(all_labels))
all_labels = [torch.flatten(torch.nonzero(item.label).int()).tolist() for item in self.items] # [N, 1]
flat_list = list(flatten(all_labels))
freq_iter: typing.Counter = Counter()
freq_iter.update({x: 0 for x in range(len(self.args.class_names))})
freq_iter.update(flat_list)
result = dict(freq_iter)
return result
def __len__(self) -> int:
return len(self.items)

Просмотреть файл

@ -135,19 +135,8 @@ class ScalarDataSource(ScalarItemBase):
:return: An instance of ClassificationItem, with the same label and numerical_non_image_features fields,
and all images loaded.
"""
full_channel_files: List[Path] = []
for f in self.channel_files:
if f is None:
raise ValueError("When loading images, channel_files should no longer contain None entries.")
elif file_mapping:
if f in file_mapping:
full_channel_files.append(file_mapping[str(f)])
else:
raise ValueError(f"File mapping does not contain an entry for {f}")
elif root_path:
full_channel_files.append(root_path / f)
else:
raise ValueError("One of the arguments 'file_mapping' or 'root_path' must be given.")
full_channel_files = self.get_all_image_filepaths(root_path=root_path,
file_mapping=file_mapping)
imaging_data = load_images_and_stack(files=full_channel_files,
load_segmentation=load_segmentation,
@ -175,6 +164,47 @@ class ScalarDataSource(ScalarItemBase):
def files_valid(self) -> bool:
return not any(f is None for f in self.channel_files)
def get_all_image_filepaths(self,
root_path: Optional[Path],
file_mapping: Optional[Dict[str, Path]]) -> List[Path]:
"""
Get a list of image paths for the object. Either root_path or file_mapping must be specified.
:param root_path: The root path where all channel files for images are expected. This is ignored if
file_mapping is given.
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
"""
full_channel_files: List[Path] = []
for f in self.channel_files:
if not f:
raise ValueError(f"Got invalid file path: {f}")
full_channel_files.append(self.get_full_image_filepath(f, root_path, file_mapping))
return full_channel_files
@staticmethod
def get_full_image_filepath(file: str,
root_path: Optional[Path],
file_mapping: Optional[Dict[str, Path]]) -> Path:
"""
Get the full path of an image file given the path relative to the dataset folder and one of
root_path or file_mapping.
:param file: Image filepath relative to the dataset folder
:param root_path: The root path where all channel files for images are expected. This is ignored if
file_mapping is given.
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
"""
if file is None:
raise ValueError("When loading images, channel_files should no longer contain None entries.")
elif file_mapping:
if file in file_mapping:
return file_mapping[file]
else:
raise ValueError(f"File mapping does not contain an entry for {file}")
elif root_path:
return root_path / file
else:
raise ValueError("One of the arguments 'file_mapping' or 'root_path' must be given.")
@dataclass(frozen=True)
class SequenceDataSource(ScalarDataSource):

Просмотреть файл

@ -233,6 +233,9 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
"sequence index should be read from.")
if len(self.args.class_names) > 1:
raise ValueError("Multilabel configs not supported for sequence datasets.")
data_sources = self.load_all_data_sources()
grouped = group_samples_into_sequences(
data_sources,
@ -278,16 +281,18 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
return [seq.get_labels_at_target_indices(self.args.get_target_indices())[-1].item()
for seq in self.items]
def get_class_counts(self) -> Dict:
def get_class_counts(self) -> Dict[int, int]:
"""
Return class weights that are proportional to the inverse frequency of label counts (summed
over all target indices).
Return the label counts (summed over all target indices).
:return: Dictionary of {"label": count}
"""
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
for seq in self.items]) # [N, T, 1]
non_nan_labels = list(filter(lambda x: not np.isnan(x), all_labels_per_target.flatten().tolist()))
return dict(Counter(non_nan_labels))
non_nan_and_nonzero_labels = list(filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
counts = dict(Counter(non_nan_and_nonzero_labels))
if not len(counts.keys()) == 1 or 1 not in counts.keys():
raise ValueError("get_class_counts supports only binary targets.")
return {0: counts[1]}
def __len__(self) -> int:
return len(self.items)

Просмотреть файл

@ -23,7 +23,6 @@ from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode,\
create_recovery_checkpoint_path, create_unique_timestamp_id,\
get_best_checkpoint_path
VISUALIZATION_FOLDER = "Visualizations"
# A folder inside of the outputs folder that will contain all information for running the model in inference mode
FINAL_MODEL_FOLDER = "final_model"
FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
@ -31,6 +30,8 @@ FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
# The checkpoints must be stored inside of the final model folder, if we want to avoid copying
# them before registration.
CHECKPOINT_FOLDER = "checkpoints"
VISUALIZATION_FOLDER = "visualizations"
ARGS_TXT = "args.txt"
WEIGHTS_FILE = "weights.pth"

Просмотреть файл

@ -185,9 +185,10 @@ class ScalarLightning(InnerEyeLightning):
else:
self.loss_fn = raw_loss
self.target_indices = []
self.target_names = [MetricsDict.DEFAULT_HUE_KEY]
self.target_names = config.class_names
self.is_classification_model = config.is_classification_model
self.use_mean_teacher_model = config.compute_mean_teacher_model
self.is_binary_classification_or_regression = True if len(config.class_names) == 1 else False
self.logits_to_posterior_fn = config.get_post_loss_logits_normalization_function()
self.loss_type = config.loss_type
# These two fields store the PyTorch Lightning Metrics objects that will compute metrics on validation
@ -338,7 +339,8 @@ class ScalarLightning(InnerEyeLightning):
metric_computers = self.train_metric_computers if is_training else self.val_metric_computers
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
for prediction_target, metric_list in metric_computers.items():
target_suffix = "" if prediction_target == MetricsDict.DEFAULT_HUE_KEY else f"/{prediction_target}"
target_suffix = "" if (prediction_target == MetricsDict.DEFAULT_HUE_KEY
or self.is_binary_classification_or_regression) else f"/{prediction_target}"
for metric in metric_list:
if metric.has_predictions:
# Sequence models can have no predictions at all for particular positions, depending on the data.

Просмотреть файл

@ -339,12 +339,22 @@ def store_epoch_metrics(metrics: DictStrFloat,
"""
logger_row = {}
for key, value in metrics.items():
if key == MetricType.SECONDS_PER_BATCH.value or key == MetricType.SECONDS_PER_EPOCH.value:
continue
if key in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[key].value] = value
tokens = key.split("/")
if len(tokens) == 1:
metric_name = tokens[0]
hue_suffix = ""
elif len(tokens) == 2:
metric_name = tokens[0]
hue_suffix = "/" + tokens[1]
else:
logger_row[key] = value
raise ValueError(f"Expected key to have format 'metric_name[/optional_suffix_for_hue]', got {key}")
if metric_name == MetricType.SECONDS_PER_BATCH.value or metric_name == MetricType.SECONDS_PER_EPOCH.value:
continue
if metric_name in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[metric_name].value + hue_suffix] = value
else:
logger_row[metric_name + hue_suffix] = value
logger_row[LoggingColumns.Epoch.value] = epoch
file_logger.add_record(logger_row)
file_logger.flush()

Просмотреть файл

@ -192,9 +192,11 @@ class MetricsDict:
default hue.
:param is_classification_metrics: If this is a classification metrics dict
"""
if hues and MetricsDict.DEFAULT_HUE_KEY in hues:
hues.remove(MetricsDict.DEFAULT_HUE_KEY)
self.hues_without_default = hues or []
_hues = hues.copy() if hues else None
if _hues and MetricsDict.DEFAULT_HUE_KEY in _hues:
_hues.remove(MetricsDict.DEFAULT_HUE_KEY)
self.hues_without_default = _hues or []
_hue_keys = self.hues_without_default + [MetricsDict.DEFAULT_HUE_KEY]
self.hues: OrderedDict[str, Hue] = OrderedDict([(x, Hue(name=x)) for x in _hue_keys])
self.skip_nan_when_averaging: Dict[str, bool] = dict()

Просмотреть файл

@ -387,7 +387,8 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
return SequenceMetricsDict.create(is_classification_model=config.is_classification_model,
sequence_target_positions=config.sequence_target_positions)
else:
return ScalarMetricsDict(is_classification_metrics=config.is_classification_model)
return ScalarMetricsDict(hues=config.class_names,
is_classification_metrics=config.is_classification_model)
def classification_model_test(config: ScalarModelBase,

Просмотреть файл

@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"%%javascript\n",
"IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
" return false;\n",
"}\n",
"// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
"# This cell must carry the tag \"parameters\" in its metadata.\n",
"from pathlib import Path\n",
"import pickle\n",
"import codecs\n",
"\n",
"innereye_path = Path.cwd().parent.parent.parent\n",
"train_metrics_csv = \"\"\n",
"val_metrics_csv = \"\"\n",
"test_metrics_csv = \"\"\n",
"number_best_and_worst_performing = 20\n",
"config= \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import sys\n",
"print(f\"Adding to path: {innereye_path}\")\n",
"if str(innereye_path) not in sys.path:\n",
" sys.path.append(str(innereye_path))\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
"\n",
"from InnerEye.ML.reports.notebook_report import print_header\n",
"from InnerEye.ML.reports.classification_multilabel_report import print_metrics_for_thresholded_output_for_all_prediction_targets\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"plt.rcParams['figure.figsize'] = (20, 10)\n",
"\n",
"#convert params to Path\n",
"train_metrics_csv = Path(train_metrics_csv)\n",
"val_metrics_csv = Path(val_metrics_csv)\n",
"test_metrics_csv = Path(test_metrics_csv)"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"# Train Metrics (for label combinations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"if train_metrics_csv.is_file():\n",
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=train_metrics_csv,\n",
" test_metrics_csv=train_metrics_csv,\n",
" config=config)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"# Validation Metrics (for label combinations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=val_metrics_csv,\n",
" test_metrics_csv=val_metrics_csv,\n",
" config=config)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"# Test Metrics (for label combinations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv=val_metrics_csv,\n",
" test_metrics_csv=test_metrics_csv,\n",
" config=config)"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Просмотреть файл

@ -0,0 +1,175 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import List, Set, FrozenSet
import pandas as pd
import torch
import math
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.reports.classification_report import LabelsAndPredictions, print_metrics, get_labels_and_predictions, \
get_metric, ReportedMetrics
from InnerEye.ML.reports.notebook_report import print_header
def get_unique_prediction_target_combinations(config: ScalarModelBase) -> Set[FrozenSet[str]]:
"""
Get a list of all the combinations of labels that exist in the dataset.
For multilabel classification tasks, this function will return all unique combinations of labels that
occur in the dataset csv.
For example, if there are 6 samples in the dataset with the following ground truth labels
Sample1: class1, class2
Sample2: class0
Sample3: class1
Sample4: class2, class3
Sample5: (all label classes are negative in Sample 5)
Sample6: class1, class2
This function will return {{"class1", "class2"}, {"class0"}, {"class1"}, {"class2", "class3"}, {}}
For binary classification tasks (assume class_names has not been changed from ["Default"]):
This function will return a set with two members - {{"Default"}, {}} if there is at least one positive example
in the dataset. If there are no positive examples, it returns {{}}.
"""
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
all_labels = [torch.flatten(torch.nonzero(item.label)).tolist() for item in dataset.items]
label_set = set(frozenset([config.class_names[i] for i in labels if not math.isnan(i)])
for labels in all_labels)
return label_set
def get_dataframe_with_exact_label_matches(metrics_df: pd.DataFrame,
prediction_target_set_to_match: List[str],
all_prediction_targets: List[str],
thresholds_per_prediction_target: List[float]) -> pd.DataFrame:
"""
Given a set of prediction targets (a subset of the classes in the classification task), for each sample find
(i) if the set of ground truth labels matches this set exactly,
(ii) if the predicted model outputs (after thresholding) match this set exactly
Generates an output dataframe with the rows:
LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput, LoggingColumns.Hue
The output dataframe is generated according to the following rules:
- LoggingColumns.Patient: For each sample, the sample id is copied over into this field
- LoggingColumns.Label: For each sample, this field is set to 1 if the ground truth value is true
for every prediction target (i.e. every class) in the given set and false for all other prediction
targets. It is set to 0 otherwise.
- LoggingColumns.ModelOutput: For each sample, this field is set to 1 if the model predicts a value exceeding the
prediction target threshold for every prediction target in the given set and lower for all other prediction
targets. It is set to 0 otherwise.
- LoggingColumns.Hue: For every sample, this is set to "|".join(prediction_target_set_to_match)
:param metrics_df: Dataframe with the model predictions (read from the csv written by the inference pipeline)
The dataframe must have at least the following columns (defined in the LoggingColumns enum):
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
Any other columns will be ignored.
:param prediction_target_set_to_match: The set of prediction targets to which each sample is compared
:param all_prediction_targets: The entire set of prediction targets on which the model is trained
:param thresholds_per_prediction_target: Thresholds per prediction target to decide if model has predicted True or
False for the specific prediction target
:return: Dataframe with generated label and model outputs per sample
"""
def get_exact_label_match(df: pd.DataFrame) -> pd.DataFrame:
values_to_return = {LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]]}
pred_positives = df[df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.ModelOutput.value].values
pred_negatives = df[~df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.ModelOutput.value].values
if all(pred_positives) and not any(pred_negatives):
values_to_return[LoggingColumns.ModelOutput.value] = [1]
else:
values_to_return[LoggingColumns.ModelOutput.value] = [0]
true_positives = df[df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.Label.value].values
true_negatives = df[~df[LoggingColumns.Hue.value].isin(prediction_target_set_to_match)][LoggingColumns.Label.value].values
if all(true_positives) and not any(true_negatives):
values_to_return[LoggingColumns.Label.value] = [1]
else:
values_to_return[LoggingColumns.Label.value] = [0]
return pd.DataFrame.from_dict(values_to_return)
df = metrics_df.copy()
for i in range(len(thresholds_per_prediction_target)):
df_for_prediction_target = df[LoggingColumns.Hue.value] == all_prediction_targets[i]
df.loc[df_for_prediction_target, LoggingColumns.ModelOutput.value] = \
df.loc[df_for_prediction_target, LoggingColumns.ModelOutput.value] > thresholds_per_prediction_target[i]
df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_exact_label_match).reset_index(drop=True)
df[LoggingColumns.Hue.value] = ["|".join(prediction_target_set_to_match)] * len(df)
return df
def get_labels_and_predictions_for_prediction_target_set(csv: Path,
prediction_target_set_to_match: List[str],
all_prediction_targets: List[str],
thresholds_per_prediction_target: List[float]) -> LabelsAndPredictions:
"""
Given a CSV file, generate a set of labels and model predictions for the given set of prediction targets
(in other words, for the given subset of the classes in the classification task).
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
like the ones written while training.
"""
metrics_df = pd.read_csv(csv)
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
prediction_target_set_to_match=prediction_target_set_to_match,
all_prediction_targets=all_prediction_targets,
thresholds_per_prediction_target=thresholds_per_prediction_target)
labels = df[LoggingColumns.Label.value].to_numpy()
model_outputs = df[LoggingColumns.ModelOutput.value].to_numpy()
subjects = df[LoggingColumns.Patient.value].to_numpy()
return LabelsAndPredictions(subject_ids=subjects, labels=labels, model_outputs=model_outputs)
def print_metrics_for_thresholded_output_for_all_prediction_targets(val_metrics_csv: Path,
test_metrics_csv: Path,
config: ScalarModelBase) -> None:
"""
Given csvs written during inference for the validation and test sets, print out metrics for every combination of
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).
:param val_metrics_csv: Csv written during inference time for the val set. This is used to determine the
optimal threshold for classification.
:param test_metrics_csv: Csv written during inference time for the test set. Metrics are calculated for this csv.
:param config: Model config
"""
unique_prediction_target_combinations = get_unique_prediction_target_combinations(config)
all_prediction_target_combinations = list(set([frozenset([prediction_target])
for prediction_target in config.class_names])
| unique_prediction_target_combinations)
thresholds_per_prediction_target = []
for label in config.class_names:
val_metrics = get_labels_and_predictions(val_metrics_csv, label)
test_metrics = get_labels_and_predictions(test_metrics_csv, label)
thresholds_per_prediction_target.append(get_metric(val_labels_and_predictions=val_metrics,
test_labels_and_predictions=test_metrics,
metric=ReportedMetrics.OptimalThreshold))
for labels in all_prediction_target_combinations:
print_header(f"Class {'|'.join(labels) or 'Negative'}", level=3)
val_metrics = get_labels_and_predictions_for_prediction_target_set(
csv=val_metrics_csv,
prediction_target_set_to_match=list(labels),
all_prediction_targets=config.class_names,
thresholds_per_prediction_target=thresholds_per_prediction_target)
test_metrics = get_labels_and_predictions_for_prediction_target_set(
csv=test_metrics_csv,
prediction_target_set_to_match=list(labels),
all_prediction_targets=config.class_names,
thresholds_per_prediction_target=thresholds_per_prediction_target)
print_metrics(val_labels_and_predictions=val_metrics, test_labels_and_predictions=test_metrics, is_thresholded=True)

Просмотреть файл

@ -31,14 +31,15 @@
"# Default parameter values. They will be overwritten by papermill notebook parameters.\n",
"# This cell must carry the tag \"parameters\" in its metadata.\n",
"from pathlib import Path\n",
"import pickle\n",
"import codecs\n",
"\n",
"innereye_path = Path.cwd().parent.parent.parent\n",
"train_metrics_csv = \"\"\n",
"val_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'val_metrics_classification.csv'\n",
"test_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'test_metrics_classification.csv'\n",
"number_best_and_worst_performing = 20\n",
"dataset_csv_path=innereye_path / 'Tests' / 'ML' / 'reports' / 'dataset.csv'\n",
"dataset_subject_column=\"subject\"\n",
"dataset_file_column=\"filePath\""
"config= \"\""
]
},
{
@ -52,15 +53,20 @@
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"\n",
"print(f\"Adding to path: {innereye_path}\")\n",
"if str(innereye_path) not in sys.path:\n",
" sys.path.append(str(innereye_path))\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n",
"\n",
"from InnerEye.ML.reports.notebook_report import print_header\n",
"from InnerEye.ML.reports.classification_report import plot_pr_and_roc_curves_from_csv, \\\n",
"print_k_best_and_worst_performing, print_metrics, plot_k_best_and_worst_performing\n",
"print_k_best_and_worst_performing, print_metrics_for_all_prediction_targets, \\\n",
"plot_k_best_and_worst_performing, get_labels_and_predictions\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
@ -69,8 +75,7 @@
"#convert params to Path\n",
"train_metrics_csv = Path(train_metrics_csv)\n",
"val_metrics_csv = Path(val_metrics_csv)\n",
"test_metrics_csv = Path(test_metrics_csv)\n",
"dataset_csv_path = Path(dataset_csv_path)"
"test_metrics_csv = Path(test_metrics_csv)"
]
},
{
@ -78,7 +83,7 @@
"id": "4",
"metadata": {},
"source": [
"# Validation Metrics"
"# Train Metrics"
]
},
{
@ -88,8 +93,9 @@
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv)"
"if train_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
@ -97,7 +103,7 @@
"id": "6",
"metadata": {},
"source": [
"# Test Metrics"
"# Validation Metrics"
]
},
{
@ -107,13 +113,34 @@
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv)"
"if val_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"# Test Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_metrics_for_all_prediction_targets(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" config=config, is_thresholded=False)"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {
"pycharm": {
"name": "#%% md\n"
@ -121,30 +148,7 @@
},
"source": [
"# AUC and PR curves\n",
"## Test Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"if test_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(test_metrics_csv)"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"## Validation set"
"## Train Set"
]
},
{
@ -158,20 +162,16 @@
},
"outputs": [],
"source": [
"if val_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(val_metrics_csv)"
"if train_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=train_metrics_csv, config=config)"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"metadata": {},
"source": [
"## Training set"
"## Validation set"
]
},
{
@ -185,28 +185,35 @@
},
"outputs": [],
"source": [
"if train_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(train_metrics_csv)"
"if val_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=val_metrics_csv, config=config)"
]
},
{
"cell_type": "markdown",
"id": "14",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Best and worst samples by ID"
"## Test set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, \n",
" k=number_best_and_worst_performing)"
"if test_metrics_csv.is_file():\n",
" plot_pr_and_roc_curves_from_csv(metrics_csv=test_metrics_csv, config=config)"
]
},
{
@ -214,7 +221,7 @@
"id": "16",
"metadata": {},
"source": [
"# Plot best and worst sample images"
"# Best and worst samples by ID"
]
},
{
@ -224,20 +231,35 @@
"metadata": {},
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file() and dataset_csv_path.is_file() and \\\n",
" dataset_subject_column and dataset_file_column:\n",
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, \n",
" k=number_best_and_worst_performing, dataset_csv_path=dataset_csv_path,\n",
" dataset_subject_column=dataset_subject_column, dataset_file_column=dataset_file_column)"
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing,\n",
" prediction_target=prediction_target)"
]
},
{
"cell_type": "markdown",
"id": "18",
"metadata": {},
"source": [
"# Plot best and worst sample images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"id": "19",
"metadata": {},
"outputs": [],
"source": []
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)"
]
}
],
"metadata": {
@ -262,4 +284,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

Просмотреть файл

@ -3,10 +3,11 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
import torch
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
@ -20,6 +21,8 @@ from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.ML.metrics_dict import MetricsDict, binary_classification_accuracy
from InnerEye.ML.reports.notebook_report import print_header
from InnerEye.ML.utils.io_util import load_image_in_known_formats
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
@dataclass
@ -49,19 +52,35 @@ class ReportedMetrics(Enum):
FalseNegativeRate = "false_negative_rate"
def get_results(csv: Path) -> LabelsAndPredictions:
def read_csv_and_filter_prediction_target(csv: Path, prediction_target: str) -> pd.DataFrame:
"""
Given a CSV file, reads the subject IDs, ground truth labels and model outputs for each subject.
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
like the ones written while training.
Given one of the csv files written during inference time, read it and select only those rows which belong to the
given prediction_target. Also check that there is only a single entry per prediction_target per subject in the file.
The csv must have at least the following columns (defined in the LoggingColumns enum):
LoggingColumns.Hue, LoggingColumns.Patient.
"""
df = pd.read_csv(csv)
labels = df[LoggingColumns.Label.value]
model_outputs = df[LoggingColumns.ModelOutput.value]
subjects = df[LoggingColumns.Patient.value]
if not subjects.is_unique:
df = df[df[LoggingColumns.Hue.value] == prediction_target] # Filter by prediction target
if not df[LoggingColumns.Patient.value].is_unique:
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
f"in column {LoggingColumns.Patient.value} in the csv file.")
return df
def get_labels_and_predictions(csv: Path, prediction_target: str) -> LabelsAndPredictions:
"""
Given a CSV file, reads the subject IDs, ground truth labels and model outputs for each subject
for the given prediction target.
NOTE: This CSV file should have results from a single epoch, as in the metrics files written during inference, not
like the ones written while training. It must have at least the following columns (defined in the LoggingColumns
enum):
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
"""
df = read_csv_and_filter_prediction_target(csv, prediction_target)
labels = df[LoggingColumns.Label.value].to_numpy()
model_outputs = df[LoggingColumns.ModelOutput.value].to_numpy()
subjects = df[LoggingColumns.Patient.value].to_numpy()
return LabelsAndPredictions(subject_ids=subjects, labels=labels, model_outputs=model_outputs)
@ -85,116 +104,172 @@ def plot_auc(x_values: np.ndarray, y_values: np.ndarray, title: str, ax: Axes, p
ax.annotate(f"{x:0.3f}, {y:0.3f}", xy=(x, y), xytext=(15, 0), textcoords='offset points')
def plot_pr_and_roc_curves_from_csv(metrics_csv: Path) -> None:
def plot_pr_and_roc_curves(labels_and_model_outputs: LabelsAndPredictions) -> None:
"""
Given a csv file, read the predicted values and ground truth labels and plot the ROC and PR curves.
Given a LabelsAndPredictions object, plot the ROC and PR curves.
"""
print_header("ROC and PR curves", level=3)
results = get_results(metrics_csv)
_, ax = plt.subplots(1, 2)
fpr, tpr, thresholds = roc_curve(results.labels, results.model_outputs)
fpr, tpr, thresholds = roc_curve(labels_and_model_outputs.labels, labels_and_model_outputs.model_outputs)
plot_auc(fpr, tpr, "ROC Curve", ax[0])
precision, recall, thresholds = precision_recall_curve(results.labels, results.model_outputs)
precision, recall, thresholds = precision_recall_curve(labels_and_model_outputs.labels,
labels_and_model_outputs.model_outputs)
plot_auc(recall, precision, "PR Curve", ax[1])
plt.show()
def get_metric(val_metrics_csv: Path, test_metrics_csv: Path, metric: ReportedMetrics) -> float:
def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase) -> None:
"""
Given a csv file, read the predicted values and ground truth labels and return the specified metric.
Given the csv written during inference time and the model config,
plot the ROC and PR curves for all prediction targets.
"""
results_val = get_results(val_metrics_csv)
fpr, tpr, thresholds = roc_curve(results_val.labels, results_val.model_outputs)
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
optimal_threshold = thresholds[optimal_idx]
for prediction_target in config.class_names:
print_header(f"Class {prediction_target}", level=3)
metrics = get_labels_and_predictions(metrics_csv, prediction_target)
plot_pr_and_roc_curves(metrics)
def get_metric(val_labels_and_predictions: LabelsAndPredictions,
test_labels_and_predictions: LabelsAndPredictions,
metric: ReportedMetrics,
optimal_threshold: Optional[float] = None) -> float:
"""
Given LabelsAndPredictions objects for the validation and test sets, return the specified metric.
:param val_labels_and_predictions: This set of ground truth labels and model predictions is used to determine the
optimal threshold for classification.
:param test_labels_and_predictions: The set of labels and model outputs to calculate metrics for.
:param metric: The name of the metric to calculate.
:param optimal_threshold: If provided, use this threshold instead of calculating an optimal threshold.
"""
if not optimal_threshold:
fpr, tpr, thresholds = roc_curve(val_labels_and_predictions.labels, val_labels_and_predictions.model_outputs)
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
optimal_threshold = thresholds[optimal_idx]
assert optimal_threshold # for mypy, we have already calculated optimal threshold if it was set to None
if metric is ReportedMetrics.OptimalThreshold:
return optimal_threshold
results_test = get_results(test_metrics_csv)
only_one_class_present = len(set(results_test.labels)) < 2
only_one_class_present = len(set(test_labels_and_predictions.labels)) < 2
if metric is ReportedMetrics.AUC_ROC:
return math.nan if only_one_class_present else roc_auc_score(results_test.labels, results_test.model_outputs)
return math.nan if only_one_class_present else roc_auc_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs)
elif metric is ReportedMetrics.AUC_PR:
if only_one_class_present:
return math.nan
precision, recall, _ = precision_recall_curve(results_test.labels, results_test.model_outputs)
precision, recall, _ = precision_recall_curve(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs)
return auc(recall, precision)
elif metric is ReportedMetrics.Accuracy:
return binary_classification_accuracy(model_output=results_test.model_outputs,
label=results_test.labels,
return binary_classification_accuracy(model_output=test_labels_and_predictions.model_outputs,
label=test_labels_and_predictions.labels,
threshold=optimal_threshold)
elif metric is ReportedMetrics.FalsePositiveRate:
tnr = recall_score(results_test.labels, results_test.model_outputs >= optimal_threshold, pos_label=0)
tnr = recall_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs >= optimal_threshold, pos_label=0)
return 1 - tnr
elif metric is ReportedMetrics.FalseNegativeRate:
return 1 - recall_score(results_test.labels, results_test.model_outputs >= optimal_threshold)
return 1 - recall_score(test_labels_and_predictions.labels, test_labels_and_predictions.model_outputs >= optimal_threshold)
else:
raise ValueError("Unknown metric")
def print_metrics(val_metrics_csv: Path, test_metrics_csv: Path) -> None:
def print_metrics(val_labels_and_predictions: LabelsAndPredictions,
test_labels_and_predictions: LabelsAndPredictions,
is_thresholded: bool = False) -> None:
"""
Given a csv file, read the predicted values and ground truth labels and print out some metrics.
Given LabelsAndPredictions objects for the validation and test sets, print out some metrics.
:param val_labels_and_predictions: LabelsAndPredictions object for the val set. This is used to determine the
optimal threshold for classification.
:param test_labels_and_predictions: LabelsAndPredictions object for the test set. Metrics are calculated for this
set.
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
or are floating point numbers.
:return:
"""
roc_auc = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.AUC_ROC)
print_header(f"Area under ROC Curve: {roc_auc:.4f}", level=4)
pr_auc = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.AUC_PR)
print_header(f"Area under PR Curve: {pr_auc:.4f}", level=4)
optimal_threshold = 0.5 if is_thresholded else None
optimal_threshold = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.OptimalThreshold)
if not is_thresholded:
roc_auc = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.AUC_ROC)
print_header(f"Area under ROC Curve: {roc_auc:.4f}", level=4)
print_header(f"Optimal threshold: {optimal_threshold: .4f}", level=4)
pr_auc = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.AUC_PR)
print_header(f"Area under PR Curve: {pr_auc:.4f}", level=4)
accuracy = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.Accuracy)
optimal_threshold = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.OptimalThreshold)
print_header(f"Optimal threshold: {optimal_threshold: .4f}", level=4)
accuracy = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.Accuracy,
optimal_threshold=optimal_threshold)
print_header(f"Accuracy at optimal threshold: {accuracy:.4f}", level=4)
fpr = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.FalsePositiveRate)
fpr = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.FalsePositiveRate,
optimal_threshold=optimal_threshold)
print_header(f"Specificity at optimal threshold: {1 - fpr:.4f}", level=4)
fnr = get_metric(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
metric=ReportedMetrics.FalseNegativeRate)
fnr = get_metric(val_labels_and_predictions=val_labels_and_predictions,
test_labels_and_predictions=test_labels_and_predictions,
metric=ReportedMetrics.FalseNegativeRate,
optimal_threshold=optimal_threshold)
print_header(f"Sensitivity at optimal threshold: {1 - fnr:.4f}", level=4)
print_header("", level=4)
def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path) -> Results:
def print_metrics_for_all_prediction_targets(val_metrics_csv: Path,
test_metrics_csv: Path,
config: ScalarModelBase,
is_thresholded: bool = False) -> None:
"""
Given csvs written during inference for the validation and test sets, print out metrics for every prediction target
in the config.
:param val_metrics_csv: Csv written during inference time for the val set. This is used to determine the
optimal threshold for classification.
:param test_metrics_csv: Csv written during inference time for the test set. Metrics are calculated for this csv.
:param config: Model config
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
or are floating point numbers.
"""
for prediction_target in config.class_names:
print_header(f"Class {prediction_target}", level=3)
val_metrics = get_labels_and_predictions(val_metrics_csv, prediction_target)
test_metrics = get_labels_and_predictions(test_metrics_csv, prediction_target)
print_metrics(val_labels_and_predictions=val_metrics, test_labels_and_predictions=test_metrics, is_thresholded=is_thresholded)
def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path,
prediction_target: str = "Default") -> Results:
"""
Given the paths to the metrics files for the validation and test sets, get a list of true positives,
false positives, false negatives and true negatives.
The threshold for classification is obtained by looking at the validation file, and applied to the test set to get
label predictions.
"""
df_val = pd.read_csv(val_metrics_csv)
The validation and test csvs must have at least the following columns (defined in the LoggingColumns enum):
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
if not df_val[LoggingColumns.Patient.value].is_unique:
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
f"in column {LoggingColumns.Patient.value} in the csv file.")
"""
df_val = read_csv_and_filter_prediction_target(val_metrics_csv, prediction_target)
fpr, tpr, thresholds = roc_curve(df_val[LoggingColumns.Label.value], df_val[LoggingColumns.ModelOutput.value])
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
optimal_threshold = thresholds[optimal_idx]
df_test = pd.read_csv(test_metrics_csv)
if not df_test[LoggingColumns.Patient.value].is_unique:
raise ValueError(f"Subject IDs should be unique, but found duplicate entries "
f"in column {LoggingColumns.Patient.value} in the csv file.")
df_test = read_csv_and_filter_prediction_target(test_metrics_csv, prediction_target)
df_test["predicted"] = df_test.apply(lambda x: int(x[LoggingColumns.ModelOutput.value] >= optimal_threshold),
axis=1)
@ -210,13 +285,15 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c
false_negatives=false_negatives)
def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int) -> Results:
def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int,
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results:
"""
Get the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
"""
results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv)
test_metrics_csv=test_metrics_csv,
prediction_target=prediction_target)
# sort by model_output
sorted = Results(true_positives=results.true_positives.sort_values(by=LoggingColumns.ModelOutput.value,
@ -230,14 +307,22 @@ def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pat
return sorted
def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int) -> None:
def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, prediction_target: str) -> None:
"""
Print the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
used to determine the thresholds for predicting labels on the test set. The best and worst
performing subjects will not be printed out for this csv.
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
best and worst performing subjects will be printed out.
:param k: Number of subjects of each category to print out.
:param prediction_target: The class label to filter on
"""
results = get_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
k=k)
k=k,
prediction_target=prediction_target)
print_header(f"Top {k} false positives", level=2)
for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value],
@ -261,33 +346,57 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: P
def get_image_filepath_from_subject_id(subject_id: str,
dataset_df: pd.DataFrame,
dataset_subject_column: str,
dataset_file_column: str,
dataset_dir: Path) -> Optional[Path]:
dataset: ScalarDataset,
config: ScalarModelBase) -> List[Path]:
"""
Returns the filepath for the image associated with a subject. If the subject is not found, return None.
If the csv contains multiple entries per subject (which may happen if the csv uses the channels column) then
return None as we do not support these csv types yet.
Return the filepaths for images associated with a subject. If the subject is not found, raises a ValueError.
:param subject_id: Subject to retrive image for
:param dataset_df: Dataset dataframe (from the datset.csv file)
:param dataset_subject_column: Name of the column with the subject IDs
:param dataset_file_column: Name of the column with the image filepaths
:param dataset_dir: Path to the dataset
:return: path to the image file for the patient or None if it is not found.
:param dataset: scalar dataset object
:param config: model config
:return: List of paths to the image files for the patient.
"""
for item in dataset.items:
if item.metadata.id == subject_id:
return item.get_all_image_filepaths(root_path=config.local_dataset,
file_mapping=dataset.file_to_full_path)
raise ValueError(f"Could not find subject {subject_id} in the dataset.")
def get_image_labels_from_subject_id(subject_id: str,
dataset: ScalarDataset,
config: ScalarModelBase) -> List[str]:
"""
Return the ground truth labels associated with a subject. If the subject is not found, raises a ValueError.
:param subject_id: Subject to retrive image for
:param dataset: scalar dataset object
:param config: model config
:return: List of labels for the patient.
"""
labels = None
for item in dataset.items:
if item.metadata.id == subject_id:
labels = torch.flatten(torch.nonzero(item.label)).tolist()
break
if labels is None:
raise ValueError(f"Could not find subject {subject_id} in the dataset.")
return [config.class_names[int(label)] for label in labels
if not math.isnan(label)]
def get_image_outputs_from_subject_id(subject_id: str,
metrics_df: pd.DataFrame) -> List[Tuple[str, int]]:
"""
Return a list of tuples (Label class name, model output for the class) for a single subject.
"""
if not dataset_df[dataset_subject_column].is_unique:
return None
dataset_df[dataset_subject_column] = dataset_df.apply(lambda x: str(x[dataset_subject_column]), axis=1)
if subject_id not in dataset_df[dataset_subject_column].unique():
return None
filtered = dataset_df[dataset_df[dataset_subject_column] == subject_id]
filepath = filtered.iloc[0][dataset_file_column]
return dataset_dir / Path(filepath)
filtered = metrics_df[metrics_df[LoggingColumns.Patient.value] == subject_id]
outputs = list(zip(filtered[LoggingColumns.Hue.value].values.tolist(),
filtered[LoggingColumns.ModelOutput.value].values.astype(float).tolist()))
return outputs
def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
@ -321,58 +430,81 @@ def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
def plot_image_for_subject(subject_id: str,
dataset_df: pd.DataFrame,
dataset_subject_column: str,
dataset_file_column: str,
dataset_dir: Path,
dataset: ScalarDataset,
im_width: int,
model_output: float,
header: Optional[str]) -> None:
header: Optional[str],
config: ScalarModelBase,
metrics_df: Optional[pd.DataFrame] = None) -> None:
"""
Given a subject ID, plots the corresponding image.
:param subject_id: Subject to plot image for
:param dataset_df: Dataset dataframe (from the datset.csv file)
:param dataset_subject_column: Name of the column with the subject IDs
:param dataset_file_column: Name of the column with the image filepaths
:param dataset_dir: Path to the dataset
:param dataset: scalar dataset object
:param im_width: Display width for image
:param model_output: The predicted value for this image
:param header: Optional header printed along with the subject ID and score for the image.
:param config: model config
:param metrics_df: dataframe with the metrics written out during inference time
"""
print_header("", level=4)
if header:
print_header(header, level=4)
print_header(f"ID: {subject_id} Score: {model_output}", level=4)
filepath = get_image_filepath_from_subject_id(subject_id=str(subject_id),
dataset_df=dataset_df,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column,
dataset_dir=dataset_dir)
if not filepath:
print_header(f"Subject ID {subject_id} not found, or found duplicate entries for this subject "
f"in column {dataset_subject_column} in the csv file. "
labels = get_image_labels_from_subject_id(subject_id=subject_id,
dataset=dataset,
config=config)
print_header(f"True labels: {', '.join(labels) if labels else 'Negative'}", level=4)
if metrics_df is not None:
all_model_outputs = get_image_outputs_from_subject_id(subject_id=subject_id,
metrics_df=metrics_df)
print_header(f"ID: {subject_id}", level=4)
print_header(f"Model output: {', '.join([':'.join([str(x) for x in output]) for output in all_model_outputs])}",
level=4)
else:
print_header(f"ID: {subject_id} Score: {model_output}", level=4)
filepaths = get_image_filepath_from_subject_id(subject_id=str(subject_id),
dataset=dataset,
config=config)
if not filepaths:
print_header(f"Subject ID {subject_id} not found."
f"Note: Reports with datasets that use channel columns in the dataset.csv "
f"are not yet supported.")
f"are not yet supported.", level=0)
return
success = plot_image_from_filepath(filepath, im_width=im_width)
if not success:
print_header("Unable to plot image: image must be 2D with shape [w, h] or [1, w, h].", level=0)
for filepath in filepaths:
success = plot_image_from_filepath(filepath, im_width=im_width)
if not success:
print_header("Unable to plot image: image must be 2D with shape [w, h] or [1, w, h].", level=0)
def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, dataset_csv_path: Path,
dataset_subject_column: str, dataset_file_column: str) -> None:
def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int,
prediction_target: str, config: ScalarModelBase) -> None:
"""
Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain)
and the top "k" worst predictions (i.e. misclassifications where the model was the most confident).
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
used to determine the thresholds for predicting labels on the test set. The best and worst
performing subjects will not be printed out for this csv.
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
best and worst performing subjects will be printed out.
:param k: Number of subjects of each category to print out.
:param prediction_target: The class label to filter on
:param config: scalar model config object
"""
results = get_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
k=k)
k=k,
prediction_target=prediction_target)
dataset_df = pd.read_csv(dataset_csv_path)
dataset_dir = dataset_csv_path.parent
test_metrics = pd.read_csv(test_metrics_csv, dtype=str)
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
im_width = 800
@ -381,46 +513,42 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa
for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value],
results.false_positives[LoggingColumns.ModelOutput.value])):
plot_image_for_subject(subject_id=str(subject),
dataset_df=dataset_df,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column,
dataset_dir=dataset_dir,
dataset=dataset,
im_width=im_width,
model_output=model_output,
header="False Positive")
header="False Positive",
config=config,
metrics_df=test_metrics)
print_header(f"Top {k} false negatives", level=2)
for index, (subject, model_output) in enumerate(zip(results.false_negatives[LoggingColumns.Patient.value],
results.false_negatives[LoggingColumns.ModelOutput.value])):
plot_image_for_subject(subject_id=str(subject),
dataset_df=dataset_df,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column,
dataset_dir=dataset_dir,
dataset=dataset,
im_width=im_width,
model_output=model_output,
header="False Negative")
header="False Negative",
config=config,
metrics_df=test_metrics)
print_header(f"Top {k} true positives", level=2)
for index, (subject, model_output) in enumerate(zip(results.true_positives[LoggingColumns.Patient.value],
results.true_positives[LoggingColumns.ModelOutput.value])):
plot_image_for_subject(subject_id=str(subject),
dataset_df=dataset_df,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column,
dataset_dir=dataset_dir,
dataset=dataset,
im_width=im_width,
model_output=model_output,
header="True Positive")
header="True Positive",
config=config,
metrics_df=test_metrics)
print_header(f"Top {k} true negatives", level=2)
for index, (subject, model_output) in enumerate(zip(results.true_negatives[LoggingColumns.Patient.value],
results.true_negatives[LoggingColumns.ModelOutput.value])):
plot_image_for_subject(subject_id=str(subject),
dataset_df=dataset_df,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column,
dataset_dir=dataset_dir,
dataset=dataset,
im_width=im_width,
model_output=model_output,
header="True Negative")
header="True Negative",
config=config,
metrics_df=test_metrics)

Просмотреть файл

@ -5,13 +5,40 @@
from pathlib import Path
from typing import Dict, Optional, Union
import codecs
import nbformat
import papermill
import pickle
from IPython.display import Markdown, display
from nbconvert import HTMLExporter
from nbconvert.writers import FilesWriter
from InnerEye.Common import fixed_paths
from InnerEye.ML.scalar_config import ScalarModelBase
REPORT_PREFIX = "report"
REPORT_IPYNB_SUFFIX = ".ipynb"
REPORT_HTML_SUFFIX = ".html"
reports_folder = "reports"
def get_ipynb_report_name(report_type: str) -> str:
"""
Constructs the name of the report (as an ipython notebook).
:param report_type: suffix describing the report, added to the filename
:return:
"""
return f"{REPORT_PREFIX}_{report_type}{REPORT_IPYNB_SUFFIX}"
def get_html_report_name(report_type: str) -> str:
"""
Constructs the name of the report (as an html file).
:param report_type: suffix describing the report, added to the filename
:return:
"""
return f"{REPORT_PREFIX}_{report_type}{REPORT_HTML_SUFFIX}"
def str_or_empty(p: Union[None, str, Path]) -> str:
@ -89,12 +116,10 @@ def generate_segmentation_notebook(result_notebook: Path,
def generate_classification_notebook(result_notebook: Path,
config: ScalarModelBase,
train_metrics: Optional[Path] = None,
val_metrics: Optional[Path] = None,
test_metrics: Optional[Path] = None,
dataset_csv_path: Optional[Path] = None,
dataset_subject_column: Optional[str] = None,
dataset_file_column: Optional[str] = None) -> Path:
test_metrics: Optional[Path] = None) -> Path:
"""
Creates a reporting notebook for a classification model, using the given training, validation, and test set metrics.
Returns the report file after HTML conversion.
@ -106,11 +131,35 @@ def generate_classification_notebook(result_notebook: Path,
'train_metrics_csv': str_or_empty(train_metrics),
'val_metrics_csv': str_or_empty(val_metrics),
'test_metrics_csv': str_or_empty(test_metrics),
'dataset_csv_path': str_or_empty(dataset_csv_path),
"dataset_subject_column": str_or_empty(dataset_subject_column),
"dataset_file_column": str_or_empty(dataset_file_column)
"config": codecs.encode(pickle.dumps(config), "base64").decode()
}
template = Path(__file__).absolute().parent / "classification_report.ipynb"
return generate_notebook(template,
notebook_params=notebook_params,
result_notebook=result_notebook)
def generate_classification_multilabel_notebook(result_notebook: Path,
config: ScalarModelBase,
train_metrics: Optional[Path] = None,
val_metrics: Optional[Path] = None,
test_metrics: Optional[Path] = None) -> Path:
"""
Creates a reporting notebook for a multilabel classification model, using the given training, validation,
and test set metrics. This report adds metrics specific to the multilabel task, and is meant to be used in
addition to the standard report created for all classification models.
Returns the report file after HTML conversion.
"""
notebook_params = \
{
'innereye_path': str(fixed_paths.repository_root_directory()),
'train_metrics_csv': str_or_empty(train_metrics),
'val_metrics_csv': str_or_empty(val_metrics),
'test_metrics_csv': str_or_empty(test_metrics),
"config": codecs.encode(pickle.dumps(config), "base64").decode()
}
template = Path(__file__).absolute().parent / "classification_multilabel_report.ipynb"
return generate_notebook(template,
notebook_params=notebook_params,
result_notebook=result_notebook)

Просмотреть файл

@ -40,9 +40,10 @@ from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import ModelInferenceConfig
from InnerEye.ML.model_testing import model_test
from InnerEye.ML.model_training import model_train
from InnerEye.ML.reports.notebook_report import generate_classification_notebook, generate_segmentation_notebook
from InnerEye.ML.runner import ModelDeploymentHookSignature, PostCrossValidationHookSignature, REPORT_HTML, \
REPORT_IPYNB, get_all_environment_files
from InnerEye.ML.reports.notebook_report import get_ipynb_report_name, generate_classification_notebook, \
generate_segmentation_notebook, \
generate_classification_multilabel_notebook, reports_folder
from InnerEye.ML.runner import ModelDeploymentHookSignature, PostCrossValidationHookSignature, get_all_environment_files
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import ml_util
@ -663,7 +664,8 @@ class MLRunner:
model_proc=ModelProcessing.ENSEMBLE_CREATION)
crossval_dir = self.plot_cross_validation_and_upload_results()
self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
if self.model_config.generate_report:
self.generate_report(ModelProcessing.ENSEMBLE_CREATION)
# CrossValResults should have been uploaded to the parent run, so we don't need it here.
remove_file_or_directory(crossval_dir)
# We can also remove OTHER_RUNS under the root, as it is no longer useful and only contains copies of files
@ -676,8 +678,7 @@ class MLRunner:
for subdir in other_runs_ensemble_dir.glob("*"):
if subdir.name not in [BASELINE_WILCOXON_RESULTS_FILE,
SCATTERPLOTS_SUBDIR_NAME,
REPORT_HTML,
REPORT_IPYNB]:
reports_folder]:
remove_file_or_directory(subdir)
PARENT_RUN_CONTEXT.upload_folder(name=BASELINE_COMPARISONS_FOLDER, path=str(other_runs_ensemble_dir))
else:
@ -724,21 +725,33 @@ class MLRunner:
output_dir = config.outputs_folder / OTHER_RUNS_SUBDIR_NAME / ENSEMBLE_SPLIT_NAME \
if model_proc == ModelProcessing.ENSEMBLE_CREATION else config.outputs_folder
reports_dir = output_dir / reports_folder
if not reports_dir.exists():
reports_dir.mkdir(exist_ok=False)
if config.model_category == ModelCategory.Segmentation:
generate_segmentation_notebook(result_notebook=output_dir / REPORT_IPYNB,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
generate_segmentation_notebook(
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
else:
if isinstance(config, ScalarModelBase) and not isinstance(config, SequenceModelBase):
dataset_csv_path = config.local_dataset / config.dataset_csv if config.local_dataset else None
generate_classification_notebook(result_notebook=output_dir / REPORT_IPYNB,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test,
dataset_csv_path=dataset_csv_path,
dataset_subject_column=config.subject_column,
dataset_file_column=config.image_file_column)
generate_classification_notebook(
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
config=config,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
if len(config.class_names) > 1:
generate_classification_multilabel_notebook(
result_notebook=reports_dir / get_ipynb_report_name(f"{config.model_category.value}_multilabel"),
config=config,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
else:
logging.info(f"Cannot create report for config of type {type(config)}.")
except Exception as ex:

Просмотреть файл

@ -41,9 +41,6 @@ from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.utils.config_util import ModelConfigLoader
REPORT_IPYNB = "report.ipynb"
REPORT_HTML = "report.html"
LOG_FILE_NAME = "stdout.txt"
PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None]

Просмотреть файл

@ -16,6 +16,7 @@ from InnerEye.Common.generic_parsing import ListOrDictParam
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.common import ModelExecutionMode, OneHotEncoderBase
from InnerEye.ML.deep_learning_config import ModelCategory
from InnerEye.ML.metrics_dict import MetricsDict
from InnerEye.ML.model_config_base import ModelConfigBase, ModelTransformsPerExecutionMode
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
from InnerEye.ML.utils.split_dataset import DatasetSplits
@ -102,6 +103,17 @@ class LabelTransformation(Enum):
class ScalarModelBase(ModelConfigBase):
class_names: List[str] = param.List(class_=str,
default=[MetricsDict.DEFAULT_HUE_KEY],
bounds=(1, None),
doc="The label names for each label class in the dataset and model output "
"in the case of binary and multi-label classification tasks."
"The order of the names should match the order of label class indices "
"in dataset.csv"
"For multi-label classification, this field is required."
"For binary classification, this field must be a list of size 1, and "
"is by default ['Default'], but can optionally be set to a more descriptive "
"name for the positive class.")
aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType,
doc="The type of global pooling aggregation to use between"
" the encoder and the classifier.")
@ -205,6 +217,10 @@ class ScalarModelBase(ModelConfigBase):
else:
self.num_dataset_reader_workers = num_dataset_reader_workers
def validate(self) -> None:
if len(self.class_names) > 1 and not self.is_classification_model:
raise ValueError("Multiple label classes supported only for classification tasks.")
@property
def is_classification_model(self) -> bool:
"""
@ -385,6 +401,12 @@ class ScalarModelBase(ModelConfigBase):
assert self._datasets_for_training is not None # for mypy
return self._datasets_for_training[ModelExecutionMode.TRAIN].get_class_counts()
def get_total_number_of_training_samples(self) -> int:
if self._datasets_for_training is None:
self.create_and_set_torch_datasets(for_inference=False)
assert self._datasets_for_training is not None # for mypy
return len(self._datasets_for_training[ModelExecutionMode.TRAIN])
def create_model(self) -> Any:
pass

Просмотреть файл

@ -384,6 +384,13 @@ def load_images_and_stack(files: Iterable[Path],
return ImageAndSegmentations(images=image_tensor, segmentations=segmentation_tensor)
def is_png(file: PathOrString) -> bool:
"""
Returns true if file is png
"""
return _file_matches_extension(file, [".png"])
def load_image_in_known_formats(file: Path,
load_segmentation: bool) -> ImageAndSegmentations[np.ndarray]:
"""
@ -404,6 +411,9 @@ def load_image_in_known_formats(file: Path,
return ImageAndSegmentations(images=load_numpy_image(path=file))
elif is_dicom_file_path(file):
return ImageAndSegmentations(images=load_dicom_image(path=file))
elif is_png(file):
image_with_header = load_image(path=file)
return ImageAndSegmentations(images=image_with_header.image)
else:
raise ValueError(f"Unsupported image file type for path {file}")
@ -463,6 +473,11 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW
image = load_hdf5_dataset_from_file(Path(h5_path), dataset)[channel]
header = get_unit_image_header()
return ImageWithHeader(image, header)
elif is_png(path):
import imageio
image = imageio.imread(path).astype(np.float)
header = get_unit_image_header()
return ImageWithHeader(image, header)
raise ValueError(f"Invalid file type {path}")

Просмотреть файл

@ -105,11 +105,14 @@ def create_scalar_loss_function(config: ScalarModelBase) -> torch.nn.Module:
Returns a torch module that computes a loss function for classification and regression models.
"""
if config.loss_type == ScalarLoss.BinaryCrossEntropyWithLogits:
return BinaryCrossEntropyWithLogitsLoss(smoothing_eps=config.label_smoothing_eps)
return BinaryCrossEntropyWithLogitsLoss(num_classes=len(config.class_names),
smoothing_eps=config.label_smoothing_eps)
if config.loss_type == ScalarLoss.WeightedCrossEntropyWithLogits:
return BinaryCrossEntropyWithLogitsLoss(
num_classes=len(config.class_names),
smoothing_eps=config.label_smoothing_eps,
class_counts=config.get_training_class_counts())
class_counts=config.get_training_class_counts(),
num_train_samples=config.get_total_number_of_training_samples())
elif config.loss_type == ScalarLoss.MeanSquaredError:
return MSELoss()
else:

Просмотреть файл

@ -37,7 +37,7 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
# Smooth the one-hot target: 1.0 becomes 1.0-eps, 0.0 becomes eps / (nClasses - 1)
# noinspection PyTypeChecker
return target * (1.0 - self.smoothing_eps) + \
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
_input: List[T] = list(input)
if self.smoothing_eps > 0.0:
@ -56,10 +56,26 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
"""A wrapper function for torch.nn.BCEWithLogitsLoss to enable label smoothing"""
def __init__(self, class_counts: Optional[Dict[float, float]] = None, **kwargs: Any):
def __init__(self, num_classes: int,
class_counts: Optional[Dict[float, int]] = None,
num_train_samples: Optional[int] = None,
**kwargs: Any):
"""
:param num_classes: The number of classes the model predicts. For binary classification num_classes is one
and for multi-label classification tasks num_classes will be greater than one.
:param class_counts: The number of positive samples for each class. class_counts is a dictionary with key-value
pairs corresponding to each class and the positive sample count for the class.
For binary classification tasks, class_counts should have a single key-value pair
for the positive class.
:param num_train_samples: The total number of training samples in the dataset.
"""
super().__init__(is_binary_classification=True, **kwargs)
if class_counts and not num_train_samples:
raise ValueError("Need to specify the num_train_samples with class_counts")
self._positive_class_weights = None
self._class_counts = class_counts
self._num_train_samples = num_train_samples
self.num_classes = num_classes
if class_counts:
self._positive_class_weights = self.get_positive_class_weights()
if torch.cuda.is_available():
@ -73,16 +89,18 @@ class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
target position.
:return: a list of weights to use for the positive class for each target position.
"""
assert self._class_counts is not None
labels = list(self._class_counts.keys())
if sorted(labels) != [0.0, 1.0]:
if labels == [1.0] or labels == [0.0]:
return torch.tensor(1.0)
else:
raise ValueError(f"Expected one-hot encoded binary label."
f"Found labels {self._class_counts.keys()}")
else:
return torch.tensor(float(self._class_counts[0.0]) / self._class_counts[1.0], dtype=torch.float32)
assert self._class_counts and self._num_train_samples
if len(self._class_counts) != self.num_classes:
raise ValueError(f"Have {self.num_classes} classes but got counts for {len(self._class_counts)} classes "
f"Note: If this is a binary classification task, there should be a single class count "
f"corresponding to the positive class.")
# These weights are given to the pos_weight parameter of Pytorch's BCEWithLogitsLoss.
# Weights are calculated as (number of negative samples for class 'i')/(number of positive samples for class 'i')
# for every class 'i' in a binary/multi-label classification task.
# For a binary classification task, this reduces to (number of false samples / number of true samples).
weights = [(self._num_train_samples - value) / value if value != 0 else 1.0 for (key, value) in
sorted(self._class_counts.items())] # Uses the first number on the tuple to compare
return torch.tensor(weights, dtype=torch.float32)
def forward_minibatch(self, output: T, target: T, **kwargs: Any) -> Any:
if isinstance(target, PackedSequence) and isinstance(output, PackedSequence):

Просмотреть файл

@ -15,7 +15,7 @@ from InnerEye.Azure.azure_runner import create_experiment_name, get_or_create_py
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, fetch_child_runs, fetch_run, \
get_cross_validation_split_index, is_cross_validation_child_run, is_run_and_child_runs_completed, \
merge_conda_dependencies, merge_conda_files, to_azure_friendly_container_path
from InnerEye.Common.common_util import logging_to_stdout
from InnerEye.Common.common_util import logging_to_stdout, is_linux
from InnerEye.Common.fixed_paths import PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \
get_environment_yaml_file, repository_root_directory
from InnerEye.Common.output_directories import OutputFolderForTests
@ -85,6 +85,7 @@ def test_is_cross_validation_child_run_ensemble_run() -> None:
assert all([is_cross_validation_child_run(x) for x in fetch_child_runs(run)])
@pytest.mark.skipif(is_linux(), reason="Spurious file read/write errors on linux build agents.")
def test_merge_conda(test_output_dirs: OutputFolderForTests) -> None:
"""
Tests the logic for merging Conda environment files.

Просмотреть файл

@ -21,7 +21,7 @@ from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
from InnerEye.ML.dataset.scalar_dataset import DataSourceReader, ScalarDataSource, ScalarDataset, \
_get_single_channel_row, _string_to_float, extract_label_classification, extract_label_regression, files_by_stem, \
_get_single_channel_row, _string_to_float, extract_label_classification, files_by_stem, \
is_valid_item_index, load_single_data_source
from InnerEye.ML.photometric_normalization import WindowNormalizationForScalarItem, mri_window
from InnerEye.ML.scalar_config import LabelTransformation, ScalarLoss, ScalarModelBase
@ -496,41 +496,49 @@ def test_load_single_item_7() -> None:
assert torch.all(torch.isnan(item.categorical_non_image_features[4:6]))
@pytest.mark.parametrize(["text", "expected_classification", "expected_regression"],
@pytest.mark.parametrize(["text", "is_classification", "num_classes", "expected_label"],
[
("true", 1, None),
("tRuE", 1, None),
("false", 0, None),
("False", 0, None),
("nO", 0, None),
("Yes", 1, None),
("1.23", None, 1.23),
(3.45, None, None),
(math.nan, math.nan, math.nan),
("", math.nan, math.nan),
(None, math.nan, math.nan),
("abc", None, None),
("1", 1, 1.0),
("-1", None, -1.0)
("false", True, 0, None),
("false", False, 0, None),
("true", True, 1, [1]),
("tRuE", True, 1, [1]),
("false", True, 1, [0]),
("False", True, 1, [0]),
("nO", True, 1, [0]),
("Yes", True, 1, [1.0]),
("1.23", False, 1, [1.23]),
(3.45, True, 1, None),
(3.45, False, 1, None),
("3.45", False, 1, [3.45]),
(math.nan, True, 1, [math.nan]),
(math.nan, True, 3, [0.0, 0.0, 0.0]),
(math.nan, False, 1, [math.nan]),
("", True, 1, [math.nan]),
("", True, 3, [0.0, 0.0, 0.0]),
("", False, 1, [math.nan]),
("abc", True, 1, None),
("abc", True, 3, None),
("abc", False, 1, None),
("1", True, 1, [1.0]),
("1", True, 3, [0.0, 1.0, 0.0]),
("1", False, 1, [1.0]),
("-1", False, 1, [-1.0]),
("1|2", True, 3, [0.0, 1.0, 1.0]),
("1|5", True, 3, None)
])
def test_extract_label(text: Union[float, str], expected_classification: Optional[float],
expected_regression: Optional[float]) -> None:
_check_label_extraction_function(extract_label_classification, text, expected_classification)
_check_label_extraction_function(extract_label_regression, text, expected_regression)
def _check_label_extraction_function(extract_fn: Callable, text: Union[float, str], expected: Optional[float]) -> None:
if expected is None:
def test_extract_label(text: str, is_classification: bool, num_classes: int,
expected_label: List[float], ) -> None:
if expected_label is None:
with pytest.raises(ValueError) as ex:
extract_fn(text, "foo")
assert "Subject foo:" in str(ex)
extract_label_classification(text, "subject1", num_classes, is_classification)
assert "Subject subject1:" in str(ex)
else:
actual = extract_fn(text, "foo")
assert isinstance(actual, type(expected))
if math.isnan(expected):
assert math.isnan(actual)
actual = extract_label_classification(text, "subject1", num_classes, is_classification)
assert isinstance(actual, type(expected_label))
if expected_label == [math.nan]:
assert math.isnan(actual[0])
else:
assert actual == expected
assert actual == expected_label
@pytest.mark.parametrize(["text", "expected"],
@ -791,10 +799,9 @@ def test_imbalanced_sampler() -> None:
assert count_negative_subjects / float(len(drawn_subjects)) > 0.3
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None:
def test_get_class_counts_binary(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models that predicts at multiple time points,
when it is started via run_ml.
Test the get_class_counts method for binary scalar datasets.
"""
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
@ -816,4 +823,86 @@ def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> No
config.set_output_to(test_output_dirs.root_dir)
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
class_counts = train_dataset.get_class_counts()
assert class_counts == {0.0: 1, 1.0: 2}
assert class_counts == {0: 2}
def test_get_class_counts_multilabel(test_output_dirs: OutputFolderForTests) -> None:
"""
Test the get_class_counts method for multilabel scalar datasets.
"""
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
dataset_contents = """subject,channel,path,label,CAT1
S1,week0,scan1.npy,,A
S1,week1,scan2.npy,0|1|2,A
S2,week0,scan3.npy,,A
S2,week1,scan4.npy,1|2,A
S3,week0,scan1.npy,,A
S3,week1,scan3.npy,1,A
"""
config = ScalarModelBase(
local_dataset=dataset_folder,
class_names=["class0", "class1", "class2", "class3"],
label_channels=["week1"],
label_value_column="label",
non_image_feature_channels=["week0", "week1"],
should_validate=False
)
config.set_output_to(test_output_dirs.root_dir)
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
class_counts = train_dataset.get_class_counts()
assert class_counts == {0: 1, 1: 3, 2: 2, 3: 0}
def test_get_labels_for_imbalanced_sampler_binary(test_output_dirs: OutputFolderForTests) -> None:
"""
Test the get_labels_for_imbalanced_sampler method for binary scalar datasets.
"""
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
S1,week0,scan1.npy,,1,10,A
S1,week1,scan2.npy,True,2,20,A
S2,week0,scan3.npy,,3,30,A
S2,week1,scan4.npy,False,4,40,A
S3,week0,scan1.npy,,5,50,A
S3,week1,scan3.npy,True,6,60,A
"""
config = ScalarModelBase(
local_dataset=dataset_folder,
label_channels=["week1"],
label_value_column="label",
non_image_feature_channels=["week0", "week1"],
numerical_columns=["numerical1", "numerical2"],
should_validate=False
)
config.set_output_to(test_output_dirs.root_dir)
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
labels = train_dataset.get_labels_for_imbalanced_sampler()
assert labels == [1.0, 0.0, 1.0]
def test_get_labels_for_imbalanced_sampler_multilabel(test_output_dirs: OutputFolderForTests) -> None:
"""
Test that the get_labels_for_imbalanced_sampler method raises an error for multilabel scalar datasets.
"""
dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
dataset_contents = """subject,channel,path,label,CAT1
S1,week0,scan1.npy,,A
S1,week1,scan2.npy,0|1|2,A
S2,week0,scan3.npy,,A
S2,week1,scan4.npy,1|2,A
S3,week0,scan1.npy,,A
S3,week1,scan3.npy,1,A
"""
config = ScalarModelBase(
local_dataset=dataset_folder,
class_names=["class0", "class1", "class2", "class3"],
label_channels=["week1"],
label_value_column="label",
non_image_feature_channels=["week0", "week1"],
should_validate=False
)
config.set_output_to(test_output_dirs.root_dir)
train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
with pytest.raises(NotImplementedError) as ex:
train_dataset.get_labels_for_imbalanced_sampler()
assert "ImbalancedSampler is not supported for multilabel tasks." in str(ex)

Просмотреть файл

@ -599,7 +599,7 @@ S4,0,True,4,40,M2,B1
assert_tensors_equal(test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None:
def test_get_class_counts(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models that predicts at multiple time points,
when it is started via run_ml.
@ -615,7 +615,7 @@ def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> No
splits = config.get_dataset_splits()
train_dataset = config.create_torch_datasets(splits)[ModelExecutionMode.TRAIN]
class_counts = train_dataset.get_class_counts()
assert class_counts == {0.0: 9, 1.0: 2}
assert class_counts == {0: 2}
def test_get_labels_at_target_indices() -> None:

Просмотреть файл

@ -15,7 +15,8 @@ from InnerEye.ML.models.layers.identity import Identity
class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
def __init__(self, expected_image_size_zyx: TupleInt3,
activation: torch.nn.Module = Identity(),
use_mixed_precision: bool = False) -> None:
use_mixed_precision: bool = False,
num_classes: int = 1) -> None:
super().__init__()
self.expected_image_size_zyx = expected_image_size_zyx
self._layers = torch.nn.ModuleList()
@ -25,7 +26,7 @@ class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
fc_out = fc(torch.zeros((1, 1) + self.expected_image_size_zyx))
self.feature_size = fc_out.view(-1).shape[0]
self._layers.append(fc)
self.fc = torch.nn.Linear(self.feature_size, 1)
self.fc = torch.nn.Linear(self.feature_size, out_features=num_classes)
self.activation = activation
self.last_encoder_layer: List[str] = ["_layers", "0"]
self.conv_in_3d = False

Просмотреть файл

@ -5,6 +5,7 @@
import pytest
import torch
import torch.optim as optim
from typing import Dict
from InnerEye.ML.models.losses.cross_entropy import CrossEntropyLoss
# Set random seed
@ -58,8 +59,9 @@ def test_cross_entropy_loss_forward_smoothing(is_segmentation: bool) -> None:
smoothed_target = torch.tensor([[[0.1, 0.1, 0.9], [0.9, 0.9, 0.1]]], dtype=torch.float32)
logits = torch.tensor([[[-10, -10, 0], [0, 0, 0]]], dtype=torch.float32)
barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0)
smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1)
barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1,
smoothing_eps=0)
smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1, smoothing_eps=0.1)
if is_segmentation:
# The two loss values are only expected to be the same when no class weighting takes place,
# because weighting is done on the *unsmoothed* target values.
@ -143,10 +145,17 @@ def test_weighted_binary_cross_entropy_loss_forward_smoothing() -> None:
smoothed_target = torch.tensor([[0.9], [0.9], [0.9], [0.9], [0.9], [0.1]], dtype=torch.float32)
logits = torch.tensor([[-10], [-10], [0], [0], [0], [0]], dtype=torch.float32)
weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0, class_counts={0.0: 1.0, 1.0: 5.0})
BinaryCrossEntropyWithLogitsLoss(num_classes=1,
smoothing_eps=0,
class_counts={1.0: 5},
num_train_samples=target.shape[0])
weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts={0.0: 1.0, 1.0: 5.0})
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1,
BinaryCrossEntropyWithLogitsLoss(num_classes=1,
smoothing_eps=0.1,
class_counts={1.0: 5},
num_train_samples=target.shape[0])
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(num_classes=1,
smoothing_eps=0.1,
class_counts=None)
w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
w_loss2 = weighted_smoothed_loss_fn(logits, target)
@ -157,23 +166,44 @@ def test_weighted_binary_cross_entropy_loss_forward_smoothing() -> None:
assert torch.all(positive_class_weights == torch.tensor([[0.2]]))
def test_weighted_binary_cross_entropy_loss_multi_target() -> None:
target = torch.tensor([[[1], [0]], [[1], [0]], [[0], [0]]], dtype=torch.float32)
smoothed_target = torch.tensor([[[0.9], [0.1]], [[0.9], [0.1]], [[0.1], [0.1]]], dtype=torch.float32)
logits = torch.tensor([[[-10], [1]], [[-10], [1]], [[10], [0]]], dtype=torch.float32)
def test_weighted_binary_cross_entropy_loss_multi_label() -> None:
# Class 0 has 2 positive examples, class 1 has none
target = torch.tensor([[1, 0], [1, 0], [0, 0]], dtype=torch.float32)
smoothed_target = torch.tensor([[0.9, 0.1], [0.9, 0.1], [0.1, 0.1]], dtype=torch.float32)
logits = torch.tensor([[-10, 1], [-10, 1], [10, 0]], dtype=torch.float32)
weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0, class_counts={1.0: 2, 0.0: 4})
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
smoothing_eps=0,
class_counts={1.0: 0, 0.0: 2},
num_train_samples=target.shape[0])
weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts={1.0: 2, 0.0: 4})
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
smoothing_eps=0.1,
class_counts={1.0: 0, 0.0: 2},
num_train_samples=target.shape[0])
non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts=None)
BinaryCrossEntropyWithLogitsLoss(num_classes=2,
smoothing_eps=0.1,
class_counts=None)
w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
w_loss2 = weighted_smoothed_loss_fn(logits, target)
w_loss3 = non_weighted_smoothed_loss_fn(logits, target)
positive_class_weights = weighted_smoothed_loss_fn.get_positive_class_weights() # type: ignore
assert torch.isclose(w_loss1, w_loss2)
assert not torch.isclose(w_loss2, w_loss3)
assert torch.all(positive_class_weights == torch.tensor(2))
assert torch.equal(positive_class_weights, torch.tensor([0.5, 1]))
@pytest.mark.parametrize("num_classes, class_counts", [(1, {1.0: 0, 0.0: 2}),
(3, {1.0: 0, 0.0: 2})])
def test_invalid_initialization(num_classes: int,
class_counts: Dict[float, int]) -> None:
with pytest.raises(ValueError) as ex:
BinaryCrossEntropyWithLogitsLoss(num_classes=num_classes,
smoothing_eps=0,
class_counts=class_counts,
num_train_samples=10)
assert f"Have {num_classes} classes but got counts for {len(class_counts)} classes" in str(ex)
class ToyNet(torch.nn.Module):

Просмотреть файл

@ -21,10 +21,13 @@ from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML import model_testing, model_training, runner
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
from InnerEye.ML.metrics import InferenceMetricsForClassification, binary_classification_accuracy, \
compute_scalar_metrics
from InnerEye.ML.metrics_dict import MetricsDict, ScalarMetricsDict
from InnerEye.ML.reports.notebook_report import get_ipynb_report_name, get_html_report_name, \
generate_classification_notebook, generate_classification_multilabel_notebook
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.config_util import ModelConfigLoader
@ -36,7 +39,8 @@ from Tests.ML.util import get_default_azure_config, get_default_checkpoint_handl
@pytest.mark.cpu_and_gpu
def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> None:
@pytest.mark.parametrize("class_name", [MetricsDict.DEFAULT_HUE_KEY, "foo"])
def test_train_classification_model(class_name: str, test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of classification models, asserting on the individual results from training and
testing.
@ -44,6 +48,7 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
"""
logging_to_stdout(logging.DEBUG)
config = ClassificationModelForTesting()
config.class_names = [class_name]
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=Path(test_output_dirs.root_dir))
@ -59,6 +64,7 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
assert len(model_training_result.val_results_per_epoch) == config.num_epochs
assert len(model_training_result.train_results_per_epoch[0]) >= 11
assert len(model_training_result.val_results_per_epoch[0]) >= 11
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
MetricType.AREA_UNDER_PR_CURVE,
@ -67,10 +73,12 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
MetricType.LOSS,
MetricType.SECONDS_PER_BATCH,
MetricType.SECONDS_PER_EPOCH,
MetricType.SUBJECT_COUNT,
]:
assert metric.value in model_training_result.train_results_per_epoch[0], f"{metric.value} not in training"
assert metric.value in model_training_result.val_results_per_epoch[0], f"{metric.value} not in validation"
MetricType.SUBJECT_COUNT]:
assert metric.value in model_training_result.train_results_per_epoch[0], \
f"{metric.value} not in training"
assert metric.value in model_training_result.val_results_per_epoch[0], \
f"{metric.value} not in validation"
actual_train_loss = model_training_result.get_metric(is_training=True, metric_type=MetricType.LOSS.value)
actual_val_loss = model_training_result.get_metric(is_training=False, metric_type=MetricType.LOSS.value)
actual_lr = model_training_result.get_metric(is_training=True, metric_type=MetricType.LEARNING_RATE.value)
@ -81,20 +89,27 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
checkpoint_handler=checkpoint_handler)
assert isinstance(test_results, InferenceMetricsForClassification)
expected_metrics = [0.636085, 0.735952]
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
assert test_results.metrics.values(class_name)[MetricType.CROSS_ENTROPY.value] == \
pytest.approx(expected_metrics, abs=1e-5)
# Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
# we want to mostly assert that the files look reasonable
if machine_has_gpu:
return
# Check epoch_metrics.csv
epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
# Auto-format will break the long header line, hence the strange way of writing it!
expected_epoch_metrics = \
"loss,cross_entropy,accuracy_at_threshold_05,learning_rate," + \
"area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
"false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
"optimal_threshold,subject_count,epoch,cross_validation_split_index\n" + \
f"{LoggingColumns.Loss.value},{LoggingColumns.CrossEntropy.value}," \
f"{LoggingColumns.AccuracyAtThreshold05.value},{LoggingColumns.LearningRate.value}," + \
f"{LoggingColumns.AreaUnderRocCurve.value}," \
f"{LoggingColumns.AreaUnderPRCurve.value}," \
f"{LoggingColumns.AccuracyAtOptimalThreshold.value}," \
f"{LoggingColumns.FalsePositiveRateAtOptimalThreshold.value}," \
f"{LoggingColumns.FalseNegativeRateAtOptimalThreshold.value}," \
f"{LoggingColumns.OptimalThreshold.value}," \
f"{LoggingColumns.SubjectCount.value},{LoggingColumns.Epoch.value}," \
f"{LoggingColumns.CrossValidationSplitIndex.value}\n" + \
"""0.6866141557693481,0.6866141557693481,0.5,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1
0.6864652633666992,0.6864652633666992,0.5,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1
0.6863163113594055,0.6863162517547607,0.5,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1
@ -107,15 +122,15 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
return
metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / SUBJECT_METRICS_FILE_NAME
metrics_expected = \
"""epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index
0,S2,Default,0.529514,1,Train,-1
0,S4,Default,0.521659,0,Train,-1
1,S4,Default,0.521482,0,Train,-1
1,S2,Default,0.529475,1,Train,-1
2,S4,Default,0.521305,0,Train,-1
2,S2,Default,0.529437,1,Train,-1
3,S2,Default,0.529399,1,Train,-1
3,S4,Default,0.521128,0,Train,-1
f"""epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index
0,S2,{class_name},0.529514,1,Train,-1
0,S4,{class_name},0.521659,0,Train,-1
1,S4,{class_name},0.521482,0,Train,-1
1,S2,{class_name},0.529475,1,Train,-1
2,S4,{class_name},0.521305,0,Train,-1
2,S2,{class_name},0.529437,1,Train,-1
3,S2,{class_name},0.529399,1,Train,-1
3,S4,{class_name},0.521128,0,Train,-1
"""
check_log_file(metrics_path, metrics_expected, ignore_columns=[])
# Check log METRICS_FILE_NAME inside of the folder epoch_004/Train, which is written when we run model_test.
@ -123,13 +138,97 @@ def test_train_classification_model(test_output_dirs: OutputFolderForTests) -> N
inference_metrics_path = config.outputs_folder / get_epoch_results_path(ModelExecutionMode.TRAIN) / \
SUBJECT_METRICS_FILE_NAME
inference_metrics_expected = \
"""prediction_target,subject,model_output,label,cross_validation_split_index,data_split
Default,S2,0.5293986201286316,1.0,-1,Train
Default,S4,0.5211275815963745,0.0,-1,Train
f"""prediction_target,subject,model_output,label,cross_validation_split_index,data_split
{class_name},S2,0.5293986201286316,1.0,-1,Train
{class_name},S4,0.5211275815963745,0.0,-1,Train
"""
check_log_file(inference_metrics_path, inference_metrics_expected, ignore_columns=[])
@pytest.mark.cpu_and_gpu
def test_train_classification_multilabel_model(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of classification models, asserting on the individual results from training and
testing.
Expected test results are stored for GPU with and without mixed precision.
"""
logging_to_stdout(logging.DEBUG)
config = DummyMulticlassClassification()
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=Path(test_output_dirs.root_dir))
# Train for 4 epochs, checkpoints at epochs 2 and 4
config.num_epochs = 4
model_training_result = model_training.model_train(config, checkpoint_handler=checkpoint_handler)
assert model_training_result is not None
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
expected_train_loss = [0.699870228767395, 0.6239662170410156, 0.551329493522644, 0.4825132489204407]
expected_val_loss = [0.6299371719360352, 0.5546272993087769, 0.4843321740627289, 0.41909298300743103]
# Ensure that all metrics are computed on both training and validation set
assert len(model_training_result.train_results_per_epoch) == config.num_epochs
assert len(model_training_result.val_results_per_epoch) == config.num_epochs
assert len(model_training_result.train_results_per_epoch[0]) >= 11
assert len(model_training_result.val_results_per_epoch[0]) >= 11
for class_name in config.class_names:
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
MetricType.AREA_UNDER_PR_CURVE,
MetricType.AREA_UNDER_ROC_CURVE,
MetricType.CROSS_ENTROPY]:
assert f'{metric.value}/{class_name}' in model_training_result.train_results_per_epoch[
0], f"{metric.value} not in training"
assert f'{metric.value}/{class_name}' in model_training_result.val_results_per_epoch[
0], f"{metric.value} not in validation"
for metric in [MetricType.LOSS,
MetricType.SECONDS_PER_EPOCH,
MetricType.SUBJECT_COUNT]:
assert metric.value in model_training_result.train_results_per_epoch[0], f"{metric.value} not in training"
assert metric.value in model_training_result.val_results_per_epoch[0], f"{metric.value} not in validation"
actual_train_loss = model_training_result.get_metric(is_training=True, metric_type=MetricType.LOSS.value)
actual_val_loss = model_training_result.get_metric(is_training=False, metric_type=MetricType.LOSS.value)
actual_lr = model_training_result.get_metric(is_training=True, metric_type=MetricType.LEARNING_RATE.value)
assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6), "Training loss"
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
checkpoint_handler=checkpoint_handler)
assert isinstance(test_results, InferenceMetricsForClassification)
expected_metrics = {MetricType.CROSS_ENTROPY: [1.3996, 5.2966, 1.4020, 0.3553, 0.6908],
MetricType.ACCURACY_AT_THRESHOLD_05: [0.0000, 0.0000, 0.0000, 1.0000, 1.0000]
}
for i, class_name in enumerate(config.class_names):
for metric in expected_metrics.keys():
assert expected_metrics[metric][i] == pytest.approx(
test_results.metrics.get_single_metric(
metric_name=metric,
hue=class_name), 1e-4)
def get_epoch_path(mode: ModelExecutionMode) -> Path:
p = get_epoch_results_path(mode=mode)
return config.outputs_folder / p / SUBJECT_METRICS_FILE_NAME
path_to_best_epoch_train = get_epoch_path(ModelExecutionMode.TRAIN)
path_to_best_epoch_val = get_epoch_path(ModelExecutionMode.VAL)
path_to_best_epoch_test = get_epoch_path(ModelExecutionMode.TEST)
generate_classification_notebook(result_notebook=config.outputs_folder / get_ipynb_report_name(config.model_category.value),
config=config,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
assert (config.outputs_folder / get_html_report_name(config.model_category.value)).exists()
report_name_multilabel = f"{config.model_category.value}_multilabel"
generate_classification_multilabel_notebook(result_notebook=config.outputs_folder / get_ipynb_report_name(report_name_multilabel),
config=config,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
assert (config.outputs_folder / get_html_report_name(report_name_multilabel)).exists()
def _count_lines(s: str) -> int:
lines = [line for line in s.splitlines() if line.strip()]
return len(lines)

Просмотреть файл

@ -1,13 +0,0 @@
subject,filePath,label
0,../test_data/classification_data_2d/im1.npy,0
1,../test_data/classification_data_2d/im2.npy,0
2,../test_data/classification_data_2d/im1.npy,0
3,../test_data/classification_data_2d/im2.npy,0
4,../test_data/classification_data_2d/im1.npy,0
5,../test_data/classification_data_2d/im2.npy,0
6,../test_data/classification_data_2d/im1.npy,0
7,../test_data/classification_data_2d/im2.npy,0
8,../test_data/classification_data_2d/im1.npy,0
9,../test_data/classification_data_2d/im2.npy,0
10,../test_data/classification_data_2d/im1.npy,0
11,../test_data/classification_data_2d/im2.npy,0
1 subject filePath label
2 0 ../test_data/classification_data_2d/im1.npy 0
3 1 ../test_data/classification_data_2d/im2.npy 0
4 2 ../test_data/classification_data_2d/im1.npy 0
5 3 ../test_data/classification_data_2d/im2.npy 0
6 4 ../test_data/classification_data_2d/im1.npy 0
7 5 ../test_data/classification_data_2d/im2.npy 0
8 6 ../test_data/classification_data_2d/im1.npy 0
9 7 ../test_data/classification_data_2d/im2.npy 0
10 8 ../test_data/classification_data_2d/im1.npy 0
11 9 ../test_data/classification_data_2d/im2.npy 0
12 10 ../test_data/classification_data_2d/im1.npy 0
13 11 ../test_data/classification_data_2d/im2.npy 0

Просмотреть файл

@ -0,0 +1,240 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
import pandas as pd
import numpy as np
import pytest
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.common_util import is_windows
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
from InnerEye.ML.metrics_dict import MetricsDict
from InnerEye.ML.reports.classification_multilabel_report import get_dataframe_with_exact_label_matches, \
get_labels_and_predictions_for_prediction_target_set, get_unique_prediction_target_combinations
from InnerEye.ML.reports.notebook_report import generate_classification_multilabel_notebook
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
def test_generate_classification_multilabel_report(test_output_dirs: OutputFolderForTests) -> None:
hues = ["Hue1", "Hue2"]
config = ScalarModelBase(label_value_column="label",
image_file_column="filePath",
image_channels=["image1", "image2"],
label_channels=["image1"])
config.class_names = hues
test_metrics_file = test_output_dirs.root_dir / "test_metrics_classification.csv"
val_metrics_file = test_output_dirs.root_dir / "val_metrics_classification.csv"
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv_path = config.local_dataset / "dataset.csv"
image_file_name = "image.npy"
pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
LoggingColumns.Epoch.value: [0] * 12,
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.9, 0.1, 0.9,
0.9, 0.9, 0.9, 0.9, 0.9, 0.1],
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
LoggingColumns.DataSplit.value: [0] * 12,
}).to_csv(test_metrics_file, index=False)
pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
LoggingColumns.Epoch.value: [0] * 12,
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.1, 0.1, 0.9,
0.9, 0.9, 0.9, 0.1, 0.9, 0.1],
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
LoggingColumns.DataSplit.value: [0] * 12,
}).to_csv(val_metrics_file, index=False)
pd.DataFrame.from_dict({config.subject_column: [s for s in range(6) for _ in range(2)],
config.channel_column: ["image1", "image2"] * 6,
config.image_file_column: [f for f in [f"0_{image_file_name}", f"1_{image_file_name}"]
for _ in range(6)],
config.label_value_column: ["", "", "1", "1", "1", "1", "0|1", "0|1", "0|1", "0|1", "0", "0"]
}).to_csv(dataset_csv_path, index=False)
np.save(str(Path(config.local_dataset / f"0_{image_file_name}")),
np.random.randint(0, 255, [5, 4]))
np.save(str(Path(config.local_dataset / f"1_{image_file_name}")),
np.random.randint(0, 255, [5, 4]))
result_file = test_output_dirs.root_dir / "report.ipynb"
result_html = generate_classification_multilabel_notebook(result_notebook=result_file,
config=config,
val_metrics=val_metrics_file,
test_metrics=test_metrics_file)
assert result_file.is_file()
assert result_html.is_file()
assert result_html.suffix == ".html"
def test_get_pseudo_labels_and_predictions() -> None:
reports_folder = Path(__file__).parent
test_metrics_file = reports_folder / "test_metrics_classification.csv"
results = get_labels_and_predictions_for_prediction_target_set(test_metrics_file,
[MetricsDict.DEFAULT_HUE_KEY],
all_prediction_targets=[MetricsDict.DEFAULT_HUE_KEY],
thresholds_per_prediction_target=[0.5])
assert all([results.subject_ids[i] == i for i in range(12)])
assert all([results.labels[i] == label for i, label in enumerate([1] * 6 + [0] * 6)])
assert all([results.model_outputs[i] == op for i, op in enumerate([0.0, 0.0, 0.0, 1.0, 1.0, 1.0] * 2)])
def test_get_pseudo_labels_and_predictions_multiple_hues(test_output_dirs: OutputFolderForTests) -> None:
reports_folder = Path(__file__).parent
test_metrics_file = reports_folder / "test_metrics_classification.csv"
# Write a new metrics file with 2 prediction targets,
# prediction_target_set_to_match will only be associated with one prediction target
csv = pd.read_csv(test_metrics_file)
hues = ["Hue1", "Hue2"]
csv.loc[::2, LoggingColumns.Hue.value] = hues[0]
csv.loc[1::2, LoggingColumns.Hue.value] = hues[1]
csv.loc[::2, LoggingColumns.Patient.value] = list(range(len(csv)//2))
csv.loc[1::2, LoggingColumns.Patient.value] = list(range(len(csv)//2))
csv.loc[::2, LoggingColumns.Label.value] = [0, 0, 0, 1, 1, 1]
csv.loc[1::2, LoggingColumns.Label.value] = [0, 1, 1, 1, 1, 0]
csv.loc[::2, LoggingColumns.ModelOutput.value] = [0.1, 0.1, 0.1, 0.9, 0.9, 0.9]
csv.loc[1::2, LoggingColumns.ModelOutput.value] = [0.1, 0.9, 0.9, 0.9, 0.9, 0.1]
metrics_csv_multi_hue = test_output_dirs.root_dir / "metrics.csv"
csv.to_csv(metrics_csv_multi_hue, index=False)
for h, hue in enumerate(hues):
results = get_labels_and_predictions_for_prediction_target_set(metrics_csv_multi_hue,
prediction_target_set_to_match=[hue],
all_prediction_targets=hues,
thresholds_per_prediction_target=[0.5, 0.5])
assert all([results.subject_ids[i] == i for i in range(6)])
assert all([results.labels[i] == label
for i, label in enumerate([0, 0, 0, 0, 0, 1] if h == 0 else [0, 1, 1, 0, 0, 0])])
assert all([results.model_outputs[i] == op
for i, op in enumerate([0, 0, 0, 0, 0, 1] if h == 0 else [0, 1, 1, 0, 0, 0])])
def test_generate_pseudo_labels() -> None:
metrics_df = pd.DataFrame.from_dict({"prediction_target": ["Hue1", "Hue2", "Hue3"] * 4,
"epoch": [0] * 12,
"subject": [i for i in range(4) for _ in range(3)],
"model_output": [0.5, 0.6, 0.3, 0.5, 0.6, 0.5, 0.5, 0.6, 0.5, 0.5, 0.6, 0.3],
"label": [1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1],
"cross_validation_split_index": [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
"data_split": [ModelExecutionMode.TEST.value] * 12
})
expected_df = pd.DataFrame.from_dict({"subject": list(range(4)),
"model_output": [1, 0, 0, 1],
"label": [1, 0, 1, 0],
"prediction_target": ["Hue1|Hue2"] * 4})
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
prediction_target_set_to_match=["Hue1", "Hue2"],
all_prediction_targets=["Hue1", "Hue2", "Hue3"],
thresholds_per_prediction_target=[0.4, 0.5, 0.4])
assert expected_df.equals(df)
def test_generate_pseudo_labels_negative_class() -> None:
metrics_df = pd.DataFrame.from_dict({"prediction_target": ["Hue1", "Hue2", "Hue3"] * 4,
"epoch": [0] * 12,
"subject": [i for i in range(4) for _ in range(3)],
"model_output": [0.2, 0.3, 0.2, 0.5, 0.6, 0.5, 0.5, 0.6, 0.5, 0.2, 0.3, 0.2],
"label": [0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1],
"cross_validation_split_index": [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
"data_split": [ModelExecutionMode.TEST.value] * 12
})
expected_df = pd.DataFrame.from_dict({"subject": list(range(4)),
"model_output": [1, 0, 0, 1],
"label": [1, 1, 0, 0],
"prediction_target": [""] * 4})
df = get_dataframe_with_exact_label_matches(metrics_df=metrics_df,
prediction_target_set_to_match=[],
all_prediction_targets=["Hue1", "Hue2", "Hue3"],
thresholds_per_prediction_target=[0.4, 0.5, 0.4])
assert expected_df.equals(df)
def test_get_unique_label_combinations_single_label(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
label_value_column="value",
image_channels=["image"],
image_file_column="path",
subject_column="subjectID")
class_names = config.class_names
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("subjectID,channel,path,value\n"
"S1,label,random,1\n"
"S1,image,random,\n"
"S2,label,random,0\n"
"S2,image,random,\n"
"S3,label,random,1\n"
"S3,image,random,\n")
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
for labels in [[], [0]])
assert unique_labels == expected_label_combinations
def test_get_unique_label_combinations_nan(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
label_value_column="value",
image_channels=["image"],
image_file_column="path",
subject_column="subjectID")
class_names = config.class_names
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("subjectID,channel,path,value\n"
"S1,label,random,1\n"
"S1,image,random,\n"
"S2,label,random,\n"
"S2,image,random,\n")
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
for labels in [[0]])
assert unique_labels == expected_label_combinations
def test_get_unique_label_combinations_multi_label(test_output_dirs: OutputFolderForTests) -> None:
config = DummyMulticlassClassification()
class_names = config.class_names
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("ID,channel,path,label\n"
"S1,blue,random,1|2|3\n"
"S1,green,random,\n"
"S2,blue,random,2|3\n"
"S2,green,random,\n"
"S3,blue,random,3\n"
"S3,green,random,\n"
"S4,blue,random,\n"
"S4,green,random,\n")
unique_labels = get_unique_prediction_target_combinations(config) # type: ignore
expected_label_combinations = set(frozenset(class_names[i] for i in labels) # type: ignore
for labels in [[1, 2, 3], [2, 3], [3], []])
assert unique_labels == expected_label_combinations

Просмотреть файл

@ -12,38 +12,64 @@ import pytest
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.common_util import is_windows
from InnerEye.ML.reports.classification_report import ReportedMetrics, get_correct_and_misclassified_examples, \
get_image_filepath_from_subject_id, get_k_best_and_worst_performing, get_metric, get_results, \
plot_image_from_filepath
get_image_filepath_from_subject_id, get_k_best_and_worst_performing, get_metric, get_labels_and_predictions, \
plot_image_from_filepath, get_image_labels_from_subject_id, get_image_outputs_from_subject_id
from InnerEye.ML.reports.notebook_report import generate_classification_notebook
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
from InnerEye.ML.metrics_dict import MetricsDict
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
def test_generate_classification_report(test_output_dirs: OutputFolderForTests) -> None:
reports_folder = Path(__file__).parent
test_metrics_file = reports_folder / "test_metrics_classification.csv"
val_metrics_file = reports_folder / "val_metrics_classification.csv"
dataset_csv_path = reports_folder / 'dataset.csv'
dataset_subject_column = "subject"
dataset_file_column = "filePath"
current_dir = test_output_dirs.make_sub_dir("test_classification_report")
result_file = current_dir / "report.ipynb"
config = ScalarModelBase(label_value_column="label",
image_file_column="filePath",
subject_column="subject")
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
image_file_name = "image.npy"
dataset_csv.write_text("subject,filePath,label\n"
f"0,0_{image_file_name},0\n"
f"1,1_{image_file_name},0\n"
f"2,0_{image_file_name},0\n"
f"3,1_{image_file_name},0\n"
f"4,0_{image_file_name},0\n"
f"5,1_{image_file_name},0\n"
f"6,0_{image_file_name},0\n"
f"7,1_{image_file_name},0\n"
f"8,0_{image_file_name},0\n"
f"9,1_{image_file_name},0\n"
f"10,0_{image_file_name},0\n"
f"11,1_{image_file_name},0\n")
np.save(str(Path(config.local_dataset / f"0_{image_file_name}")),
np.random.randint(0, 255, [5, 4]))
np.save(str(Path(config.local_dataset / f"1_{image_file_name}")),
np.random.randint(0, 255, [5, 4]))
result_file = test_output_dirs.root_dir / "report.ipynb"
result_html = generate_classification_notebook(result_notebook=result_file,
config=config,
val_metrics=val_metrics_file,
test_metrics=test_metrics_file,
dataset_csv_path=dataset_csv_path,
dataset_subject_column=dataset_subject_column,
dataset_file_column=dataset_file_column)
test_metrics=test_metrics_file)
assert result_file.is_file()
assert result_html.is_file()
assert result_html.suffix == ".html"
def test_get_results() -> None:
def test_get_labels_and_predictions() -> None:
reports_folder = Path(__file__).parent
test_metrics_file = reports_folder / "test_metrics_classification.csv"
results = get_results(test_metrics_file)
results = get_labels_and_predictions(test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
assert all([results.subject_ids[i] == i for i in range(12)])
assert all([results.labels[i] == label for i, label in enumerate([1] * 6 + [0] * 6)])
assert all([results.model_outputs[i] == op for i, op in enumerate([0.0, 0.2, 0.4, 0.6, 0.8, 1.0] * 2)])
@ -57,16 +83,18 @@ def test_functions_with_invalid_csv(test_output_dirs: OutputFolderForTests) -> N
shutil.copyfile(test_metrics_file, invalid_metrics_file)
# Duplicate a subject
with open(invalid_metrics_file, "a") as file:
file.write("Default,1,5,1.0,1,-1,Test")
file.write(f"{MetricsDict.DEFAULT_HUE_KEY},1,5,1.0,1,-1,Test")
with pytest.raises(ValueError) as ex:
get_labels_and_predictions(invalid_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
assert "Subject IDs should be unique" in str(ex)
with pytest.raises(ValueError):
get_results(invalid_metrics_file)
with pytest.raises(ValueError) as ex:
get_correct_and_misclassified_examples(invalid_metrics_file, test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
assert "Subject IDs should be unique" in str(ex)
with pytest.raises(ValueError):
get_correct_and_misclassified_examples(invalid_metrics_file, test_metrics_file)
with pytest.raises(ValueError):
get_correct_and_misclassified_examples(val_metrics_file, invalid_metrics_file)
with pytest.raises(ValueError) as ex:
get_correct_and_misclassified_examples(val_metrics_file, invalid_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
assert "Subject IDs should be unique" in str(ex)
def test_get_metric() -> None:
@ -74,41 +102,72 @@ def test_get_metric() -> None:
test_metrics_file = reports_folder / "test_metrics_classification.csv"
val_metrics_file = reports_folder / "val_metrics_classification.csv"
optimal_threshold = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
val_metrics = get_labels_and_predictions(val_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
test_metrics = get_labels_and_predictions(test_metrics_file, MetricsDict.DEFAULT_HUE_KEY)
optimal_threshold = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.OptimalThreshold)
assert optimal_threshold == 0.6
auc_roc = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
optimal_threshold = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.OptimalThreshold,
optimal_threshold=0.3)
assert optimal_threshold == 0.3
auc_roc = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.AUC_ROC)
assert auc_roc == 0.5
auc_pr = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
auc_pr = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.AUC_PR)
assert math.isclose(auc_pr, 13 / 24, abs_tol=1e-15)
accuracy = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
accuracy = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.Accuracy)
assert accuracy == 0.5
fpr = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
accuracy = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.Accuracy,
optimal_threshold=0.1)
assert accuracy == 0.5
fpr = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.FalsePositiveRate)
assert fpr == 0.5
fnr = get_metric(test_metrics_csv=test_metrics_file,
val_metrics_csv=val_metrics_file,
fpr = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.FalsePositiveRate,
optimal_threshold=0.1)
assert fpr == 5 / 6
fnr = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.FalseNegativeRate)
assert fnr == 0.5
fnr = get_metric(test_labels_and_predictions=test_metrics,
val_labels_and_predictions=val_metrics,
metric=ReportedMetrics.FalseNegativeRate,
optimal_threshold=0.1)
assert math.isclose(fnr, 1 / 6, abs_tol=1e-15)
def test_get_correct_and_misclassified_examples() -> None:
reports_folder = Path(__file__).parent
@ -153,34 +212,188 @@ def test_get_k_best_and_worst_performing() -> None:
assert worst_false_negatives == [0, 1]
def test_get_image_filepath_from_subject_id() -> None:
reports_folder = Path(__file__).parent
dataset_csv_file = reports_folder / "dataset.csv"
dataset_df = pd.read_csv(dataset_csv_file)
def test_get_image_filepath_from_subject_id_single(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(image_file_column="filePath",
label_value_column="label",
subject_column="subject")
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
image_file_name = "image.npy"
dataset_csv.write_text(f"subject,filePath,label\n"
f"0,0_{image_file_name},0\n"
f"1,1_{image_file_name},1\n")
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
Path(config.local_dataset / f"0_{image_file_name}").touch()
Path(config.local_dataset / f"1_{image_file_name}").touch()
filepath = get_image_filepath_from_subject_id(subject_id="1",
dataset_df=dataset_df,
dataset_subject_column="subject",
dataset_file_column="filePath",
dataset_dir=reports_folder)
expected_path = Path(reports_folder / "../test_data/classification_data_2d/im2.npy")
dataset=dataset,
config=config)
expected_path = Path(config.local_dataset / f"1_{image_file_name}")
assert filepath
assert expected_path.samefile(filepath)
assert len(filepath) == 1
assert expected_path.samefile(filepath[0])
# Check error is raised if the subject does not exist
with pytest.raises(ValueError) as ex:
get_image_filepath_from_subject_id(subject_id="100",
dataset=dataset,
config=config)
assert "Could not find subject" in str(ex)
def test_get_image_filepath_from_subject_id_invalid_id() -> None:
reports_folder = Path(__file__).parent
dataset_csv_file = reports_folder / "dataset.csv"
dataset_df = pd.read_csv(dataset_csv_file)
def test_get_image_filepath_from_subject_id_with_image_channels(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
image_file_column="filePath",
label_value_column="label",
image_channels=["image"],
subject_column="subject")
filepath = get_image_filepath_from_subject_id(subject_id="100",
dataset_df=dataset_df,
dataset_subject_column="subject",
dataset_file_column="filePath",
dataset_dir=reports_folder)
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
image_file_name = "image.npy"
dataset_csv.write_text(f"subject,channel,filePath,label\n"
f"0,label,,0\n"
f"0,image,0_{image_file_name},\n"
f"1,label,,1\n"
f"1,image,1_{image_file_name},\n")
assert not filepath
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
Path(config.local_dataset / f"0_{image_file_name}").touch()
Path(config.local_dataset / f"1_{image_file_name}").touch()
filepath = get_image_filepath_from_subject_id(subject_id="1",
dataset=dataset,
config=config)
expected_path = Path(config.local_dataset / f"1_{image_file_name}")
assert filepath
assert len(filepath) == 1
assert filepath[0].samefile(expected_path)
def test_get_image_filepath_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
image_file_column="filePath",
label_value_column="label",
image_channels=["image1", "image2"],
subject_column="subject")
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
image_file_name = "image.npy"
dataset_csv.write_text(f"subject,channel,filePath,label\n"
f"0,label,,0\n"
f"0,image1,00_{image_file_name},\n"
f"0,image2,01_{image_file_name},\n"
f"1,label,,1\n"
f"1,image1,10_{image_file_name},\n"
f"1,image2,11_{image_file_name},\n")
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
Path(config.local_dataset / f"00_{image_file_name}").touch()
Path(config.local_dataset / f"01_{image_file_name}").touch()
Path(config.local_dataset / f"10_{image_file_name}").touch()
Path(config.local_dataset / f"11_{image_file_name}").touch()
filepath = get_image_filepath_from_subject_id(subject_id="1",
dataset=dataset,
config=config)
expected_paths = [config.local_dataset / f"10_{image_file_name}",
config.local_dataset / f"11_{image_file_name}"]
assert filepath
assert len(filepath) == 2
assert expected_paths[0].samefile(filepath[0])
assert expected_paths[1].samefile(filepath[1])
def test_image_labels_from_subject_id_single(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_value_column="label",
subject_column="subject")
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("subject,channel,label\n"
"0,label,0\n"
"1,label,1\n")
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
labels = get_image_labels_from_subject_id(subject_id="0",
dataset=dataset,
config=config)
assert not labels
labels = get_image_labels_from_subject_id(subject_id="1",
dataset=dataset,
config=config)
assert labels
assert len(labels) == 1
assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
def test_image_labels_from_subject_id_with_label_channels(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
label_value_column="label",
subject_column="subject")
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("subject,channel,label\n"
"0,label,0\n"
"0,image,\n"
"1,label,1\n"
"1,image,\n")
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
labels = get_image_labels_from_subject_id(subject_id="1",
dataset=dataset,
config=config)
assert labels
assert len(labels) == 1
assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
def test_image_labels_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None:
config = ScalarModelBase(label_channels=["label"],
label_value_column="label",
subject_column="subject",
class_names=["class1", "class2", "class3"])
config.local_dataset = test_output_dirs.root_dir / "dataset"
config.local_dataset.mkdir()
dataset_csv = config.local_dataset / "dataset.csv"
dataset_csv.write_text("subject,channel,label\n"
"0,label,0\n"
"0,image,\n"
"1,label,1|2\n"
"1,image,\n")
df = config.read_dataset_if_needed()
dataset = ScalarDataset(args=config, data_frame=df)
labels = get_image_labels_from_subject_id(subject_id="1",
dataset=dataset,
config=config)
assert labels
assert len(labels) == 2
assert set(labels) == {config.class_names[1], config.class_names[2]}
def test_plot_image_from_filepath(test_output_dirs: OutputFolderForTests) -> None:
@ -197,3 +410,27 @@ def test_plot_image_from_filepath(test_output_dirs: OutputFolderForTests) -> Non
np.save(invalid_file, array)
res = plot_image_from_filepath(invalid_file, im_width)
assert not res
def test_get_image_outputs_from_subject_id(test_output_dirs: OutputFolderForTests) -> None:
hues = ["Hue1", "Hue2"]
metrics_df = pd.DataFrame.from_dict({LoggingColumns.Hue.value: [hues[0], hues[1]] * 6,
LoggingColumns.Epoch.value: [0] * 12,
LoggingColumns.Patient.value: [s for s in range(6) for _ in range(2)],
LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.1, 0.9, 0.1, 0.9,
0.9, 0.9, 0.9, 0.9, 0.9, 0.1],
LoggingColumns.Label.value: [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0],
LoggingColumns.CrossValidationSplitIndex: [DEFAULT_CROSS_VALIDATION_SPLIT_INDEX] * 12,
LoggingColumns.DataSplit.value: [0] * 12,
}, dtype=str)
config = DummyMulticlassClassification()
config.class_names = hues
config.subject_column = "subject"
model_output = get_image_outputs_from_subject_id(subject_id="1",
metrics_df=metrics_df)
assert model_output
assert len(model_output) == 2
assert all([m == e for m, e in zip(model_output, [(hues[0], 0.1), (hues[1], 0.9)])])

Просмотреть файл

@ -6,14 +6,17 @@ from io import StringIO
from pathlib import Path
import pandas as pd
import pytest
from InnerEye.Common.metrics_constants import MetricsFileColumns
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.common_util import is_windows
from InnerEye.ML.reports.notebook_report import generate_segmentation_notebook
from InnerEye.ML.reports.segmentation_report import describe_score, worst_patients_and_outliers
from InnerEye.ML.utils.csv_util import COL_IS_OUTLIER
@pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.")
def test_generate_segmentation_report(test_output_dirs: OutputFolderForTests) -> None:
reports_folder = Path(__file__).parent
metrics_file = reports_folder / "metrics_hn.csv"

Просмотреть файл

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4450200a7b3c1f1c4b94d844a3adfd429690c698592fe9cefac429e77b6ec56
size 4716

Просмотреть файл

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4450200a7b3c1f1c4b94d844a3adfd429690c698592fe9cefac429e77b6ec56
size 4716

Просмотреть файл

@ -0,0 +1,7 @@
ID,channel,path,label
S1,blue,1_blue.png,1|2|3
S1,green,1_green.png,
S2,blue,1_blue.png,2|3
S2,green,1_green.png,
S3,blue,1_blue.png,3
S3,green,1_green.png,
1 ID channel path label
2 S1 blue 1_blue.png 1|2|3
3 S1 green 1_green.png
4 S2 blue 1_blue.png 2|3
5 S2 green 1_green.png
6 S3 blue 1_blue.png 3
7 S3 green 1_green.png

Просмотреть файл

@ -152,7 +152,11 @@ Classification datasets should have a `dataset.csv` and a folder containing the
have at least the following fields:
* subject: The subject ID, a unique positive integer assigned to every image
* path: Path to the image file for this subject
* value: For classification, a (binary) ground truth label. For regression, a scalar value.
* value:
* For binary classification, a (binary) ground truth label. This can be "true" and "false" or "0" and "1".
* For multi-label classification, the set of all positive labels for the image, separated by a `|` character.
Ex: "0|2|4" for a sample with true labels 0, 2 and 4 and "" for a sample in which all labels are false.
* For regression, a scalar value.
These, and other fields which can be added to dataset.csv are described in the examples below.
@ -233,7 +237,7 @@ Other recognized fields, apart from subject, channel, file path and label are nu
These are extra scalar and categorical values to be used as model input.
Any *unrecognized* columns (any column which is both not described in the model config and has no default)
will be converted to a dict of key-value pairs and stored in a object of type `GeneralSampleMetadata` in the sample.
will be converted to a dict of key-value pairs and stored in an object of type `GeneralSampleMetadata` in the sample.
```
SubjectID, Channel, FilePath, Label, Tag, weight, class
@ -270,4 +274,58 @@ In this example, `weight` is a scalar feature read from the csv, and `class` is
**Filtering on channels**: This example also shows why filtering values by channel is useful: In this example, each subject has 2 images taken at
different times with different label values. By using `label_channels=["image_time_2"]`, we can use the label associated with
the second image for all subjects.
#### Multi-label classification datasets
Classification datasets can be multi-label, i.e. they can have more than one label associated with every sample.
In this case, in the label column, separate the (numerical) ground truth labels with a pipe character (`|`) to
provide multiple ground truth labels for the sample.
Note that only *multi-label* datasets are supported, *multi-class* datasets (where the labels are mutually exclusive)
are not supported.
For example, the `dataset.csv` for a multi-label task with 4 classes (0, 1, 2, 3) would look like the following:
```
SubjectID, Channel, FilePath, Label
1, image_feature_1, images/image_1_feature_1.npy,
1, image_feature_2, images/image_1_feature_2.npy,
1, label, , 0|2|3
2, image_feature_1, images/image_2_feature_1.npy
2, image_feature_2, images/image_2_feature_2.npy
2, label, , 1|2
3, image_feature_1, images/image_3_feature_1.npy
3, image_feature_2, images/image_3_feature_2.npy
3, label, , 1
4, image_feature_1, images/image_4_feature_1.npy
4, image_feature_2, images/image_4_feature_2.npy
4, label, ,
```
Note that the label field for sample 4 is left empty, this indicates that all labels are negative in Sample 4.
In multi-label tasks, the negative class (all ground truth classes being false for a sample) should not be
considered a separate class, and should be encoded by an empty label field.
The labels which are true for each sample in the `dataset.csv` shown above are:
* Sample 1: 0, 2, 3
* Sample 2: 1, 2
* Sample 3: 1
* Sample 4: No labels are true for this sample
The config file would be
```python
class GlaucomaPublicExt(GlaucomaPublic):
def __init__(self) -> None:
super().__init__(azure_dataset_id="name_of_your_dataset_on_azure",
subject_column="SubjectID",
channel_column="Channel",
image_channels=["image_feature_1", "image_feature_2"],
image_file_column="FilePath",
label_channels=["label"],
label_value_column="Label",
class_names=["class0", "class1", "class2", "class3"])
```
The added parameter `class_names` gives the string name corresponding to each ground truth class index.
In multi-label configs, the `class_names` parameter must be specified, so that InnerEye can recognize that the task is
a multi-label task and parse the `dataset.csv` accordingly. In binary tasks, the class_names field can optionally be
set to a list with a single string in it corresponding to the name of the positive class.