Clean up legacy code (#671)
* Remove rnns * Fix flake8 * Edit README * Edit README * Remove sequence * Remove sequence * Fix all * Remove more * Remove ignore * Fix tests * Undo config * Fix config * Revert pycharm * Fix tests * Undo outputlogger * Fix flake8 * Fix ignore file * Revert hi-ml * Disable fail on alert
This commit is contained in:
Родитель
8a78ec8c1e
Коммит
1606729c7a
|
@ -166,5 +166,6 @@ InnerEye-DataQuality/name_stats_scoring.png
|
|||
InnerEye-DataQuality/cifar-10-batches-py
|
||||
InnerEye-DataQuality/logs
|
||||
InnerEye-DataQuality/data
|
||||
|
||||
!**/InnerEye/ML/Histopathology/datasets
|
||||
None
|
||||
cifar-10-batches-py
|
||||
cifar-10-python.tar.gz
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.testing.pytestArgs": [
|
||||
"Tests"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
|
@ -13,7 +13,7 @@ created.
|
|||
## Upcoming
|
||||
|
||||
### Added
|
||||
|
||||
-([#671](https://github.com/microsoft/InnerEye-DeepLearning/pull/671)) Remove sequence models and unused variables. Simplify README.
|
||||
- ([#678](https://github.com/microsoft/InnerEye-DeepLearning/pull/678)) Add function to get log level name and use it for logging.
|
||||
- ([#666](https://github.com/microsoft/InnerEye-DeepLearning/pull/666)) Replace RadIO with TorchIO for patch-based inference.
|
||||
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
|
||||
|
|
|
@ -24,8 +24,6 @@ from InnerEye.Common.generic_parsing import GenericConfig
|
|||
# The name of the "azureml" property of AzureConfig
|
||||
AZURECONFIG_SUBMIT_TO_AZUREML = "azureml"
|
||||
|
||||
INPUT_DATA_KEY = "input_data"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GitInformation:
|
||||
|
|
|
@ -21,8 +21,6 @@ from InnerEye.Common.generic_parsing import GenericConfig
|
|||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.utils.config_loader import ModelConfigLoader
|
||||
|
||||
SLEEP_TIME_SECONDS = 30
|
||||
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
|
||||
|
||||
# Environment variables used for multi-node training
|
||||
|
|
|
@ -52,8 +52,7 @@ up each score from one set of results with a score from the other set.
|
|||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from itertools import filterfalse, tee
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -67,8 +66,6 @@ from InnerEye.Common.common_util import FULL_METRICS_DATAFRAME_FILE
|
|||
from InnerEye.Common.generic_parsing import GenericConfig
|
||||
from InnerEye.ML.visualizers.metrics_scatterplot import create_scatterplots
|
||||
|
||||
INTERSECT = lambda l, r: np.intersect1d(l, r, False)
|
||||
|
||||
"""
|
||||
The factor by which the Wilcoxon Z value should be divided to allow for incomplete independence of the data.
|
||||
Experimentation (from comparing models built from exactly the same code and data, but different random seeds)
|
||||
|
@ -206,15 +203,6 @@ def compose_pairwise_result(threshold: float, results: Dict[str, Dict[str, float
|
|||
return []
|
||||
|
||||
|
||||
def partition_results(pred: Callable, results: List[Any]) -> Tuple[List[Any], List[Any]]:
|
||||
"""
|
||||
Helper function to partition results into passed/failed
|
||||
"""
|
||||
t1, t2 = tee(results)
|
||||
map_func = lambda r: (r, results[r])
|
||||
return list(map(map_func, filterfalse(pred, t1))), list(map(map_func, filter(pred, t2)))
|
||||
|
||||
|
||||
def read_data(csv_file: str, subset: str = 'all', exclude: Optional[List[str]] = None) \
|
||||
-> Dict[str, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||
from health_ml.utils.common_utils import _create_generator
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd
|
||||
|
||||
|
||||
class CacheMode(Enum):
|
||||
NONE = 'none'
|
||||
MEMORY = 'memory'
|
||||
DISK = 'disk'
|
||||
|
||||
class CacheLocation(Enum):
|
||||
NONE = 'none'
|
||||
CPU = 'cpu'
|
||||
SAME = 'same'
|
||||
class TilesDataModule(LightningDataModule):
|
||||
"""Base class to load the tiles of a dataset as train, val, test sets"""
|
||||
|
||||
def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
|
||||
seed: Optional[int] = None, transform: Optional[Callable] = None,
|
||||
cache_mode: CacheMode = CacheMode.NONE,
|
||||
precache_location: CacheLocation = CacheLocation.NONE,
|
||||
cache_dir: Optional[Path] = None,
|
||||
number_of_cross_validation_splits: int = 0,
|
||||
cross_validation_split_index: int = 0) -> None:
|
||||
"""
|
||||
:param root_path: Root directory of the source dataset.
|
||||
:param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default),
|
||||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
||||
random subsets of instances.
|
||||
:param batch_size: Number of slides to load per batch.
|
||||
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
|
||||
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
|
||||
:param transform: A transform to apply to the source tiles dataset, or a composition of
|
||||
transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`.
|
||||
:param cache_mode: The type of caching to perform, i.e. whether the results of all
|
||||
transforms up to the first randomised one should be computed only once and reused in
|
||||
subsequent iterations:
|
||||
- `MEMORY`: MONAI CacheDataset is used, the entire transformed dataset is kept in memory for fastest access;
|
||||
- `DISK`: MONAI PersistentDataset is used, each transformed sample is saved to disk and loaded on-demand;
|
||||
- `NONE` (default): standard MONAI dataset is used, no caching is performed.
|
||||
:param precache_location: Whether to pre-cache the entire transformed dataset upfront and save
|
||||
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so
|
||||
multiple processes can afterwards access the same cache without contention in DDP settings. This parameter also allow to
|
||||
choose if the cache will be re-loaded into CPU or GPU memory:
|
||||
- `NONE (default)`: no pre-cache is performed;
|
||||
- `CPU`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded into CPU;
|
||||
- `SAME`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded on the same device it was saved from;
|
||||
If cache_mode is `DISK` precache_location `CPU` and `GPU` are equivalent.
|
||||
:param cache_dir: The directory onto which to cache data if caching is enabled.
|
||||
:param number_of_cross_validation_splits: Number of folds to perform.
|
||||
:param cross_validation_split_index: Index of the cross validation split to be performed.
|
||||
"""
|
||||
if precache_location is not CacheLocation.NONE and cache_mode is CacheMode.NONE:
|
||||
raise ValueError("Can only pre-cache if caching is enabled")
|
||||
if precache_location is not CacheLocation.NONE and cache_dir is None:
|
||||
raise ValueError("A cache directory is required for pre-caching")
|
||||
if cache_mode is CacheMode.DISK and cache_dir is None:
|
||||
raise ValueError("A cache directory is required for on-disk caching")
|
||||
super().__init__()
|
||||
|
||||
self.root_path = root_path
|
||||
self.max_bag_size = max_bag_size
|
||||
self.transform = transform
|
||||
self.cache_mode = cache_mode
|
||||
self.precache_location = precache_location
|
||||
self.cache_dir = cache_dir
|
||||
self.batch_size = batch_size
|
||||
self.number_of_cross_validation_splits = number_of_cross_validation_splits
|
||||
self.cross_validation_split_index = cross_validation_split_index
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
|
||||
self.class_weights = self.train_dataset.get_class_weights()
|
||||
self.seed = seed
|
||||
|
||||
def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]:
|
||||
"""Create the training, validation, and test datasets"""
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
if self.precache_location != CacheLocation.NONE:
|
||||
self._load_dataset(self.train_dataset, stage='train', shuffle=True)
|
||||
self._load_dataset(self.val_dataset, stage='val', shuffle=True)
|
||||
self._load_dataset(self.test_dataset, stage='test', shuffle=True)
|
||||
|
||||
def _dataset_pickle_path(self, stage: str) -> Optional[Path]:
|
||||
if self.cache_dir is None:
|
||||
return None
|
||||
return self.cache_dir / f"{stage}_dataset.pt"
|
||||
|
||||
def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset:
|
||||
dataset_pickle_path = self._dataset_pickle_path(stage)
|
||||
|
||||
if dataset_pickle_path and dataset_pickle_path.is_file():
|
||||
if self.precache_location == CacheLocation.CPU:
|
||||
memory_location = torch.device('cpu')
|
||||
print(f"Loading dataset from {dataset_pickle_path} into {memory_location}")
|
||||
else:
|
||||
# by default torch.load will reload on the same device it was saved from
|
||||
memory_location = None # type: ignore
|
||||
|
||||
with dataset_pickle_path.open('rb') as f:
|
||||
return torch.load(f, map_location=memory_location)
|
||||
|
||||
generator = _create_generator(self.seed)
|
||||
bag_dataset = BagDataset(tiles_dataset, # type: ignore
|
||||
bag_ids=tiles_dataset.slide_ids,
|
||||
max_bag_size=self.max_bag_size,
|
||||
shuffle_samples=shuffle,
|
||||
generator=generator)
|
||||
transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN)
|
||||
|
||||
# Save and restore PRNG state for consistency across (pre-)caching options
|
||||
generator_state = generator.get_state()
|
||||
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore
|
||||
generator.set_state(generator_state)
|
||||
|
||||
# Dataset is saved if cache_dir is True, regardless of CacheMode
|
||||
if dataset_pickle_path:
|
||||
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dataset_pickle_path.open('wb') as f:
|
||||
torch.save(transformed_bag_dataset, f)
|
||||
|
||||
return transformed_bag_dataset
|
||||
|
||||
def _get_transformed_dataset(self, base_dataset: BagDataset,
|
||||
transform: Union[Sequence[Callable], Callable]) -> Dataset:
|
||||
if self.cache_mode is CacheMode.MEMORY:
|
||||
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore
|
||||
elif self.cache_mode is CacheMode.DISK:
|
||||
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore
|
||||
if self.precache_location != CacheLocation.NONE:
|
||||
import tqdm # TODO: Make optional
|
||||
|
||||
for i in tqdm.trange(len(dataset), desc="Loading dataset"):
|
||||
dataset[i] # empty loop to pre-compute all transformed samples
|
||||
else:
|
||||
dataset = Dataset(base_dataset, transform) # type: ignore
|
||||
return dataset
|
||||
|
||||
def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool,
|
||||
**dataloader_kwargs: Any) -> DataLoader:
|
||||
transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle)
|
||||
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
|
||||
generator = bag_dataset.bag_sampler.generator
|
||||
return DataLoader(transformed_bag_dataset, batch_size=self.batch_size,
|
||||
collate_fn=multibag_collate, shuffle=shuffle, generator=generator,
|
||||
pin_memory=False, # disable pinning as loaded data may already be on GPU
|
||||
**dataloader_kwargs)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.train_dataset, 'train', shuffle=True)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.val_dataset, 'val', shuffle=True)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.test_dataset, 'test', shuffle=True)
|
|
@ -1,31 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Tuple, Any
|
||||
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
|
||||
class PandaTilesDataModule(TilesDataModule):
|
||||
""" PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset
|
||||
Method get_splits() returns the train, val, test splits from the PANDA dataset
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
|
||||
dataset = PandaTilesDataset(self.root_path)
|
||||
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
|
||||
proportion_train=.8,
|
||||
proportion_test=.1,
|
||||
proportion_val=.1,
|
||||
subject_column=dataset.TILE_ID_COLUMN,
|
||||
group_column=dataset.SLIDE_ID_COLUMN)
|
||||
return (PandaTilesDataset(self.root_path, dataset_df=splits.train),
|
||||
PandaTilesDataset(self.root_path, dataset_df=splits.val),
|
||||
PandaTilesDataset(self.root_path, dataset_df=splits.test))
|
|
@ -1,38 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Tuple, Any
|
||||
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
|
||||
class TcgaCrckTilesDataModule(TilesDataModule):
|
||||
""" TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset
|
||||
Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset
|
||||
Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]:
|
||||
trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True)
|
||||
splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(),
|
||||
proportion_train=0.8,
|
||||
proportion_test=0.0,
|
||||
proportion_val=0.2,
|
||||
subject_column=trainval_dataset.TILE_ID_COLUMN,
|
||||
group_column=trainval_dataset.SLIDE_ID_COLUMN,
|
||||
random_seed=5)
|
||||
|
||||
if self.number_of_cross_validation_splits > 1:
|
||||
# Function get_k_fold_cross_validation_splits() will concatenate train and val splits
|
||||
splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index]
|
||||
|
||||
return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train),
|
||||
TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val),
|
||||
TcgaCrck_TilesDataset(self.root_path, train=False))
|
|
@ -1,220 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sklearn.utils.class_weight import compute_class_weight
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||
|
||||
|
||||
class TilesDataset(Dataset):
|
||||
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.
|
||||
|
||||
:param TILE_ID_COLUMN: CSV column name for tile ID.
|
||||
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
|
||||
:param IMAGE_COLUMN: CSV column name for relative path to image file.
|
||||
:param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch.
|
||||
:param LABEL_COLUMN: CSV column name for tile label.
|
||||
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
|
||||
:param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional).
|
||||
:param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional).
|
||||
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
|
||||
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
|
||||
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
|
||||
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
|
||||
"""
|
||||
TILE_ID_COLUMN: str = 'tile_id'
|
||||
SLIDE_ID_COLUMN: str = 'slide_id'
|
||||
IMAGE_COLUMN: str = 'image'
|
||||
PATH_COLUMN: str = 'image_path'
|
||||
LABEL_COLUMN: str = 'label'
|
||||
SPLIT_COLUMN: Optional[str] = 'split'
|
||||
TILE_X_COLUMN: Optional[str] = 'tile_x'
|
||||
TILE_Y_COLUMN: Optional[str] = 'tile_y'
|
||||
|
||||
TRAIN_SPLIT_LABEL: str = 'train'
|
||||
TEST_SPLIT_LABEL: str = 'test'
|
||||
|
||||
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
||||
|
||||
N_CLASSES: int = 1 # binary classification by default
|
||||
|
||||
def __init__(self,
|
||||
root: Union[str, Path],
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
train: Optional[bool] = None) -> None:
|
||||
"""
|
||||
:param root: Root directory of the dataset.
|
||||
:param dataset_csv: Full path to a dataset CSV file, containing at least
|
||||
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
|
||||
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
|
||||
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
||||
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
||||
:param train: If `True`, loads only the training split (resp. `False` for test split). By
|
||||
default (`None`), loads the entire dataset as-is.
|
||||
"""
|
||||
if self.SPLIT_COLUMN is None and train is not None:
|
||||
raise ValueError("Train/test split was specified but dataset has no split column")
|
||||
|
||||
self.root_dir = Path(root)
|
||||
|
||||
if dataset_df is not None:
|
||||
self.dataset_csv = None
|
||||
else:
|
||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||
dataset_df = pd.read_csv(self.dataset_csv)
|
||||
|
||||
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN,
|
||||
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
|
||||
for column in columns:
|
||||
if column is not None and column not in dataset_df.columns:
|
||||
raise ValueError(f"Expected column '{column}' not found in the dataframe")
|
||||
|
||||
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
|
||||
if train is None:
|
||||
self.dataset_df = dataset_df
|
||||
else:
|
||||
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
|
||||
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_df.shape[0]
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
tile_id = self.dataset_df.index[index]
|
||||
sample = {
|
||||
self.TILE_ID_COLUMN: tile_id,
|
||||
**self.dataset_df.loc[tile_id].to_dict()
|
||||
}
|
||||
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
|
||||
# we're replicating this column because we want to propagate the path to the batch
|
||||
sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN]
|
||||
return sample
|
||||
|
||||
@property
|
||||
def slide_ids(self) -> pd.Series:
|
||||
return self.dataset_df[self.SLIDE_ID_COLUMN]
|
||||
|
||||
def get_slide_labels(self) -> pd.Series:
|
||||
return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode)
|
||||
|
||||
def get_class_weights(self) -> torch.Tensor:
|
||||
slide_labels = self.get_slide_labels()
|
||||
classes = np.unique(slide_labels)
|
||||
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
|
||||
return torch.as_tensor(class_weights)
|
||||
|
||||
|
||||
class SlidesDataset(Dataset):
|
||||
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
|
||||
|
||||
The output dictionaries are indexed by `..utils.naming.SlideKey`.
|
||||
|
||||
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
|
||||
:param IMAGE_COLUMN: CSV column name for relative path to image file.
|
||||
:param LABEL_COLUMN: CSV column name for tile label.
|
||||
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
|
||||
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
|
||||
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
|
||||
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
|
||||
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
|
||||
"""
|
||||
SLIDE_ID_COLUMN: str = 'slide_id'
|
||||
IMAGE_COLUMN: str = 'image'
|
||||
LABEL_COLUMN: str = 'label'
|
||||
MASK_COLUMN: Optional[str] = None
|
||||
SPLIT_COLUMN: Optional[str] = None
|
||||
|
||||
TRAIN_SPLIT_LABEL: str = 'train'
|
||||
TEST_SPLIT_LABEL: str = 'test'
|
||||
|
||||
METADATA_COLUMNS: Tuple[str, ...] = ()
|
||||
|
||||
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
||||
|
||||
N_CLASSES: int = 1 # binary classification by default
|
||||
|
||||
def __init__(self,
|
||||
root: Union[str, Path],
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
train: Optional[bool] = None,
|
||||
validate_columns: bool = True) -> None:
|
||||
"""
|
||||
:param root: Root directory of the dataset.
|
||||
:param dataset_csv: Full path to a dataset CSV file, containing at least
|
||||
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
|
||||
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
|
||||
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
||||
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
||||
:param train: If `True`, loads only the training split (resp. `False` for test split). By
|
||||
default (`None`), loads the entire dataset as-is.
|
||||
:param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`.
|
||||
"""
|
||||
if self.SPLIT_COLUMN is None and train is not None:
|
||||
raise ValueError("Train/test split was specified but dataset has no split column")
|
||||
|
||||
self.root_dir = Path(root)
|
||||
|
||||
if dataset_df is not None:
|
||||
self.dataset_csv = None
|
||||
else:
|
||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||
dataset_df = pd.read_csv(self.dataset_csv)
|
||||
|
||||
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
|
||||
if train is None:
|
||||
self.dataset_df = dataset_df
|
||||
else:
|
||||
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
|
||||
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
|
||||
|
||||
if validate_columns:
|
||||
self.validate_columns()
|
||||
|
||||
def validate_columns(self) -> None:
|
||||
"""Check that loaded dataframe contains expected columns, raises `ValueError` otherwise.
|
||||
|
||||
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
|
||||
call `validate_columns()` after creating derived columns, for example.
|
||||
"""
|
||||
columns = [self.IMAGE_COLUMN, self.LABEL_COLUMN, self.MASK_COLUMN,
|
||||
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
|
||||
for column in columns:
|
||||
if column is not None and column not in self.dataset_df.columns:
|
||||
raise ValueError(f"Expected column '{column}' not found in the dataframe")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_df.shape[0]
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[SlideKey, Any]:
|
||||
slide_id = self.dataset_df.index[index]
|
||||
slide_row = self.dataset_df.loc[slide_id]
|
||||
sample = {SlideKey.SLIDE_ID: slide_id}
|
||||
|
||||
rel_image_path = slide_row[self.IMAGE_COLUMN]
|
||||
sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path)
|
||||
# we're replicating this column because we want to propagate the path to the batch
|
||||
sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE]
|
||||
|
||||
if self.MASK_COLUMN:
|
||||
rel_mask_path = slide_row[self.MASK_COLUMN]
|
||||
sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path)
|
||||
sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK]
|
||||
|
||||
sample[SlideKey.LABEL] = slide_row[self.LABEL_COLUMN]
|
||||
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS}
|
||||
return sample
|
||||
|
||||
@classmethod
|
||||
def has_mask(cls) -> bool:
|
||||
return cls.MASK_COLUMN is not None
|
|
@ -1,15 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
PANDA_DATASET_ID = "PANDA"
|
||||
PANDA_TILES_DATASET_ID = "PANDA_tiles"
|
||||
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
|
||||
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"
|
||||
|
||||
DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
|
||||
PANDA_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_DATASET_ID
|
||||
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
|
||||
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
|
||||
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
|
|
@ -1,128 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
import pandas as pd
|
||||
from health_ml.utils import box_utils
|
||||
from monai.config import KeysCollection
|
||||
from monai.data.image_reader import ImageReader, WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
|
||||
try:
|
||||
from cucim import CuImage
|
||||
except:
|
||||
logging.warning("cucim library not available, code may fail.")
|
||||
|
||||
|
||||
class PandaDataset(SlidesDataset):
|
||||
"""Dataset class for loading files from the PANDA challenge dataset.
|
||||
|
||||
Iterating over this dataset returns a dictionary following the `SlideKey` schema plus meta-data
|
||||
from the original dataset (`'data_provider'`, `'isup_grade'`, and `'gleason_score'`).
|
||||
|
||||
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
|
||||
"""
|
||||
SLIDE_ID_COLUMN = 'image_id'
|
||||
IMAGE_COLUMN = 'image'
|
||||
MASK_COLUMN = 'mask'
|
||||
LABEL_COLUMN = 'isup_grade'
|
||||
|
||||
METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score')
|
||||
|
||||
DEFAULT_CSV_FILENAME = "train.csv"
|
||||
|
||||
def __init__(self,
|
||||
root: Union[str, Path],
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None) -> None:
|
||||
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
|
||||
# PANDA CSV does not come with paths for image and mask files
|
||||
slide_ids = self.dataset_df.index
|
||||
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
|
||||
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
|
||||
self.validate_columns()
|
||||
|
||||
|
||||
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
|
||||
class ReadImaged(MapTransform):
|
||||
"""Basic transform to read image files."""
|
||||
|
||||
def __init__(self, reader: ImageReader, keys: KeysCollection,
|
||||
allow_missing_keys: bool = False, **kwargs: Any) -> None:
|
||||
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||
self.reader = reader
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
for key in self.keys:
|
||||
if key in data or not self.allow_missing_keys:
|
||||
data[key] = self.reader.read(data[key], **self.kwargs)
|
||||
return data
|
||||
|
||||
|
||||
class LoadPandaROId(MapTransform):
|
||||
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI).
|
||||
|
||||
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
|
||||
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
|
||||
- `'location'` (tuple): top-right coordinates of the bounding box
|
||||
- `'size'` (tuple): width and height of the bounding box
|
||||
- `'level'` (int): chosen magnification level
|
||||
- `'scale'` (float): corresponding scale, loaded from the file
|
||||
"""
|
||||
|
||||
def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = 'mask',
|
||||
level: int = 0, margin: int = 0, **kwargs: Any) -> None:
|
||||
"""
|
||||
:param reader: And instance of MONAI's `WSIReader`.
|
||||
:param image_key: Image key in the input and output dictionaries.
|
||||
:param mask_key: Mask key in the input and output dictionaries.
|
||||
:param level: Magnification level to load from the raw multi-scale files.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||
"""
|
||||
super().__init__([image_key, mask_key], allow_missing_keys=False)
|
||||
self.reader = reader
|
||||
self.image_key = image_key
|
||||
self.mask_key = mask_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _get_bounding_box(self, mask_obj: 'CuImage') -> box_utils.Box:
|
||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||
highest_level = mask_obj.resolutions['level_count'] - 1
|
||||
scale = mask_obj.resolutions['level_downsamples'][highest_level]
|
||||
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
||||
|
||||
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
||||
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
|
||||
return bbox
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
mask_obj: CuImage = self.reader.read(data[self.mask_key])
|
||||
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||
|
||||
level0_bbox = self._get_bounding_box(mask_obj)
|
||||
|
||||
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
scale = mask_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
|
||||
size=(scaled_bbox.w, scaled_bbox.h),
|
||||
level=self.level)
|
||||
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
|
||||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data['scale'] = scale
|
||||
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
return data
|
|
@ -1,90 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
|
||||
|
||||
|
||||
class PandaTilesDataset(TilesDataset):
|
||||
"""
|
||||
Dataset class for loading PANDA tiles.
|
||||
|
||||
Iterating over this dataset returns a dictionary containing:
|
||||
- `'slide_id'` (str): parent slide ID (`'image_id'` in the PANDA dataset)
|
||||
- `'tile_id'` (str)
|
||||
- `'image'` (`PIL.Image`): RGB tile
|
||||
- `'mask'` (str): path to mask PNG file
|
||||
- `'tile_x'`, `'tile_y'` (int): top-right tile coordinates
|
||||
- `'data_provider'`, `'slide_isup_grade'`, `'slide_gleason_score'` (str): parent slide metadata
|
||||
"""
|
||||
LABEL_COLUMN = "slide_isup_grade"
|
||||
SPLIT_COLUMN = None # PANDA does not have an official train/test split
|
||||
N_CLASSES = 6
|
||||
|
||||
_RELATIVE_ROOT_FOLDER = Path("PANDA_tiles_20210926-135446/panda_tiles_level1_224")
|
||||
|
||||
def __init__(self,
|
||||
root: Path,
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
occupancy_threshold: Optional[float] = None) -> None:
|
||||
super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER,
|
||||
dataset_csv=dataset_csv,
|
||||
dataset_df=dataset_df,
|
||||
train=None)
|
||||
if occupancy_threshold is not None:
|
||||
dataset_df_filtered = self.dataset_df.loc[self.dataset_df['occupancy'] > occupancy_threshold] # type: ignore
|
||||
self.dataset_df = dataset_df_filtered
|
||||
|
||||
class PandaTilesDatasetReturnImageLabel(VisionDataset):
|
||||
"""
|
||||
Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a
|
||||
class label.
|
||||
"""
|
||||
occupancy_threshold = 0
|
||||
|
||||
def __init__(self,
|
||||
root: Path,
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
**kwargs: Any) -> None:
|
||||
super().__init__(root=root, transform=transform)
|
||||
|
||||
self.base_dataset = PandaTilesDataset(root=root,
|
||||
dataset_csv=dataset_csv,
|
||||
dataset_df=dataset_df,
|
||||
occupancy_threshold=self.occupancy_threshold)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple: # type: ignore
|
||||
sample = self.base_dataset[index]
|
||||
# TODO change to a meaningful evaluation
|
||||
image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN])
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
# get binary label
|
||||
label = 0 if sample[self.base_dataset.LABEL_COLUMN] == 0 else 1
|
||||
return image, label
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.base_dataset)
|
||||
|
||||
|
||||
class PandaTilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex, PandaTilesDatasetReturnImageLabel):
|
||||
"""
|
||||
Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData.
|
||||
This class is just a shorthand notation for this double inheritance. Please note that this class needs
|
||||
to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel.
|
||||
"""
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
|
@ -1,70 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
|
||||
|
||||
|
||||
class TcgaCrck_TilesDataset(TilesDataset):
|
||||
"""Dataset class for loading TCGA-CRCk tiles.
|
||||
|
||||
Iterating over this dataset returns a dictionary containing:
|
||||
- `'slide_id'` (str): parent slide ID
|
||||
- `'tile_id'` (str)
|
||||
- `'image'` (`PIL.Image`): RGB tile
|
||||
- `'label'` (str): MSS (0) vs MSIMUT (1)
|
||||
"""
|
||||
TILE_X_COLUMN = TILE_Y_COLUMN = None # no tile coordinates available
|
||||
# This dataset conforms to all other defaults in TilesDataset
|
||||
|
||||
|
||||
class TcgaCrck_TilesDatasetReturnImageLabel(VisionDataset):
|
||||
"""
|
||||
Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a
|
||||
class label.
|
||||
"""
|
||||
def __init__(self,
|
||||
root: Union[str, Path],
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None,
|
||||
train: Optional[bool] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
**kwargs: Any) -> None:
|
||||
super().__init__(root=root, transform=transform)
|
||||
self.base_dataset = TcgaCrck_TilesDataset(root=root,
|
||||
dataset_csv=dataset_csv,
|
||||
dataset_df=dataset_df,
|
||||
train=train)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple: # type: ignore
|
||||
sample = self.base_dataset[index]
|
||||
# TODO change to a meaningful evaluation
|
||||
image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN])
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
return image, sample[self.base_dataset.LABEL_COLUMN]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.base_dataset)
|
||||
|
||||
|
||||
class TcgaCrck_TilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex,
|
||||
TcgaCrck_TilesDatasetReturnImageLabel):
|
||||
"""
|
||||
Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData.
|
||||
This class is just a shorthand notation for this double inheritance. Please note that this class needs
|
||||
to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel.
|
||||
"""
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
|
@ -1,42 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
|
||||
|
||||
class TcgaPradDataset(SlidesDataset):
|
||||
"""Dataset class for loading TCGA-PRAD slides.
|
||||
|
||||
Iterating over this dataset returns a dictionary containing:
|
||||
- `'slide_id'` (str)
|
||||
- `'case_id'` (str)
|
||||
- `'image_path'` (str): absolute slide image path
|
||||
- `'label'` (int, 0 or 1): label for predicting positive or negative
|
||||
"""
|
||||
IMAGE_COLUMN: str = 'image_path'
|
||||
LABEL_COLUMN: str = 'label'
|
||||
|
||||
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
||||
|
||||
def __init__(self, root: Union[str, Path],
|
||||
dataset_csv: Optional[Union[str, Path]] = None,
|
||||
dataset_df: Optional[pd.DataFrame] = None) -> None:
|
||||
"""
|
||||
:param root: Root directory of the dataset.
|
||||
:param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from
|
||||
`"{root}/{DEFAULT_CSV_FILENAME}"`.
|
||||
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
||||
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
||||
"""
|
||||
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
|
||||
# Example of how to define a custom label column from existing columns:
|
||||
self.dataset_df[self.LABEL_COLUMN] = (self.dataset_df['label1']
|
||||
| self.dataset_df['label2']).astype(int)
|
||||
self.validate_columns()
|
|
@ -1,440 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import more_itertools as mi
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
|
||||
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset
|
||||
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
|
||||
from InnerEye.ML.Histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_tiles,
|
||||
plot_scores_hist, plot_heatmap_overlay,
|
||||
plot_slide, plot_normalized_confusion_matrix)
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey
|
||||
from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict
|
||||
from health_ml.utils import log_on_epoch
|
||||
|
||||
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
|
||||
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
|
||||
|
||||
def _format_cuda_memory_stats() -> str:
|
||||
return (f"GPU {torch.cuda.current_device()} memory: "
|
||||
f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB allocated, "
|
||||
f"{torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB reserved")
|
||||
|
||||
|
||||
class DeepMILModule(LightningModule):
|
||||
"""Base class for deep multiple-instance learning"""
|
||||
|
||||
def __init__(self,
|
||||
label_column: str,
|
||||
n_classes: int,
|
||||
encoder: TileEncoder,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
pool_hidden_dim: int = 128,
|
||||
pool_out_dim: int = 1,
|
||||
dropout_rate: Optional[float] = None,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
l_rate: float = 5e-4,
|
||||
weight_decay: float = 1e-4,
|
||||
adam_betas: Tuple[float, float] = (0.9, 0.99),
|
||||
verbose: bool = False,
|
||||
slide_dataset: SlidesDataset = None,
|
||||
tile_size: int = 224,
|
||||
level: int = 1,
|
||||
class_names: Optional[List[str]] = None,
|
||||
is_finetune: bool = False) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
|
||||
:param encoder: The tile encoder to use for feature extraction. If no encoding is needed,
|
||||
you should use `IdentityEncoder`.
|
||||
:param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a
|
||||
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
|
||||
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
|
||||
:param pool_out_dim: Output dimension of pooling layer (default=1).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param l_rate: Optimiser learning rate.
|
||||
:param weight_decay: Weight decay parameter for L2 regularisation.
|
||||
:param adam_betas: Beta parameters for Adam optimiser.
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
:param slide_dataset: Slide dataset object, if available.
|
||||
:param tile_size: The size of each tile (default=224).
|
||||
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
|
||||
:param class_names: The names of the classes if available (default=None).
|
||||
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Dataset specific attributes
|
||||
self.label_column = label_column
|
||||
self.n_classes = n_classes
|
||||
self.pool_hidden_dim = pool_hidden_dim
|
||||
self.pool_out_dim = pool_out_dim
|
||||
self.pooling_layer = pooling_layer
|
||||
self.dropout_rate = dropout_rate
|
||||
self.class_weights = class_weights
|
||||
self.encoder = encoder
|
||||
self.num_encoding = self.encoder.num_encoding
|
||||
|
||||
if class_names is not None:
|
||||
self.class_names = class_names
|
||||
else:
|
||||
if self.n_classes > 1:
|
||||
self.class_names = [str(i) for i in range(self.n_classes)]
|
||||
else:
|
||||
self.class_names = ['0', '1']
|
||||
if self.n_classes > 1 and len(self.class_names) != self.n_classes:
|
||||
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes})")
|
||||
if self.n_classes == 1 and len(self.class_names) != 2:
|
||||
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes+1})")
|
||||
|
||||
# Optimiser hyperparameters
|
||||
self.l_rate = l_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.adam_betas = adam_betas
|
||||
|
||||
# Slide specific attributes
|
||||
self.slide_dataset = slide_dataset
|
||||
self.tile_size = tile_size
|
||||
self.level = level
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# Finetuning attributes
|
||||
self.is_finetune = is_finetune
|
||||
|
||||
self.aggregation_fn, self.num_pooling = self.get_pooling()
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.loss_fn = self.get_loss()
|
||||
self.activation_fn = self.get_activation()
|
||||
|
||||
# Metrics Objects
|
||||
self.train_metrics = self.get_metrics()
|
||||
self.val_metrics = self.get_metrics()
|
||||
self.test_metrics = self.get_metrics()
|
||||
|
||||
def get_pooling(self) -> Tuple[Callable, int]:
|
||||
pooling_layer = self.pooling_layer(self.num_encoding,
|
||||
self.pool_hidden_dim,
|
||||
self.pool_out_dim)
|
||||
num_features = self.num_encoding*self.pool_out_dim
|
||||
return pooling_layer, num_features
|
||||
|
||||
def get_classifier(self) -> Callable:
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
elif 0 <= self.dropout_rate < 1:
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
else:
|
||||
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
|
||||
|
||||
def get_loss(self) -> Callable:
|
||||
if self.n_classes > 1:
|
||||
if self.class_weights is None:
|
||||
return nn.CrossEntropyLoss()
|
||||
else:
|
||||
class_weights = self.class_weights.float()
|
||||
return nn.CrossEntropyLoss(weight=class_weights)
|
||||
else:
|
||||
pos_weight = None
|
||||
if self.class_weights is not None:
|
||||
pos_weight = Tensor([self.class_weights[1]/(self.class_weights[0]+1e-5)])
|
||||
return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
|
||||
def get_activation(self) -> Callable:
|
||||
if self.n_classes > 1:
|
||||
return nn.Softmax()
|
||||
else:
|
||||
return nn.Sigmoid()
|
||||
|
||||
@staticmethod
|
||||
def get_bag_label(labels: Tensor) -> Tensor:
|
||||
# Get bag (batch) labels as majority vote
|
||||
bag_label = mode(labels).values
|
||||
return bag_label.view(1)
|
||||
|
||||
def get_metrics(self) -> nn.ModuleDict:
|
||||
if self.n_classes > 1:
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes, average='micro'),
|
||||
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
|
||||
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
|
||||
else:
|
||||
threshold = 0.5
|
||||
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
|
||||
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
|
||||
MetricsKey.PRECISION: Precision(threshold=threshold),
|
||||
MetricsKey.RECALL: Recall(threshold=threshold),
|
||||
MetricsKey.F1: F1(threshold=threshold),
|
||||
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
|
||||
|
||||
def log_metrics(self,
|
||||
stage: str) -> None:
|
||||
valid_stages = ['train', 'test', 'val']
|
||||
if stage not in valid_stages:
|
||||
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
|
||||
for metric_name, metric_object in self.get_metrics_dict(stage).items():
|
||||
if metric_name == MetricsKey.CONF_MATRIX:
|
||||
metric_value = metric_object.compute()
|
||||
metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True)
|
||||
for i in range(metric_value_n.shape[0]):
|
||||
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
|
||||
else:
|
||||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
with set_grad_enabled(self.is_finetune):
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(1, -1)
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit, attentions
|
||||
|
||||
def configure_optimizers(self) -> optim.Optimizer:
|
||||
return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay,
|
||||
betas=self.adam_betas)
|
||||
|
||||
def get_metrics_dict(self, stage: str) -> nn.ModuleDict:
|
||||
return getattr(self, f'{stage}_metrics')
|
||||
|
||||
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsKey, Tensor]:
|
||||
# The batch dict contains lists of tensors of different sizes, for all bags in the batch.
|
||||
# This means we can't stack them along a new axis without padding to the same length.
|
||||
# We could alternatively concatenate them, but this would require other changes (e.g. in
|
||||
# the attention layers) to correctly split the tensors by bag/slide ID.
|
||||
bag_labels_list = []
|
||||
bag_logits_list = []
|
||||
bag_attn_list = []
|
||||
for bag_idx in range(len(batch[self.label_column])):
|
||||
images = batch[TilesDataset.IMAGE_COLUMN][bag_idx]
|
||||
labels = batch[self.label_column][bag_idx]
|
||||
bag_labels_list.append(self.get_bag_label(labels))
|
||||
logit, attn = self(images)
|
||||
bag_logits_list.append(logit.view(-1))
|
||||
bag_attn_list.append(attn)
|
||||
bag_logits = torch.stack(bag_logits_list)
|
||||
bag_labels = torch.stack(bag_labels_list).view(-1)
|
||||
|
||||
if self.n_classes > 1:
|
||||
loss = self.loss_fn(bag_logits, bag_labels.long())
|
||||
else:
|
||||
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
|
||||
predicted_probs = self.activation_fn(bag_logits)
|
||||
if self.n_classes > 1:
|
||||
predicted_labels = argmax(predicted_probs, dim=1)
|
||||
probs_perclass = predicted_probs
|
||||
else:
|
||||
predicted_labels = round(predicted_probs)
|
||||
probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()] for i in range(len(predicted_probs))])
|
||||
|
||||
loss = loss.view(-1, 1)
|
||||
predicted_labels = predicted_labels.view(-1, 1)
|
||||
if self.n_classes == 1:
|
||||
predicted_probs = predicted_probs.view(-1, 1)
|
||||
bag_labels = bag_labels.view(-1, 1)
|
||||
|
||||
results = dict()
|
||||
for metric_object in self.get_metrics_dict(stage).values():
|
||||
if self.n_classes > 1:
|
||||
metric_object.update(predicted_probs, bag_labels.squeeze())
|
||||
else:
|
||||
metric_object.update(predicted_probs, bag_labels)
|
||||
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
|
||||
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
|
||||
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
|
||||
ResultsKey.PROB: predicted_probs, ResultsKey.CLASS_PROBS: probs_perclass,
|
||||
ResultsKey.PRED_LABEL: predicted_labels,
|
||||
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
|
||||
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})
|
||||
|
||||
if (TilesDataset.TILE_X_COLUMN in batch.keys()) and (TilesDataset.TILE_Y_COLUMN in batch.keys()):
|
||||
results.update({ResultsKey.TILE_X: batch[TilesDataset.TILE_X_COLUMN],
|
||||
ResultsKey.TILE_Y: batch[TilesDataset.TILE_Y_COLUMN]}
|
||||
)
|
||||
else:
|
||||
logging.warning("Coordinates not found in batch. If this is not expected check your input tiles dataset.")
|
||||
|
||||
return results
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
|
||||
train_result = self._shared_step(batch, batch_idx, 'train')
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
if self.verbose:
|
||||
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
|
||||
self.log_metrics('train')
|
||||
return train_result[ResultsKey.LOSS]
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
|
||||
val_result = self._shared_step(batch, batch_idx, 'val')
|
||||
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log_metrics('val')
|
||||
return val_result[ResultsKey.LOSS]
|
||||
|
||||
def test_step(self, batch: Dict, batch_idx: int) -> Dict[ResultsKey, Any]: # type: ignore
|
||||
test_result = self._shared_step(batch, batch_idx, 'test')
|
||||
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log_metrics('test')
|
||||
return test_result
|
||||
|
||||
def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore
|
||||
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
|
||||
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
|
||||
# example of batch_key ResultsKey.SLIDE_ID_COL
|
||||
# for batch keys that contains multiple values for slides e.g. ResultsKey.BAG_ATTN_COL
|
||||
# outputs[batch_idx][batch_key][bag_idx][tile_idx]
|
||||
# contains the tile value
|
||||
|
||||
# collate the batches
|
||||
results: Dict[str, List[Any]] = {}
|
||||
[results.update({col: []}) for col in outputs[0].keys()]
|
||||
for key in results.keys():
|
||||
for batch_id in range(len(outputs)):
|
||||
results[key] += outputs[batch_id][key]
|
||||
|
||||
print("Saving outputs ...")
|
||||
# collate at slide level
|
||||
list_slide_dicts = []
|
||||
list_encoded_features = []
|
||||
# any column can be used here, the assumption is that the first dimension is the N of slides
|
||||
for slide_idx in range(len(results[ResultsKey.SLIDE_ID])):
|
||||
slide_dict = dict()
|
||||
for key in results.keys():
|
||||
if key not in [ResultsKey.IMAGE, ResultsKey.LOSS]:
|
||||
slide_dict[key] = results[key][slide_idx]
|
||||
list_slide_dicts.append(slide_dict)
|
||||
list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx])
|
||||
|
||||
outputs_path = fixed_paths.repository_parent_directory() / 'outputs'
|
||||
print(f"Metrics results will be output to {outputs_path}")
|
||||
outputs_fig_path = outputs_path / 'fig'
|
||||
csv_filename = outputs_path / 'test_output.csv'
|
||||
encoded_features_filename = outputs_path / 'test_encoded_features.pickle'
|
||||
|
||||
# Collect the list of dictionaries in a list of pandas dataframe and save
|
||||
df_list = []
|
||||
for slide_dict in list_slide_dicts:
|
||||
slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False)
|
||||
df_list.append(pd.DataFrame.from_dict(slide_dict))
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df.to_csv(csv_filename, mode='w', header=True)
|
||||
|
||||
# Collect all features in a list and save
|
||||
features_list = self.move_list_to_device(list_encoded_features, use_gpu=False)
|
||||
torch.save(features_list, encoded_features_filename)
|
||||
|
||||
print("Selecting tiles ...")
|
||||
# Class 0
|
||||
tn_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'highest_att'))
|
||||
tn_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'lowest_att'))
|
||||
fp_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'highest_att'))
|
||||
fp_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'lowest_att'))
|
||||
report_cases = {'TN': [tn_top_tiles, tn_bottom_tiles], 'FP': [fp_top_tiles, fp_bottom_tiles]}
|
||||
|
||||
# Class 1 to n_classes-1
|
||||
n_classes_to_select = self.n_classes if self.n_classes > 1 else 2
|
||||
for i in range(1, n_classes_to_select):
|
||||
fn_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'highest_att'))
|
||||
fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'lowest_att'))
|
||||
tp_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'highest_att'))
|
||||
tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'lowest_att'))
|
||||
report_cases.update({'TP_'+str(i): [tp_top_tiles, tp_bottom_tiles], 'FN_'+str(i): [fn_top_tiles, fn_bottom_tiles]})
|
||||
|
||||
for key in report_cases.keys():
|
||||
print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...")
|
||||
key_folder_path = outputs_fig_path / f'{key}'
|
||||
Path(key_folder_path).mkdir(parents=True, exist_ok=True)
|
||||
nslides = len(report_cases[key][0])
|
||||
for i in range(nslides):
|
||||
slide, score, paths, top_attn = report_cases[key][0][i]
|
||||
fig = plot_attention_tiles(slide, score, paths, top_attn, key + '_top', ncols=4)
|
||||
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png'))
|
||||
|
||||
slide, score, paths, bottom_attn = report_cases[key][1][i]
|
||||
fig = plot_attention_tiles(slide, score, paths, bottom_attn, key + '_bottom', ncols=4)
|
||||
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png'))
|
||||
|
||||
if self.slide_dataset is not None:
|
||||
slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore
|
||||
_ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore
|
||||
slide_image = slide_dict[SlideKey.IMAGE]
|
||||
location_bbox = slide_dict[SlideKey.LOCATION]
|
||||
|
||||
fig = plot_slide(slide_image=slide_image, scale=1.0)
|
||||
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png'))
|
||||
fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results,
|
||||
location_bbox=location_bbox, tile_size=self.tile_size, level=self.level)
|
||||
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png'))
|
||||
|
||||
print("Plotting histogram ...")
|
||||
fig = plot_scores_hist(results)
|
||||
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png')
|
||||
|
||||
print("Computing and saving confusion matrix...")
|
||||
metrics_dict = self.get_metrics_dict('test')
|
||||
cf_matrix = metrics_dict[MetricsKey.CONF_MATRIX].compute()
|
||||
cf_matrix = np.array(cf_matrix.cpu())
|
||||
# We can't log tensors in the normal way - just print it to console
|
||||
print('test/confusion matrix:')
|
||||
print(cf_matrix)
|
||||
# Save the normalized confusion matrix as a figure in outputs
|
||||
cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1, keepdims=True)
|
||||
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=self.class_names)
|
||||
self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png')
|
||||
|
||||
@staticmethod
|
||||
def save_figure(fig: plt.figure, figpath: Path) -> None:
|
||||
fig.savefig(figpath, bbox_inches='tight')
|
||||
|
||||
@staticmethod
|
||||
def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict:
|
||||
# slide-level dictionaries are processed by making value dimensions uniform and converting to numpy arrays.
|
||||
# these steps are required to convert the dictionary to pandas dataframe.
|
||||
device = 'cuda' if use_gpu else 'cpu'
|
||||
dict_new = dict()
|
||||
bag_size = len(dict_old[ResultsKey.SLIDE_ID])
|
||||
for key, value in dict_old.items():
|
||||
if key not in [ResultsKey.CLASS_PROBS, ResultsKey.PROB]:
|
||||
if isinstance(value, Tensor):
|
||||
value = value.squeeze(0).to(device).numpy()
|
||||
if value.ndim == 0:
|
||||
value = np.full(bag_size, fill_value=value)
|
||||
dict_new[key] = value
|
||||
elif key == ResultsKey.CLASS_PROBS:
|
||||
if isinstance(value, Tensor):
|
||||
value = value.squeeze(0).to(device).numpy()
|
||||
for i in range(len(value)):
|
||||
dict_new[key+str(i)] = np.repeat(value[i], bag_size)
|
||||
return dict_new
|
||||
|
||||
@staticmethod
|
||||
def move_list_to_device(list_encoded_features: List, use_gpu: bool) -> List:
|
||||
# a list of features on cpu obtained from original list on gpu
|
||||
features_list = []
|
||||
device = 'cuda' if use_gpu else 'cpu'
|
||||
for feature in list_encoded_features:
|
||||
feature = feature.squeeze(0).to(device)
|
||||
features_list.append(feature)
|
||||
return features_list
|
|
@ -1,147 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pl_bolts.models.self_supervised import SimCLR
|
||||
from torch import nn
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from InnerEye.ML.Histopathology.utils.layer_utils import (get_imagenet_preprocessing,
|
||||
load_weights_to_model,
|
||||
setup_feature_extractor)
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.utils import create_ssl_image_classifier
|
||||
|
||||
|
||||
class TileEncoder(nn.Module):
|
||||
"""Base tile encoder class for use in dataset transforms or as part of a bigger model"""
|
||||
|
||||
def __init__(self, tile_size: int = 0, n_channels: int = 3,
|
||||
input_dim: Optional[Sequence[int]] = None) -> None:
|
||||
"""The `TileEncoder` constructor should be called after setting any attributes needed in
|
||||
`_get_preprocessing()` or `_get_encoder()`.
|
||||
|
||||
:param tile_size: Tile width/height, in pixels.
|
||||
:param n_channels: Number of channels in the tile (default=3).
|
||||
:param input_dim: Input shape, to override default of `(n_channels, tile_size, tile_size)`.
|
||||
"""
|
||||
super().__init__()
|
||||
if input_dim is None:
|
||||
if tile_size == 0:
|
||||
raise ValueError("Either input_dim or tile_size must be specified")
|
||||
input_dim = (n_channels, tile_size, tile_size)
|
||||
self.input_dim = tuple(input_dim)
|
||||
|
||||
self.preprocessing_fn = self._get_preprocessing()
|
||||
self.feature_extractor_fn, self.num_encoding = self._get_encoder()
|
||||
|
||||
def _get_preprocessing(self) -> Callable:
|
||||
return Compose([])
|
||||
|
||||
def _get_encoder(self) -> Tuple[Callable, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
prep_images = self.preprocessing_fn(images)
|
||||
return self.feature_extractor_fn(prep_images)
|
||||
|
||||
|
||||
class IdentityEncoder(TileEncoder):
|
||||
"""Dummy encoder that just flattens the input"""
|
||||
|
||||
def _get_encoder(self) -> Tuple[Callable, int]:
|
||||
return nn.Flatten(), np.prod(self.input_dim)
|
||||
|
||||
|
||||
class ImageNetEncoder(TileEncoder):
|
||||
"""Feature extractor pretrained for classification on ImageNet"""
|
||||
|
||||
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
|
||||
tile_size: int, n_channels: int = 3) -> None:
|
||||
"""
|
||||
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
|
||||
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
|
||||
:param tile_size: Tile width/height, in pixels.
|
||||
:param n_channels: Number of channels in the tile (default=3).
|
||||
"""
|
||||
self.create_feature_extractor_fn = feature_extraction_model
|
||||
super().__init__(tile_size=tile_size, n_channels=n_channels)
|
||||
|
||||
def _get_preprocessing(self) -> Callable:
|
||||
return get_imagenet_preprocessing()
|
||||
|
||||
def _get_encoder(self) -> Tuple[Callable, int]:
|
||||
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
|
||||
return setup_feature_extractor(pretrained_model, self.input_dim) # type: ignore
|
||||
|
||||
|
||||
class ImageNetSimCLREncoder(TileEncoder):
|
||||
"""SimCLR encoder pretrained on ImageNet"""
|
||||
|
||||
WEIGHTS_URL = ("https://pl-bolts-weights.s3.us-east-2.amazonaws.com/"
|
||||
"simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt")
|
||||
EMBEDDING_DIM = 2048
|
||||
|
||||
def _get_preprocessing(self) -> Callable:
|
||||
return get_imagenet_preprocessing()
|
||||
|
||||
def _get_encoder(self) -> Tuple[SimCLR, int]:
|
||||
simclr = SimCLR.load_from_checkpoint(self.WEIGHTS_URL, strict=False)
|
||||
simclr.freeze()
|
||||
return simclr, self.EMBEDDING_DIM
|
||||
|
||||
|
||||
class InnerEyeSSLEncoder(TileEncoder):
|
||||
"""SSL encoder trained on Azure ML using InnerEye"""
|
||||
|
||||
def __init__(self, pl_checkpoint_path: Path, tile_size: int, n_channels: int = 3) -> None:
|
||||
"""
|
||||
:param pl_checkpoint_path: The path of the downloaded checkpoint file.
|
||||
:param tile_size: Tile width/height, in pixels.
|
||||
:param n_channels: Number of channels in the tile (default=3).
|
||||
"""
|
||||
self.pl_checkpoint_path = pl_checkpoint_path
|
||||
super().__init__(tile_size=tile_size, n_channels=n_channels)
|
||||
|
||||
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
|
||||
model: SSLClassifier = create_ssl_image_classifier( # type: ignore
|
||||
num_classes=1, # dummy value
|
||||
freeze_encoder=True,
|
||||
pl_checkpoint_path=str(self.pl_checkpoint_path)
|
||||
)
|
||||
encoder = model.encoder # type: ignore
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False # freeze_encoder does not disable gradients
|
||||
|
||||
classifier_head = model.classifier_head
|
||||
embedding_dim = classifier_head.n_input # type: ignore
|
||||
|
||||
return encoder, embedding_dim
|
||||
|
||||
|
||||
class HistoSSLEncoder(TileEncoder):
|
||||
"""HistoSSL encoder pretrained on multiple histological datasets
|
||||
|
||||
Reference:
|
||||
- Ciga, Xu, Martel (2021). Self supervised contrastive learning for digital histopathology.
|
||||
arXiv:2011.13971
|
||||
"""
|
||||
|
||||
WEIGHTS_URL = ("https://github.com/ozanciga/self-supervised-histopathology/releases/"
|
||||
"download/tenpercent/tenpercent_resnet18.ckpt")
|
||||
|
||||
def _get_encoder(self) -> Tuple[Callable, int]:
|
||||
resnet18_model = resnet18(pretrained=False)
|
||||
num_features = resnet18_model.fc.in_features
|
||||
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
|
||||
histossl_encoder.fc = torch.nn.Sequential()
|
||||
for param in histossl_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
return histossl_encoder, num_features # type: ignore
|
|
@ -1,169 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence, Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import PIL
|
||||
from monai.config.type_definitions import KeysCollection
|
||||
from monai.transforms.transform import MapTransform, Randomizable
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
|
||||
|
||||
PathOrString = Union[Path, str]
|
||||
|
||||
|
||||
def load_pil_image(image_path: PathOrString) -> PIL.Image.Image:
|
||||
"""Load a PIL image in RGB format from the given path"""
|
||||
with PIL.PngImagePlugin.PngImageFile(image_path) as pil_png:
|
||||
image = np.asarray(pil_png)
|
||||
return image
|
||||
|
||||
|
||||
def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor:
|
||||
"""Load an image as a tensor from the given path"""
|
||||
pil_image = load_pil_image(image_path)
|
||||
return to_tensor(pil_image)
|
||||
|
||||
|
||||
def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
|
||||
progress: bool = False) -> torch.Tensor:
|
||||
"""Load a batch of images of the same size as a tensor from the given paths"""
|
||||
loading_generator = (load_image_as_tensor(path) for path in image_paths)
|
||||
if progress:
|
||||
from tqdm import tqdm
|
||||
loading_generator = tqdm(loading_generator, desc="Loading image stack",
|
||||
total=len(image_paths), leave=False)
|
||||
image_tensors = list(loading_generator)
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
class LoadTiled(MapTransform):
|
||||
"""Dictionary transform to load an individual image tile as a tensor from an input path"""
|
||||
|
||||
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for the image path(s) in the input dictionary.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
"""
|
||||
super().__init__(keys, allow_missing_keys)
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = load_image_as_tensor(data[key])
|
||||
return out_data
|
||||
|
||||
|
||||
class LoadTilesBatchd(MapTransform):
|
||||
"""Dictionary transform to load a batch of image tiles as a tensor from a list of input paths"""
|
||||
|
||||
# Cannot reuse MONAI readers because they support stacking only images with no channels
|
||||
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False,
|
||||
progress: bool = False) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for the image path(s) in the input dictionary.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
:param progress: Whether to display a tqdm progress bar.
|
||||
"""
|
||||
super().__init__(keys, allow_missing_keys)
|
||||
self.progress = progress
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = load_image_stack_as_tensor(data[key], progress=self.progress)
|
||||
return out_data
|
||||
|
||||
|
||||
class EncodeTilesBatchd(MapTransform):
|
||||
"""Dictionary transform to extract features from a batch tensor of image tiles"""
|
||||
|
||||
def __init__(self,
|
||||
keys: KeysCollection,
|
||||
encoder: TileEncoder,
|
||||
allow_missing_keys: bool = False,
|
||||
chunk_size: int = 0) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for the image tensor(s) in the input dictionary.
|
||||
:param encoder: The tile encoder to use for feature extraction.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
:param chunk_size: if > 0, extracts features in chunks of size chunk_size.
|
||||
"""
|
||||
super().__init__(keys, allow_missing_keys)
|
||||
self.encoder = encoder
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_tiles(self, images: torch.Tensor) -> torch.Tensor:
|
||||
device = next(self.encoder.parameters()).device
|
||||
if self.chunk_size > 0:
|
||||
embeddings = []
|
||||
chunks = torch.split(images, self.chunk_size)
|
||||
# TODO parallelize encoding - keep metadata and images aligned
|
||||
for chunk in chunks:
|
||||
chunk_embeddings = self._encode_images(chunk, device)
|
||||
embeddings.append(chunk_embeddings)
|
||||
return torch.cat(embeddings)
|
||||
else:
|
||||
return self._encode_images(images, device)
|
||||
|
||||
def _encode_images(self, images: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
images = images.to(device)
|
||||
embeddings = self.encoder(images)
|
||||
del images
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = self._encode_tiles(data[key])
|
||||
return out_data
|
||||
|
||||
|
||||
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data[indices]
|
||||
elif isinstance(data, Sequence):
|
||||
return [data[i] for i in indices]
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(data)} is not indexable")
|
||||
|
||||
|
||||
class Subsampled(MapTransform, Randomizable):
|
||||
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
|
||||
|
||||
def __init__(self, keys: KeysCollection, max_size: int,
|
||||
allow_missing_keys: bool = False) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for all batch elements that must be subsampled.
|
||||
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
|
||||
random down to `max_size` along their first dimension. If shorter, the elements are merely
|
||||
shuffled.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
"""
|
||||
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||
self.max_size = max_size
|
||||
self._indices: np.ndarray
|
||||
|
||||
def randomize(self, total_size: int) -> None:
|
||||
subsample_size = min(self.max_size, total_size)
|
||||
self._indices = self.R.choice(total_size, size=subsample_size)
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
size = len(data[self.keys[0]])
|
||||
self.randomize(size)
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = take_indices(data[key], self._indices)
|
||||
return out_data
|
|
@ -1,230 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
"""This script is specific to PANDA and is kept only for retrocompatibility.
|
||||
`create_tiles_dataset.py` is the new supported way to process slide datasets.
|
||||
"""
|
||||
import functools
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from monai.data import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from tqdm import tqdm
|
||||
|
||||
from InnerEye.ML.Histopathology.preprocessing import tiling
|
||||
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
|
||||
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
|
||||
TMP_SUFFIX = "_tmp"
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
|
||||
-> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]:
|
||||
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
||||
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
||||
foreground_mask = mask_tile > 0
|
||||
occupancy = foreground_mask.mean(axis=(-2, -1))
|
||||
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()
|
||||
|
||||
|
||||
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
||||
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
||||
|
||||
|
||||
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
||||
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
||||
|
||||
|
||||
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
||||
pil_image = PIL.Image.fromarray(array_hwc)
|
||||
pil_image.convert('RGB').save(path)
|
||||
return pil_image
|
||||
|
||||
|
||||
def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
|
||||
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
||||
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
|
||||
constant_values=255)
|
||||
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)
|
||||
|
||||
selected: np.ndarray
|
||||
occupancies: np.ndarray
|
||||
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
|
||||
n_discarded = (~selected).sum()
|
||||
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")
|
||||
|
||||
image_tiles = image_tiles[selected]
|
||||
mask_tiles = mask_tiles[selected]
|
||||
tile_locations = tile_locations[selected]
|
||||
occupancies = occupancies[selected]
|
||||
|
||||
abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)
|
||||
|
||||
return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded
|
||||
|
||||
|
||||
# TODO refactor this to separate metadata identification from saving. We might want the metadata
|
||||
# even if the saving fails
|
||||
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
|
||||
tile_location: Sequence[int], output_dir: Path) -> dict:
|
||||
slide_id = sample['image_id']
|
||||
descriptor = get_tile_descriptor(tile_location)
|
||||
image_tile_filename = f"train_images/{descriptor}.png"
|
||||
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"
|
||||
|
||||
save_image(image_tile, output_dir / image_tile_filename)
|
||||
save_image(mask_tile, output_dir / mask_tile_filename)
|
||||
|
||||
tile_metadata = {
|
||||
'slide_id': slide_id,
|
||||
'tile_id': get_tile_id(slide_id, tile_location),
|
||||
'image': image_tile_filename,
|
||||
'mask': mask_tile_filename,
|
||||
'tile_x': tile_location[0],
|
||||
'tile_y': tile_location[1],
|
||||
'data_provider': sample['data_provider'],
|
||||
'slide_isup_grade': sample['isup_grade'],
|
||||
'slide_gleason_score': sample['gleason_score'],
|
||||
}
|
||||
|
||||
return tile_metadata
|
||||
|
||||
|
||||
def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> None:
|
||||
slide_id = sample['image_id']
|
||||
slide_dir: Path = output_dir / (slide_id + "/")
|
||||
logging.info(f">>> Slide dir {slide_dir}")
|
||||
if slide_dir.exists(): # already processed slide - skip
|
||||
logging.info(f">>> Skipping {slide_dir} - already processed")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
slide_dir.mkdir(parents=True)
|
||||
|
||||
dataset_csv_path = slide_dir / "dataset.csv"
|
||||
dataset_csv_file = dataset_csv_path.open('w')
|
||||
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
|
||||
tiles_failure = 0
|
||||
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
||||
failed_tiles_file = failed_tiles_csv_path.open('w')
|
||||
failed_tiles_file.write('tile_id' + '\n')
|
||||
|
||||
logging.info(f"Loading slide {slide_id} ...")
|
||||
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
|
||||
sample = loader(sample) # load 'image' and 'mask' from disk
|
||||
|
||||
logging.info(f"Tiling slide {slide_id} ...")
|
||||
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
|
||||
generate_tiles(sample, tile_size, occupancy_threshold)
|
||||
n_tiles = image_tiles.shape[0]
|
||||
|
||||
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
||||
try:
|
||||
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
|
||||
slide_dir)
|
||||
tile_metadata['occupancy'] = occupancies[i]
|
||||
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
|
||||
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
|
||||
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
|
||||
dataset_csv_file.write(dataset_row + '\n')
|
||||
except Exception as e:
|
||||
tiles_failure += 1
|
||||
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
|
||||
failed_tiles_file.write(descriptor)
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while saving tile "
|
||||
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
||||
|
||||
dataset_csv_file.close()
|
||||
failed_tiles_file.close()
|
||||
if tiles_failure > 0:
|
||||
# TODO what we want to do with slides that have some failed tiles?
|
||||
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
||||
|
||||
|
||||
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||
full_csv = dataset_dir / "dataset.csv"
|
||||
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
|
||||
# and it seems to find many more files
|
||||
# print("List of files")
|
||||
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
|
||||
with full_csv.open('w') as full_csv_file:
|
||||
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
first_file = True
|
||||
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
|
||||
logging.info(f"Merging slide {slide_csv}")
|
||||
content = slide_csv.read_text()
|
||||
if not first_file:
|
||||
content = content[content.index('\n') + 1:] # discard header row for all but the first file
|
||||
full_csv_file.write(content)
|
||||
first_file = False
|
||||
return full_csv
|
||||
|
||||
|
||||
def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
|
||||
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:
|
||||
|
||||
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
||||
# to select a subsample use keyword n_slides
|
||||
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore
|
||||
|
||||
output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
|
||||
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")
|
||||
|
||||
if overwrite and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
||||
|
||||
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
||||
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
||||
tile_progress=not parallel)
|
||||
|
||||
if parallel:
|
||||
import multiprocessing
|
||||
|
||||
pool = multiprocessing.Pool()
|
||||
map_func = pool.imap_unordered # type: ignore
|
||||
else:
|
||||
map_func = map # type: ignore
|
||||
|
||||
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
|
||||
|
||||
if parallel:
|
||||
pool.close()
|
||||
|
||||
logging.info("Merging slide files in a single file")
|
||||
merge_dataset_csv_files(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(panda_dir="/tmp/datasets/PANDA",
|
||||
root_output_dir="/datadrive",
|
||||
level=1,
|
||||
tile_size=224,
|
||||
margin=64,
|
||||
occupancy_threshold=0.05,
|
||||
parallel=True,
|
||||
overwrite=False)
|
|
@ -1,306 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import shutil
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from monai.data import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from tqdm import tqdm
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
from InnerEye.ML.Histopathology.preprocessing import tiling
|
||||
from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, segment_foreground
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey, TileKey
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def select_tiles(foreground_mask: np.ndarray, occupancy_threshold: float) \
|
||||
-> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Exclude tiles that are mostly background based on estimated occupancy.
|
||||
|
||||
:param foreground_mask: Boolean array of shape (*, H, W).
|
||||
:param occupancy_threshold: Tiles with lower occupancy (between 0 and 1) will be discarded.
|
||||
:return: A tuple containing which tiles were selected and the estimated occupancies. These will
|
||||
be boolean and float arrays of shape (*,), or scalars if `foreground_mask` is a single tile.
|
||||
"""
|
||||
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
||||
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
||||
occupancy = foreground_mask.mean(axis=(-2, -1))
|
||||
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() # type: ignore
|
||||
|
||||
|
||||
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
||||
"""Format the XY tile coordinates into a tile descriptor."""
|
||||
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
||||
|
||||
|
||||
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
||||
"""Format the slide ID and XY tile coordinates into a unique tile ID."""
|
||||
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
||||
|
||||
|
||||
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
||||
"""Save an image array in (C, H, W) format to disk."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
||||
pil_image = PIL.Image.fromarray(array_hwc)
|
||||
pil_image.convert('RGB').save(path)
|
||||
return pil_image
|
||||
|
||||
|
||||
def generate_tiles(slide_image: np.ndarray, tile_size: int, foreground_threshold: float,
|
||||
occupancy_threshold: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
|
||||
"""Split the foreground of an input slide image into tiles.
|
||||
|
||||
:param slide_image: The RGB image array in (C, H, W) format.
|
||||
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||
:return: A tuple containing the image tiles (N, C, H, W), tile coordinates (N, 2), occupancies
|
||||
(N,), and total number of discarded empty tiles.
|
||||
"""
|
||||
image_tiles, tile_locations = tiling.tile_array_2d(slide_image, tile_size=tile_size,
|
||||
constant_values=255)
|
||||
foreground_mask, _ = segment_foreground(image_tiles, foreground_threshold)
|
||||
|
||||
selected, occupancies = select_tiles(foreground_mask, occupancy_threshold)
|
||||
n_discarded = (~selected).sum()
|
||||
logging.info(f"Percentage tiles discarded: {n_discarded / len(selected) * 100:.2f}")
|
||||
|
||||
image_tiles = image_tiles[selected]
|
||||
tile_locations = tile_locations[selected]
|
||||
occupancies = occupancies[selected]
|
||||
|
||||
return image_tiles, tile_locations, occupancies, n_discarded
|
||||
|
||||
|
||||
def get_tile_info(sample: Dict[SlideKey, Any], occupancy: float, tile_location: Sequence[int],
|
||||
rel_slide_dir: Path) -> Dict[TileKey, Any]:
|
||||
"""Map slide information and tiling outputs into tile-specific information dictionary.
|
||||
|
||||
:param sample: Slide dictionary.
|
||||
:param occupancy: Estimated tile foreground occuppancy.
|
||||
:param tile_location: Tile XY coordinates.
|
||||
:param rel_slide_dir: Directory where tiles are saved, relative to dataset root.
|
||||
:return: Tile information dictionary.
|
||||
"""
|
||||
slide_id = sample[SlideKey.SLIDE_ID]
|
||||
descriptor = get_tile_descriptor(tile_location)
|
||||
rel_image_path = f"{rel_slide_dir}/{descriptor}.png"
|
||||
|
||||
tile_info = {
|
||||
TileKey.SLIDE_ID: slide_id,
|
||||
TileKey.TILE_ID: get_tile_id(slide_id, tile_location),
|
||||
TileKey.IMAGE: rel_image_path,
|
||||
TileKey.LABEL: sample[SlideKey.LABEL],
|
||||
TileKey.TILE_X: tile_location[0],
|
||||
TileKey.TILE_Y: tile_location[1],
|
||||
TileKey.OCCUPANCY: occupancy,
|
||||
TileKey.SLIDE_METADATA: {TileKey.from_slide_metadata_key(key): value
|
||||
for key, value in sample[SlideKey.METADATA].items()}
|
||||
}
|
||||
|
||||
return tile_info
|
||||
|
||||
|
||||
def format_csv_row(tile_info: Dict[TileKey, Any], keys_to_save: Iterable[TileKey],
|
||||
metadata_keys: Iterable[str]) -> str:
|
||||
"""Format tile information dictionary as a row to write to a dataset CSV tile.
|
||||
|
||||
:param tile_info: Tile information dictionary.
|
||||
:param keys_to_save: Which main keys to include in the row, and in which order.
|
||||
:param metadata_keys: Likewise for metadata keys.
|
||||
:return: The formatted CSV row.
|
||||
"""
|
||||
tile_slide_metadata = tile_info.pop(TileKey.SLIDE_METADATA)
|
||||
fields = [str(tile_info[key]) for key in keys_to_save]
|
||||
fields.extend(str(tile_slide_metadata[key]) for key in metadata_keys)
|
||||
dataset_row = ','.join(fields)
|
||||
return dataset_row
|
||||
|
||||
|
||||
def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_size: int,
|
||||
foreground_threshold: Optional[float], occupancy_threshold: float, output_dir: Path,
|
||||
tile_progress: bool = False) -> None:
|
||||
"""Load and process a slide, saving tile images and information to a CSV file.
|
||||
|
||||
:param sample: Slide information dictionary, returned by the input slide dataset.
|
||||
:param level: Magnification level at which to process the slide.
|
||||
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
|
||||
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||
If `None` (default), an optimal threshold will be estimated automatically.
|
||||
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||
:param output_dir: Root directory for the output dataset; outputs for a single slide will be
|
||||
saved inside `output_dir/slide_id/`.
|
||||
:param tile_progress: Whether to display a progress bar in the terminal.
|
||||
"""
|
||||
slide_metadata: Dict[str, Any] = sample[SlideKey.METADATA]
|
||||
keys_to_save = (TileKey.SLIDE_ID, TileKey.TILE_ID, TileKey.IMAGE, TileKey.LABEL,
|
||||
TileKey.TILE_X, TileKey.TILE_Y, TileKey.OCCUPANCY)
|
||||
metadata_keys = tuple(TileKey.from_slide_metadata_key(key) for key in slide_metadata)
|
||||
csv_columns: Tuple[str, ...] = (*keys_to_save, *metadata_keys)
|
||||
|
||||
slide_id: str = sample[SlideKey.SLIDE_ID]
|
||||
rel_slide_dir = Path(slide_id)
|
||||
slide_dir = output_dir / rel_slide_dir
|
||||
logging.info(f">>> Slide dir {slide_dir}")
|
||||
if slide_dir.exists(): # already processed slide - skip
|
||||
logging.info(f">>> Skipping {slide_dir} - already processed")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
slide_dir.mkdir(parents=True)
|
||||
|
||||
dataset_csv_path = slide_dir / "dataset.csv"
|
||||
dataset_csv_file = dataset_csv_path.open('w')
|
||||
dataset_csv_file.write(','.join(csv_columns) + '\n') # write CSV header
|
||||
|
||||
n_failed_tiles = 0
|
||||
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
||||
failed_tiles_file = failed_tiles_csv_path.open('w')
|
||||
failed_tiles_file.write('tile_id' + '\n')
|
||||
|
||||
logging.info(f"Loading slide {slide_id} ...")
|
||||
loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin,
|
||||
foreground_threshold=foreground_threshold)
|
||||
sample = loader(sample) # load 'image' from disk
|
||||
|
||||
logging.info(f"Tiling slide {slide_id} ...")
|
||||
image_tiles, rel_tile_locations, occupancies, _ = \
|
||||
generate_tiles(sample[SlideKey.IMAGE], tile_size,
|
||||
sample[SlideKey.FOREGROUND_THRESHOLD],
|
||||
occupancy_threshold)
|
||||
|
||||
tile_locations = (sample[SlideKey.SCALE] * rel_tile_locations
|
||||
+ sample[SlideKey.ORIGIN]).astype(int)
|
||||
|
||||
n_tiles = image_tiles.shape[0]
|
||||
|
||||
logging.info(f"Saving tiles for slide {slide_id} ...")
|
||||
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
||||
try:
|
||||
tile_info = get_tile_info(sample, occupancies[i], tile_locations[i], rel_slide_dir)
|
||||
save_image(image_tiles[i], output_dir / tile_info[TileKey.IMAGE])
|
||||
dataset_row = format_csv_row(tile_info, keys_to_save, metadata_keys)
|
||||
dataset_csv_file.write(dataset_row + '\n')
|
||||
except Exception as e:
|
||||
n_failed_tiles += 1
|
||||
descriptor = get_tile_descriptor(tile_locations[i])
|
||||
failed_tiles_file.write(descriptor + '\n')
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while saving tile "
|
||||
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
||||
|
||||
dataset_csv_file.close()
|
||||
failed_tiles_file.close()
|
||||
if n_failed_tiles > 0:
|
||||
# TODO what we want to do with slides that have some failed tiles?
|
||||
logging.warning(f"{slide_id} is incomplete. {n_failed_tiles} tiles failed.")
|
||||
logging.info(f"Finished processing slide {slide_id}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
||||
|
||||
|
||||
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||
"""Combines all "*/dataset.csv" files into a single "dataset.csv" file in the given directory."""
|
||||
full_csv = dataset_dir / "dataset.csv"
|
||||
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
|
||||
# and it seems to find many more files
|
||||
# print("List of files")
|
||||
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
|
||||
with full_csv.open('w') as full_csv_file:
|
||||
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
first_file = True
|
||||
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
|
||||
logging.info(f"Merging slide {slide_csv}")
|
||||
content = slide_csv.read_text()
|
||||
if not first_file:
|
||||
content = content[content.index('\n') + 1:] # discard header row for all but the first file
|
||||
full_csv_file.write(content)
|
||||
first_file = False
|
||||
return full_csv
|
||||
|
||||
|
||||
def main(slides_dataset: SlidesDataset, root_output_dir: Union[str, Path],
|
||||
level: int, tile_size: int, margin: int, foreground_threshold: Optional[float],
|
||||
occupancy_threshold: float, parallel: bool = False, overwrite: bool = False,
|
||||
n_slides: Optional[int] = None) -> None:
|
||||
"""Process a slides dataset to produce a tiles dataset.
|
||||
|
||||
:param slides_dataset: Input tiles dataset object.
|
||||
:param root_output_dir: The root directory of the output tiles dataset.
|
||||
:param level: Magnification level at which to process the slide.
|
||||
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
|
||||
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||
If `None` (default), an optimal threshold will be estimated automatically.
|
||||
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||
:param parallel: Whether slides should be processed in parallel with multiprocessing.
|
||||
:param overwrite: Whether to overwrite an existing output tiles dataset. If `True`, will delete
|
||||
and recreate `root_output_dir`, otherwise will resume by skipping already processed slides.
|
||||
:param n_slides: If given, limit the total number of slides for debugging.
|
||||
"""
|
||||
|
||||
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
||||
# to select a subsample use keyword n_slides
|
||||
dataset = Dataset(slides_dataset)[:n_slides] # type: ignore
|
||||
|
||||
output_dir = Path(root_output_dir)
|
||||
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} "
|
||||
f"{slides_dataset.__class__.__name__} tiles at: {output_dir}")
|
||||
|
||||
if overwrite and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
||||
|
||||
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
||||
foreground_threshold=foreground_threshold,
|
||||
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
||||
tile_progress=not parallel)
|
||||
|
||||
if parallel:
|
||||
import multiprocessing
|
||||
|
||||
pool = multiprocessing.Pool()
|
||||
map_func = pool.imap_unordered # type: ignore
|
||||
else:
|
||||
map_func = map # type: ignore
|
||||
|
||||
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
|
||||
|
||||
if parallel:
|
||||
pool.close()
|
||||
|
||||
logging.info("Merging slide files in a single file")
|
||||
merge_dataset_csv_files(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset
|
||||
|
||||
# Example set up for an existing slides dataset:
|
||||
main(slides_dataset=TcgaPradDataset("/tmp/datasets/TCGA-PRAD"),
|
||||
root_output_dir="/datadrive/TCGA-PRAD_tiles",
|
||||
n_slides=5,
|
||||
level=3,
|
||||
tile_size=224,
|
||||
margin=64,
|
||||
foreground_threshold=None,
|
||||
occupancy_threshold=0.05,
|
||||
parallel=False,
|
||||
overwrite=True)
|
|
@ -1,114 +0,0 @@
|
|||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import skimage.filters
|
||||
from health_ml.utils import box_utils
|
||||
from monai.data.image_reader import WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||
|
||||
try:
|
||||
from cucim import CuImage
|
||||
except:
|
||||
logging.warning("cucim library not available, code may fail.")
|
||||
|
||||
|
||||
def get_luminance(slide: np.ndarray) -> np.ndarray:
|
||||
"""Compute a grayscale version of the input slide.
|
||||
|
||||
:param slide: The RGB image array in (*, C, H, W) format.
|
||||
:return: The single-channel luminance array as (*, H, W).
|
||||
"""
|
||||
# TODO: Consider more sophisticated luminance calculation if necessary
|
||||
return slide.mean(axis=-3) # type: ignore
|
||||
|
||||
|
||||
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
|
||||
-> Tuple[np.ndarray, float]:
|
||||
"""Segment the given slide by thresholding its luminance.
|
||||
|
||||
:param slide: The RGB image array in (*, C, H, W) format.
|
||||
:param threshold: Pixels with luminance below this value will be considered foreground.
|
||||
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
|
||||
:return: A tuple containing the boolean output array in (*, H, W) format and the threshold used.
|
||||
"""
|
||||
luminance = get_luminance(slide)
|
||||
if threshold is None:
|
||||
threshold = skimage.filters.threshold_otsu(luminance)
|
||||
return luminance < threshold, threshold
|
||||
|
||||
|
||||
def load_slide_at_level(reader: WSIReader, slide_obj: 'CuImage', level: int) -> np.ndarray:
|
||||
"""Load full slide array at the given magnification level.
|
||||
|
||||
This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415)
|
||||
fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417).
|
||||
|
||||
:param reader: A MONAI `WSIReader` using cuCIM backend.
|
||||
:param slide_obj: The cuCIM image object returned by `reader.read(<image_file>)`.
|
||||
:param level: Index of the desired magnification level as defined in the `slide_obj` headers.
|
||||
:return: The loaded image array in (C, H, W) format.
|
||||
"""
|
||||
size = slide_obj.resolutions['level_dimensions'][level][::-1]
|
||||
slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image
|
||||
return slide
|
||||
|
||||
|
||||
class LoadROId(MapTransform):
|
||||
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI).
|
||||
|
||||
Operates on dictionaries, replacing the file path in `image_key` with the loaded array in
|
||||
(C, H, W) format. Also adds the following entries:
|
||||
- `SlideKey.ORIGIN` (tuple): top-right coordinates of the bounding box
|
||||
- `SlideKey.SCALE` (float): corresponding scale, loaded from the file
|
||||
- `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground
|
||||
"""
|
||||
|
||||
def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0,
|
||||
margin: int = 0, foreground_threshold: Optional[float] = None) -> None:
|
||||
"""
|
||||
:param reader: And instance of MONAI's `WSIReader`.
|
||||
:param image_key: Image key in the input and output dictionaries.
|
||||
:param level: Magnification level to load from the raw multi-scale file.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||
:param foreground_threshold: Pixels with luminance below this value will be considered foreground.
|
||||
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
|
||||
"""
|
||||
super().__init__([image_key], allow_missing_keys=False)
|
||||
self.reader = reader
|
||||
self.image_key = image_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
self.foreground_threshold = foreground_threshold
|
||||
|
||||
def _get_bounding_box(self, slide_obj: 'CuImage') -> Tuple[box_utils.Box, float]:
|
||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||
highest_level = slide_obj.resolutions['level_count'] - 1
|
||||
scale = slide_obj.resolutions['level_downsamples'][highest_level]
|
||||
slide = load_slide_at_level(self.reader, slide_obj, level=highest_level)
|
||||
|
||||
foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold)
|
||||
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
|
||||
return bbox, threshold
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||
|
||||
level0_bbox, threshold = self._get_bounding_box(image_obj)
|
||||
|
||||
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
origin = (level0_bbox.x, level0_bbox.y)
|
||||
scale = image_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
|
||||
size=(scaled_bbox.w, scaled_bbox.h))
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.FOREGROUND_THRESHOLD] = threshold
|
||||
|
||||
image_obj.close()
|
||||
return data
|
|
@ -1,128 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
# These tiling implementations are adapted from PANDA Kaggle solutions, for example:
|
||||
# https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_1d_padding(length: int, tile_size: int) -> Tuple[int, int]:
|
||||
"""Computes symmetric padding for `length` to be divisible by `tile_size`."""
|
||||
pad = (tile_size - length % tile_size) % tile_size
|
||||
return (pad // 2, pad - pad // 2)
|
||||
|
||||
|
||||
def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
|
||||
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Symmetrically pads a 2D `array` such that both dimensions are divisible by `tile_size`.
|
||||
|
||||
:param array: 2D image array.
|
||||
:param tile_size: Width/height of each tile in pixels.
|
||||
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
|
||||
:return: A tuple containing:
|
||||
- `padded_array`: Resulting array, in the same CHW/HWC layout as the input.
|
||||
- `offset`: XY offset introduced by the padding. Add this to coordinates relative to the
|
||||
original array to obtain indices for the padded array.
|
||||
"""
|
||||
height, width = array.shape[1:] if channels_first else array.shape[:-1]
|
||||
padding_h = get_1d_padding(height, tile_size)
|
||||
padding_w = get_1d_padding(width, tile_size)
|
||||
padding = [padding_h, padding_w]
|
||||
channels_axis = 0 if channels_first else 2
|
||||
padding.insert(channels_axis, (0, 0)) # zero padding on channels axis
|
||||
padded_array = np.pad(array, padding, **pad_kwargs)
|
||||
offset = (padding_w[0], padding_h[0])
|
||||
return padded_array, np.array(offset)
|
||||
|
||||
|
||||
def tile_array_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
|
||||
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Split an image array into square non-overlapping tiles.
|
||||
|
||||
The array will be padded symmetrically if its dimensions are not exact multiples of `tile_size`.
|
||||
|
||||
:param array: Image array.
|
||||
:param tile_size: Width/height of each tile in pixels.
|
||||
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
|
||||
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:return: A tuple containing:
|
||||
- `tiles`: A batch of tiles in NCHW layout.
|
||||
- `coords`: XY coordinates of each tile, in the same order.
|
||||
"""
|
||||
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first, **pad_kwargs)
|
||||
if channels_first:
|
||||
channels, height, width = padded_array.shape
|
||||
else:
|
||||
height, width, channels = padded_array.shape
|
||||
n_tiles_h = height // tile_size
|
||||
n_tiles_w = width // tile_size
|
||||
|
||||
if channels_first:
|
||||
intermediate_shape = (channels, n_tiles_h, tile_size, n_tiles_w, tile_size)
|
||||
axis_order = (1, 3, 0, 2, 4) # (n_tiles_h, n_tiles_w, channels, tile_size, tile_size)
|
||||
output_shape = (n_tiles_h * n_tiles_w, channels, tile_size, tile_size)
|
||||
else:
|
||||
intermediate_shape = (n_tiles_h, tile_size, n_tiles_w, tile_size, channels)
|
||||
axis_order = (0, 2, 1, 3, 4) # (n_tiles_h, n_tiles_w, tile_size, tile_size, channels)
|
||||
output_shape = (n_tiles_h * n_tiles_w, tile_size, tile_size, channels)
|
||||
|
||||
tiles = padded_array.reshape(intermediate_shape) # Split width and height axes
|
||||
tiles = tiles.transpose(axis_order)
|
||||
tiles = tiles.reshape(output_shape) # Flatten tile batch dimension
|
||||
|
||||
# Compute top-left coordinates of every tile, relative to the original array's origin
|
||||
coords_h = tile_size * np.arange(n_tiles_h) - offset_h
|
||||
coords_w = tile_size * np.arange(n_tiles_w) - offset_w
|
||||
# Shape: (n_tiles_h * n_tiles_w, 2)
|
||||
coords = np.stack(np.meshgrid(coords_w, coords_h), axis=-1).reshape(-1, 2)
|
||||
|
||||
return tiles, coords
|
||||
|
||||
|
||||
def assemble_tiles_2d(tiles: np.ndarray, coords: np.ndarray, fill_value: Optional[float] = np.nan,
|
||||
channels_first: Optional[bool] = True) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Assembles a 2D array from sequences of tiles and coordinates.
|
||||
|
||||
:param tiles: Stack of tiles with batch dimension first.
|
||||
:param coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]).
|
||||
:param tile_size: Size of each tile; must be >0.
|
||||
:param fill_value: Value to assign to empty elements (default: `NaN`).
|
||||
:param channels_first: Whether each tile is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:return: A tuple containing:
|
||||
- `array`: The reassembled 2D array with the smallest dimensions to contain all given tiles.
|
||||
- `offset`: The lowest XY coordinates.
|
||||
- `offset`: XY offset introduced by the assembly. Add this to tile coordinates to obtain
|
||||
indices for the assembled array.
|
||||
"""
|
||||
if coords.shape[0] != tiles.shape[0]:
|
||||
raise ValueError(f"Tile coordinates and values must have the same length, "
|
||||
f"got {coords.shape[0]} and {tiles.shape[0]}")
|
||||
|
||||
if channels_first:
|
||||
n_tiles, channels, tile_size, _ = tiles.shape
|
||||
else:
|
||||
n_tiles, tile_size, _, channels = tiles.shape
|
||||
tile_xs, tile_ys = coords.T
|
||||
|
||||
x_min, x_max = min(tile_xs), max(tile_xs + tile_size)
|
||||
y_min, y_max = min(tile_ys), max(tile_ys + tile_size)
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
output_shape = (channels, height, width) if channels_first else (height, width, channels)
|
||||
array = np.full(output_shape, fill_value)
|
||||
|
||||
offset = np.array([-x_min, -y_min])
|
||||
for idx in range(n_tiles):
|
||||
row = coords[idx, 1] + offset[1]
|
||||
col = coords[idx, 0] + offset[0]
|
||||
if channels_first:
|
||||
array[:, row:row + tile_size, col:col + tile_size] = tiles[idx]
|
||||
else:
|
||||
array[row:row + tile_size, col:col + tile_size, :] = tiles[idx]
|
||||
|
||||
return array, offset
|
|
@ -1,40 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Script to find mean and standard deviation of desired metrics from cross validation child runs.
|
||||
"""
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from health_azure import aggregate_hyperdrive_metrics, get_workspace
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
|
||||
|
||||
def get_cross_validation_metrics_df(run_id: str) -> pd.DataFrame:
|
||||
"""
|
||||
Function to aggregate the metric over cross-validation runs
|
||||
:param run_id: run id of the hyperdrive run containing child runs
|
||||
"""
|
||||
aml_workspace = get_workspace()
|
||||
os.chdir(fixed_paths.repository_root_directory())
|
||||
df = aggregate_hyperdrive_metrics(run_id=run_id,
|
||||
child_run_arg_name="cross_validation_split_index",
|
||||
aml_workspace=aml_workspace)
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
metrics_list = ['test/accuracy', 'test/auroc', 'test/f1score', 'test/precision', 'test/recall']
|
||||
run_id = "hsharma_features_viz:HD_eff4c009-2f9f-4c2c-94c6-c0c84944a412"
|
||||
metrics_df = get_cross_validation_metrics_df(run_id=run_id)
|
||||
for metric in metrics_list:
|
||||
if metric in metrics_df.index.values:
|
||||
mean = metrics_df.loc[[metric]].mean(axis=1)[metric]
|
||||
std = metrics_df.loc[[metric]].std(axis=1)[metric]
|
||||
print(f"{metric}: {round(mean,4)} ± {round(std,4)}")
|
||||
else:
|
||||
print(f"Metric {metric} not found in the Hyperdrive run metrics for run id {run_id}.")
|
|
@ -1,26 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from health_azure import DatasetConfig
|
||||
from health_azure.utils import get_workspace
|
||||
|
||||
|
||||
def mount_dataset(dataset_id: str) -> str:
|
||||
ws = get_workspace()
|
||||
target_folder = "/tmp/datasets/"
|
||||
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
|
||||
dataset_mount_folder, mount_ctx = dataset.to_input_dataset_local(ws)
|
||||
mount_ctx.start()
|
||||
assert next(dataset_mount_folder.iterdir()), "Mounted data folder is empty"
|
||||
return str(dataset_mount_folder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--id', type=str, dest='dataset_id',
|
||||
help='Name of the Azure dataset e.g. PANDA or TCGA-CRCk')
|
||||
args = parser.parse_args()
|
||||
mount_dataset(args.dataset_id)
|
|
@ -1,98 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
|
||||
import umap
|
||||
from sklearn.manifold import TSNE
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
Get the t-sne projection of high dimensional data in a lower dimensional space
|
||||
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
|
||||
:param **kwargs: keyword arguments to be passed to TSNE()
|
||||
:return: list of features in lower dimensional space (n x c for n samples and c components)
|
||||
"""
|
||||
tsne_2d = TSNE(n_components=n_components, n_jobs=n_jobs, **kwargs)
|
||||
tsne_proj = tsne_2d.fit_transform(features)
|
||||
return tsne_proj
|
||||
|
||||
|
||||
def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
Get the umap projection of high dimensional data in a lower dimensional space
|
||||
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
|
||||
:param **kwargs: keyword arguments to be passed to UMAP()
|
||||
:return: list of features in lower dimensional space (n x c for n samples and c components)
|
||||
"""
|
||||
umap_2d = umap.UMAP(n_components=n_components, n_jobs=n_jobs, **kwargs)
|
||||
umap_proj = umap_2d.fit_transform(features)
|
||||
return umap_proj
|
||||
|
||||
|
||||
def normalize_array_minmax(arr: List[float]) -> List[float]:
|
||||
"""
|
||||
Normalize an array in range 0 to 1
|
||||
:param arr: array to be normalized
|
||||
:return: normalized array
|
||||
"""
|
||||
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
|
||||
|
||||
|
||||
def normalize_array_mean(arr: List[float]) -> List[float]:
|
||||
"""
|
||||
Normalize an array with zero mean and unit variance
|
||||
:param arr: array to be normalized
|
||||
:return: normalized array
|
||||
"""
|
||||
return (arr - np.mean(arr)) / np.std(arr)
|
||||
|
||||
|
||||
def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str], title: str = "") -> None:
|
||||
"""
|
||||
Plot a scatter plot of projected features in two dimensions
|
||||
:param data: features projected in 2d space (nx2)
|
||||
:param labels: corresponding labels of the data (nx1)
|
||||
:param classes: list of classes in the dataset
|
||||
:param title: plot title string
|
||||
"""
|
||||
plt.figure()
|
||||
scatter = plt.scatter(data[:, 0], data[:, 1], 20, labels)
|
||||
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
|
||||
plt.title(title)
|
||||
|
||||
|
||||
def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outliers: bool, title: str = "") -> None:
|
||||
"""
|
||||
Plot a box whisker plot of column data
|
||||
:param columns: data to be plotted in columns
|
||||
:param column_names: names of the columns
|
||||
:param show_outliers: whether outliers need to be shown
|
||||
:param title: plot title string
|
||||
"""
|
||||
plt.figure()
|
||||
_, ax = plt.subplots()
|
||||
ax.boxplot(data_list, showfliers=show_outliers)
|
||||
positions = range(1, len(column_names)+1)
|
||||
means = []
|
||||
for i in range(len(data_list)):
|
||||
means.append(np.mean(data_list[i]))
|
||||
ax.plot(positions, means, 'rs')
|
||||
plt.xticks(positions, column_names)
|
||||
plt.title(title)
|
||||
|
||||
|
||||
def plot_histogram(data: List[Any], title: str = "") -> None:
|
||||
"""
|
||||
Plot a histogram given some data
|
||||
:param data: data to be plotted
|
||||
:param title: plot title string
|
||||
"""
|
||||
plt.figure()
|
||||
plt.hist(data, bins=50)
|
||||
plt.gca().set(title=title, xlabel='Values', ylabel='Frequency')
|
|
@ -1,36 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from health_azure import download_files_from_run_id, get_workspace
|
||||
from InnerEye.Common import fixed_paths
|
||||
|
||||
|
||||
def download_file_if_necessary(run_id: str, remote_dir: Path, download_dir: Path, filename: str) -> None:
|
||||
"""
|
||||
Function to download any file from an AML run if it doesn't exist locally
|
||||
:param run_id: run ID of the AML run
|
||||
:param remote_dir: remote directory from where the file is downloaded
|
||||
:param download_dir: local directory where to save the downloaded file
|
||||
:param filename: name of the file to be downloaded (e.g. `"test_output.csv"`).
|
||||
"""
|
||||
aml_workspace = get_workspace()
|
||||
os.chdir(fixed_paths.repository_root_directory())
|
||||
local_path = download_dir / run_id.split(":")[1] / "outputs" / filename
|
||||
remote_path = remote_dir / filename
|
||||
if local_path.exists():
|
||||
print("File already exists at", local_path)
|
||||
else:
|
||||
local_dir = local_path.parent.parent
|
||||
local_dir.mkdir(exist_ok=True, parents=True)
|
||||
download_files_from_run_id(run_id=run_id,
|
||||
output_folder=local_dir,
|
||||
prefix=str(remote_path),
|
||||
aml_workspace=aml_workspace,
|
||||
validate_checksum=True)
|
||||
assert local_path.exists()
|
||||
print("File is downloaded at", local_path)
|
|
@ -1,31 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
|
||||
def location_selected_tiles(tile_coords: np.ndarray,
|
||||
location_bbox: List[int],
|
||||
level: int) -> np.ndarray:
|
||||
""" Return the scaled and shifted tile co-ordinates for selected tiles in the slide.
|
||||
:param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]) in original resolution.
|
||||
:param location_bbox: Location of the bounding box on the slide in original resolution.
|
||||
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available.
|
||||
(e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled).
|
||||
"""
|
||||
level_dict = {0: 1, 1: 4, 2: 16}
|
||||
factor = level_dict[level]
|
||||
|
||||
x_tr, y_tr = location_bbox
|
||||
tile_xs, tile_ys = tile_coords.T
|
||||
tile_xs = tile_xs - x_tr
|
||||
tile_ys = tile_ys - y_tr
|
||||
tile_xs = tile_xs//factor
|
||||
tile_ys = tile_ys//factor
|
||||
|
||||
sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()])
|
||||
|
||||
return sel_coords
|
|
@ -1,58 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from torch import as_tensor, device, nn, no_grad, prod, rand
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
|
||||
def get_imagenet_preprocessing() -> nn.Module:
|
||||
return Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def setup_feature_extractor(pretrained_model: nn.Module,
|
||||
input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]:
|
||||
try:
|
||||
# Attempt to auto-detect final classification layer:
|
||||
num_features: int = pretrained_model.fc.in_features # type: ignore
|
||||
pretrained_model.fc = nn.Flatten()
|
||||
feature_extractor = pretrained_model
|
||||
except AttributeError:
|
||||
# Otherwise fallback to sequence of child modules:
|
||||
layers = list(pretrained_model.children())[:-1]
|
||||
layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps
|
||||
feature_extractor = nn.Sequential(*layers)
|
||||
with no_grad():
|
||||
feature_shape = feature_extractor(rand(1, *input_dim)).shape
|
||||
num_features = int(prod(as_tensor(feature_shape)).item())
|
||||
# fix weights, no fine-tuning
|
||||
for param in feature_extractor.parameters():
|
||||
param.requires_grad = False
|
||||
return feature_extractor, num_features
|
||||
|
||||
|
||||
def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Load weights to the histoSSL model from the given URL
|
||||
https://github.com/ozanciga/self-supervised-histopathology
|
||||
"""
|
||||
map_location = device('cpu')
|
||||
state = load_state_dict_from_url(weights_url, map_location=map_location)
|
||||
state_dict = state['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
|
||||
new_weights = {}
|
||||
for key, value in state_dict.items():
|
||||
model_key = key.replace('model.', '').replace('resnet.', '')
|
||||
if model_key in model_dict:
|
||||
new_weights[model_key] = value
|
||||
if len(new_weights) == 0:
|
||||
raise RuntimeError("Weights could not be loaded.")
|
||||
model_dict.update(new_weights) # type: ignore
|
||||
|
||||
model.load_state_dict(model_dict) # type: ignore
|
||||
return model
|
|
@ -1,184 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Tuple, List, Any, Dict
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.collections as collection
|
||||
import seaborn as sns
|
||||
|
||||
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
|
||||
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
|
||||
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles
|
||||
|
||||
|
||||
def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1,
|
||||
select: Tuple = ('lowest_pred', 'highest_att'),
|
||||
slide_col: str = ResultsKey.SLIDE_ID, gt_col: str = ResultsKey.TRUE_LABEL,
|
||||
attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.CLASS_PROBS,
|
||||
return_col: str = ResultsKey.IMAGE_PATH) -> List[Tuple[Any, Any, List[Any], List[Any]]]:
|
||||
"""
|
||||
:param results: List that contains slide_level dicts
|
||||
:param n_tiles: number of tiles to be selected for each slide
|
||||
:param n_slides: number of slides to be selected
|
||||
:param label: which label to use to select slides
|
||||
:param select: criteria to be used to sort the slides (select[0]) and the tiles (select[1])
|
||||
:param slide_col: column name that contains slide identifiers
|
||||
:param gt_col: column name that contains labels
|
||||
:param attn_col: column name that contains scores used to sort tiles
|
||||
:param prob_col: column name that contains scores used to sort slides
|
||||
:param return_col: column name of the values we want to return for each tile
|
||||
:return: tuple containing the slides id, the slide score, the tile ids, the tiles scores
|
||||
"""
|
||||
tmp_s = [(results[prob_col][i][label], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore
|
||||
if select[0] == 'lowest_pred':
|
||||
tmp_s.sort(reverse=False)
|
||||
elif select[0] == 'highest_pred':
|
||||
tmp_s.sort(reverse=True)
|
||||
else:
|
||||
ValueError('select value not recognised')
|
||||
_, sorted_idx = zip(*tmp_s)
|
||||
k_idx = []
|
||||
if select[1] == 'highest_att':
|
||||
descending = True
|
||||
elif select[1] == 'lowest_att':
|
||||
descending = False
|
||||
for _, slide_idx in enumerate(sorted_idx[:n_slides]):
|
||||
tmp = results[attn_col][slide_idx]
|
||||
_, t_indices = torch.sort(tmp, descending=descending)
|
||||
k_tiles = []
|
||||
scores = []
|
||||
for t_idx in t_indices[0][:n_tiles]:
|
||||
k_tiles.append(results[return_col][slide_idx][t_idx])
|
||||
scores.append(results[attn_col][slide_idx][0][t_idx])
|
||||
# slide_ids are duplicated
|
||||
k_idx.append((results[slide_col][slide_idx][0],
|
||||
results[prob_col][slide_idx],
|
||||
k_tiles, scores))
|
||||
return k_idx
|
||||
|
||||
|
||||
def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.CLASS_PROBS,
|
||||
gt_col: str = ResultsKey.TRUE_LABEL) -> plt.figure:
|
||||
"""
|
||||
:param results: List that contains slide_level dicts
|
||||
:param prob_col: column name that contains the scores
|
||||
:param gt_col: column name that contains the true label
|
||||
:return: matplotlib figure of the scores histogram by class
|
||||
"""
|
||||
n_classes = len(results[prob_col][0])
|
||||
scores_class = []
|
||||
for j in range(n_classes):
|
||||
scores = [results[prob_col][i][j].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == j]
|
||||
scores_class.append(scores)
|
||||
fig, ax = plt.subplots()
|
||||
ax.hist(scores_class, label=[str(i) for i in range(n_classes)], alpha=0.5)
|
||||
ax.set_xlabel("Predicted Score")
|
||||
ax.legend()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_attention_tiles(slide: str, scores: List[float], paths: List, attn: List, case: str, ncols: int = 5,
|
||||
size: Tuple = (10, 10)) -> plt.figure:
|
||||
"""
|
||||
:param slide: slide identifier
|
||||
:param scores: predicted scores of each class for the slide
|
||||
:param paths: list of paths to tiles belonging to the slide
|
||||
:param attn: list of scores belonging to the tiles in paths. paths and attn are expected to have the same shape
|
||||
:param case: string used to define the title of the plot e.g. TP
|
||||
:param ncols: number of cols the produced figure should have
|
||||
:param size: size of the plot
|
||||
:return: matplotlib figure of each tile in paths with attn score
|
||||
"""
|
||||
nrows = int(ceil(len(paths) / ncols))
|
||||
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=size)
|
||||
fig.suptitle(f"{case}: {slide} P=%.2f" % max(scores))
|
||||
for i in range(len(paths)):
|
||||
img = load_pil_image(paths[i])
|
||||
axs.ravel()[i].imshow(img, clim=(0, 255), cmap='gray')
|
||||
axs.ravel()[i].set_title("%.6f" % attn[i].cpu().item())
|
||||
for i in range(len(axs.ravel())):
|
||||
axs.ravel()[i].set_axis_off()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure:
|
||||
"""Plots a slide thumbnail from a given slide image and scale.
|
||||
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
|
||||
:return: matplotlib figure of the slide thumbnail.
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
slide_image = slide_image.transpose(1, 2, 0)
|
||||
ax.imshow(slide_image)
|
||||
ax.set_axis_off()
|
||||
original_size = fig.get_size_inches()
|
||||
fig.set_size_inches((original_size[0]*scale, original_size[1]*scale))
|
||||
return fig
|
||||
|
||||
|
||||
def plot_heatmap_overlay(slide: str,
|
||||
slide_image: np.ndarray,
|
||||
results: Dict[str, List[Any]],
|
||||
location_bbox: List[int],
|
||||
tile_size: int = 224,
|
||||
level: int = 1) -> plt.figure:
|
||||
"""Plots heatmap of selected tiles (e.g. tiles in a bag) overlay on the corresponding slide.
|
||||
:param slide: slide identifier.
|
||||
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
|
||||
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
|
||||
:param tile_size: Size of each tile. Default 224.
|
||||
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1.
|
||||
:param location_bbox: Location of the bounding box of the slide.
|
||||
:return: matplotlib figure of the heatmap of the given tiles on slide.
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
slide_image = slide_image.transpose(1, 2, 0)
|
||||
ax.imshow(slide_image)
|
||||
ax.set_xlim(0, slide_image.shape[1])
|
||||
ax.set_ylim(slide_image.shape[0], 0)
|
||||
|
||||
coords = []
|
||||
slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]]
|
||||
slide_idx = slide_ids.index(slide)
|
||||
attentions = results[ResultsKey.BAG_ATTN][slide_idx]
|
||||
|
||||
# for each tile in the bag
|
||||
for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])):
|
||||
tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(),
|
||||
results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()]))
|
||||
coords.append(tile_coords)
|
||||
|
||||
coords = np.array(coords)
|
||||
attentions = np.array(attentions.cpu()).reshape(-1)
|
||||
|
||||
sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level)
|
||||
cmap = plt.cm.get_cmap('Reds')
|
||||
|
||||
tile_xs, tile_ys = sel_coords.T
|
||||
rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)]
|
||||
|
||||
pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None)
|
||||
pc.set_array(np.array(attentions))
|
||||
pc.set_clim([0, 1])
|
||||
ax.add_collection(pc)
|
||||
plt.colorbar(pc, ax=ax)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: List[str]) -> plt.figure:
|
||||
"""Plots a normalized confusion matrix and returns the figure.
|
||||
param cm: Normalized confusion matrix to be plotted.
|
||||
param class_names: List of class names.
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%")
|
||||
ax.set_xlabel('Predicted')
|
||||
ax.set_ylabel('True')
|
||||
ax.xaxis.set_ticklabels(class_names)
|
||||
ax.yaxis.set_ticklabels(class_names)
|
||||
return fig
|
|
@ -1,67 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SlideKey(str, Enum):
|
||||
SLIDE_ID = 'slide_id'
|
||||
IMAGE = 'image'
|
||||
IMAGE_PATH = 'image_path'
|
||||
MASK = 'mask'
|
||||
MASK_PATH = 'mask_path'
|
||||
LABEL = 'label'
|
||||
SPLIT = 'split'
|
||||
SCALE = 'scale'
|
||||
ORIGIN = 'origin'
|
||||
FOREGROUND_THRESHOLD = 'foreground_threshold'
|
||||
METADATA = 'metadata'
|
||||
LOCATION = 'location'
|
||||
|
||||
|
||||
class TileKey(str, Enum):
|
||||
TILE_ID = 'tile_id'
|
||||
SLIDE_ID = 'slide_id'
|
||||
IMAGE = 'image'
|
||||
IMAGE_PATH = 'image_path'
|
||||
MASK = 'mask'
|
||||
MASK_PATH = 'mask_path'
|
||||
LABEL = 'label'
|
||||
SPLIT = 'split'
|
||||
TILE_X = 'tile_x'
|
||||
TILE_Y = 'tile_y'
|
||||
OCCUPANCY = 'occupancy'
|
||||
FOREGROUND_THRESHOLD = 'foreground_threshold'
|
||||
SLIDE_METADATA = 'slide_metadata'
|
||||
|
||||
@staticmethod
|
||||
def from_slide_metadata_key(slide_metadata_key: str) -> str:
|
||||
return 'slide_' + slide_metadata_key
|
||||
|
||||
|
||||
class ResultsKey(str, Enum):
|
||||
SLIDE_ID = 'slide_id'
|
||||
TILE_ID = 'tile_id'
|
||||
IMAGE = 'image'
|
||||
IMAGE_PATH = 'image_path'
|
||||
LOSS = 'loss'
|
||||
PROB = 'prob'
|
||||
CLASS_PROBS = 'prob_class'
|
||||
PRED_LABEL = 'pred_label'
|
||||
TRUE_LABEL = 'true_label'
|
||||
BAG_ATTN = 'bag_attn'
|
||||
TILE_X = "x"
|
||||
TILE_Y = "y"
|
||||
|
||||
|
||||
class MetricsKey(str, Enum):
|
||||
ACC = 'accuracy'
|
||||
ACC_MACRO = 'macro_accuracy'
|
||||
ACC_WEIGHTED = 'weighted_accuracy'
|
||||
CONF_MATRIX = 'confusion_matrix'
|
||||
AUROC = 'auroc'
|
||||
PRECISION = 'precision'
|
||||
RECALL = 'recall'
|
||||
F1 = 'f1score'
|
|
@ -1,21 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def extract_fields(row: pd.Series) -> dict:
|
||||
# Paths are structured as follows:
|
||||
# "CRC_DX_[TEST|TRAIN]/[MSS|MSIMUT]/blk-{tile_id}-{slide_id}-01Z-00-DX1.png"
|
||||
# - tile_id is an uppercase string of 12 letters
|
||||
# - slide_id is "TCGA-XX-XXXX"
|
||||
parts = row.image.split('/')
|
||||
return {
|
||||
'slide_id': parts[2][17:29],
|
||||
'tile_id': parts[2][4:16],
|
||||
'image': row.image,
|
||||
'label': {'MSS': 0, 'MSIMUT': 1}[parts[1]],
|
||||
'split': parts[0][7:].lower(),
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
from typing import Any, Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from monai.data.dataset import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||
|
||||
|
||||
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
:param sample: dict describing image metadata. Example:
|
||||
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
|
||||
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
|
||||
'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'],
|
||||
'data_provider': ['karolinska'],
|
||||
'isup_grade': tensor([0]),
|
||||
'gleason_score': ['0+0']}
|
||||
:param level: level of resolution to be loaded
|
||||
:param margin: margin to be included
|
||||
:return: a dict containing the image data and metadata
|
||||
"""
|
||||
loader = LoadPandaROId(WSIReader('cuCIM'), level=level, margin=margin)
|
||||
img = loader(sample)
|
||||
return img
|
||||
|
||||
|
||||
def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int,
|
||||
title_key: str = 'data_provider') -> None:
|
||||
"""
|
||||
:param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path.
|
||||
Look at the PandaDataset for more detail
|
||||
:param nsamples: number of random samples to be visualized
|
||||
:param ncols: number of columns in the figure grid. Nrows is automatically inferred
|
||||
:param level: level of resolution to be loaded
|
||||
:param margin: margin to be included
|
||||
:param title_key: metadata key in image_dict used to label each subplot
|
||||
"""
|
||||
panda_dataset = Dataset(PandaDataset(root=panda_dir))[:nsamples] # type: ignore
|
||||
loader = DataLoader(panda_dataset, batch_size=1)
|
||||
|
||||
nrows = math.ceil(nsamples/ncols)
|
||||
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9))
|
||||
|
||||
for dict_images, ax in zip(loader, axes.flat):
|
||||
slide_id = dict_images[SlideKey.SLIDE_ID]
|
||||
title = dict_images[SlideKey.METADATA][title_key]
|
||||
print(f">>> Slide {slide_id}")
|
||||
img = load_image_dict(dict_images, level=level, margin=margin)
|
||||
ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0))
|
||||
ax.set_title(title)
|
||||
fig.tight_layout()
|
|
@ -1,122 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments.
|
||||
It is responsible for instantiating the encoder and full DeepMIL model. Subclasses should define
|
||||
their datamodules and configure experiment-specific parameters.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type # noqa
|
||||
|
||||
import param
|
||||
from torch import nn
|
||||
from torchvision.models.resnet import resnet18
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
||||
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
|
||||
from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
|
||||
ImageNetEncoder, ImageNetSimCLREncoder,
|
||||
InnerEyeSSLEncoder, TileEncoder)
|
||||
|
||||
|
||||
class BaseMIL(LightningContainer):
|
||||
# Model parameters:
|
||||
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
|
||||
is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:"
|
||||
"`False` (default), or `True`.")
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
|
||||
|
||||
# Encoder parameters:
|
||||
encoder_type: str = param.String(doc="Name of the encoder class to use.")
|
||||
tile_size: int = param.Integer(224, bounds=(1, None), doc="Tile width/height, in pixels.")
|
||||
n_channels: int = param.Integer(3, bounds=(1, None), doc="Number of channels in the tile.")
|
||||
|
||||
# Data module parameters:
|
||||
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
|
||||
max_bag_size: int = param.Integer(1000, bounds=(0, None),
|
||||
doc="Upper bound on number of tiles in each loaded bag. "
|
||||
"If 0 (default), will return all samples in each bag. "
|
||||
"If > 0, bags larger than `max_bag_size` will yield "
|
||||
"random subsets of instances.")
|
||||
cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode,
|
||||
doc="The type of caching to perform: "
|
||||
"'memory' (default), 'disk', or 'none'.")
|
||||
precache_location: str = param.ClassSelector(default=CacheLocation.NONE, class_=CacheLocation,
|
||||
doc="Whether to pre-cache the entire transformed dataset upfront "
|
||||
"and save it to disk and if re-load in cpu or gpu. Options:"
|
||||
"`none` (default),`cpu`, `gpu`")
|
||||
encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading"
|
||||
"enconding_chunk_size tiles per chunk")
|
||||
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.encoder_type == InnerEyeSSLEncoder.__name__:
|
||||
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
|
||||
|
||||
self.encoder = self.get_encoder()
|
||||
if not self.is_finetune:
|
||||
self.encoder.eval()
|
||||
|
||||
def get_encoder(self) -> TileEncoder:
|
||||
if self.encoder_type == ImageNetEncoder.__name__:
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18,
|
||||
tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
|
||||
return ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == HistoSSLEncoder.__name__:
|
||||
return HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
elif self.encoder_type == InnerEyeSSLEncoder.__name__:
|
||||
return InnerEyeSSLEncoder(pl_checkpoint_path=self.downloader.local_checkpoint_path,
|
||||
tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
|
||||
|
||||
def get_pooling_layer(self) -> Type[nn.Module]:
|
||||
if self.pooling_type == AttentionLayer.__name__:
|
||||
return AttentionLayer
|
||||
elif self.pooling_type == GatedAttentionLayer.__name__:
|
||||
return GatedAttentionLayer
|
||||
elif self.pooling_type == MeanPoolingLayer.__name__:
|
||||
return MeanPoolingLayer
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||
|
||||
def create_model(self) -> DeepMILModule:
|
||||
self.data_module = self.get_data_module()
|
||||
# Encoding is done in the datamodule, so here we provide instead a dummy
|
||||
# no-op IdentityEncoder to be used inside the model
|
||||
if self.is_finetune:
|
||||
self.model_encoder = self.encoder
|
||||
for params in self.model_encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
|
||||
return DeepMILModule(encoder=self.model_encoder,
|
||||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
dropout_rate=self.dropout_rate,
|
||||
class_weights=self.data_module.class_weights,
|
||||
l_rate=self.l_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
adam_betas=self.adam_betas)
|
||||
|
||||
def get_data_module(self) -> TilesDataModule:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_slide_dataset(self) -> SlidesDataset:
|
||||
raise NotImplementedError
|
|
@ -1,175 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset.
|
||||
Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=<encoder class name>`
|
||||
|
||||
For convenience, this module also defines encoder-specific containers that can be invoked without
|
||||
additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL`
|
||||
|
||||
Reference:
|
||||
- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
|
||||
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
|
||||
"""
|
||||
from typing import Any, List
|
||||
from pathlib import Path
|
||||
import os
|
||||
from monai.transforms import Compose
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from health_azure.utils import CheckpointDownloader
|
||||
from health_azure.utils import get_workspace
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
|
||||
from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
|
||||
from InnerEye.ML.common import get_best_checkpoint_path
|
||||
|
||||
from InnerEye.ML.Histopathology.models.transforms import (
|
||||
EncodeTilesBatchd,
|
||||
LoadTilesBatchd,
|
||||
)
|
||||
from InnerEye.ML.Histopathology.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
InnerEyeSSLEncoder,
|
||||
)
|
||||
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
|
||||
|
||||
class DeepSMILECrck(BaseMIL):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# Define dictionary with default params that can be overriden from subclasses or CLI
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMIL:
|
||||
pooling_type=AttentionLayer.__name__,
|
||||
encoding_chunk_size=60,
|
||||
cache_mode=CacheMode.MEMORY,
|
||||
precache_location=CacheLocation.CPU,
|
||||
# declared in DatasetParams:
|
||||
local_dataset=Path("/tmp/datasets/TCGA-CRCk"),
|
||||
azure_dataset_id="TCGA-CRCk",
|
||||
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
|
||||
# declared in TrainerParams:
|
||||
num_epochs=50,
|
||||
# declared in WorkflowParams:
|
||||
number_of_cross_validation_splits=5,
|
||||
cross_validation_split_index=0,
|
||||
# declared in OptimizerParams:
|
||||
l_rate=5e-4,
|
||||
weight_decay=1e-4,
|
||||
adam_betas=(0.9, 0.99),
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
||||
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
|
||||
self.best_checkpoint_filename_with_suffix = (
|
||||
self.best_checkpoint_filename + ".ckpt"
|
||||
)
|
||||
self.checkpoint_folder_path = "outputs/checkpoints/"
|
||||
|
||||
best_checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=self.checkpoint_folder_path,
|
||||
monitor="val/auroc",
|
||||
filename=self.best_checkpoint_filename,
|
||||
auto_insert_metric_name=False,
|
||||
mode="max",
|
||||
)
|
||||
self.callbacks = best_checkpoint_callback
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
return Path(
|
||||
f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/"
|
||||
)
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.encoder_type == InnerEyeSSLEncoder.__name__:
|
||||
from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_crck_4ws
|
||||
self.downloader = CheckpointDownloader(
|
||||
azure_config_json_path=get_workspace(),
|
||||
run_id=innereye_ssl_checkpoint_crck_4ws,
|
||||
checkpoint_filename="best_checkpoint.ckpt",
|
||||
download_dir="outputs/",
|
||||
remote_checkpoint_dir=Path("outputs/checkpoints")
|
||||
)
|
||||
os.chdir(fixed_paths.repository_parent_directory())
|
||||
self.downloader.download_checkpoint_if_necessary()
|
||||
|
||||
self.encoder = self.get_encoder()
|
||||
self.encoder.cuda()
|
||||
self.encoder.eval()
|
||||
|
||||
def get_data_module(self) -> TilesDataModule:
|
||||
image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN
|
||||
transform = Compose(
|
||||
[
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, self.encoder),
|
||||
]
|
||||
)
|
||||
return TcgaCrckTilesDataModule(
|
||||
root_path=self.local_dataset,
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
transform=transform,
|
||||
cache_mode=self.cache_mode,
|
||||
precache_location=self.precache_location,
|
||||
cache_dir=self.cache_dir,
|
||||
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
|
||||
cross_validation_split_index=self.cross_validation_split_index,
|
||||
)
|
||||
|
||||
def get_callbacks(self) -> List[Callback]:
|
||||
return super().get_callbacks() + [self.callbacks]
|
||||
|
||||
def get_path_to_best_checkpoint(self) -> Path:
|
||||
"""
|
||||
Returns the full path to a checkpoint file that was found to be best during training, whatever criterion
|
||||
was applied there.
|
||||
"""
|
||||
# absolute path is required for registering the model.
|
||||
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
|
||||
self.checkpoint_folder_path,
|
||||
self.best_checkpoint_filename_with_suffix)
|
||||
if absolute_checkpoint_path.is_file():
|
||||
return absolute_checkpoint_path
|
||||
|
||||
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
|
||||
self.checkpoint_folder_path,
|
||||
self.best_checkpoint_filename_with_suffix)
|
||||
if absolute_checkpoint_path_parent.is_file():
|
||||
return absolute_checkpoint_path_parent
|
||||
|
||||
checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path))
|
||||
if checkpoint_path.is_file():
|
||||
return checkpoint_path
|
||||
|
||||
raise ValueError("Path to best checkpoint not found")
|
||||
|
||||
|
||||
class TcgaCrckImageNetMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class TcgaCrckHistoSSLMIL(DeepSMILECrck):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs)
|
|
@ -1,208 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List
|
||||
from pathlib import Path
|
||||
import os
|
||||
from monai.transforms import Compose
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from health_azure.utils import CheckpointDownloader
|
||||
from health_azure.utils import get_workspace, is_running_in_azure_ml
|
||||
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
|
||||
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from InnerEye.ML.common import get_best_checkpoint_path
|
||||
|
||||
from InnerEye.ML.Histopathology.models.transforms import (
|
||||
EncodeTilesBatchd,
|
||||
LoadTilesBatchd,
|
||||
)
|
||||
from InnerEye.ML.Histopathology.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
InnerEyeSSLEncoder,
|
||||
IdentityEncoder
|
||||
)
|
||||
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
|
||||
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset
|
||||
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
|
||||
|
||||
|
||||
class DeepSMILEPanda(BaseMIL):
|
||||
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
|
||||
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch.
|
||||
Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA.
|
||||
"""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMIL:
|
||||
pooling_type=GatedAttentionLayer.__name__,
|
||||
# average number of tiles is 56 for PANDA
|
||||
encoding_chunk_size=60,
|
||||
cache_mode=CacheMode.MEMORY,
|
||||
precache_location=CacheLocation.CPU,
|
||||
is_finetune=False,
|
||||
|
||||
# declared in DatasetParams:
|
||||
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
|
||||
azure_dataset_id="PANDA_tiles",
|
||||
extra_azure_dataset_ids=["PANDA"],
|
||||
extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")],
|
||||
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
|
||||
# declared in TrainerParams:
|
||||
num_epochs=200,
|
||||
# use_mixed_precision = True,
|
||||
|
||||
# declared in WorkflowParams:
|
||||
number_of_cross_validation_splits=5,
|
||||
cross_validation_split_index=0,
|
||||
|
||||
# declared in OptimizerParams:
|
||||
l_rate=5e-4,
|
||||
weight_decay=1e-4,
|
||||
adam_betas=(0.9, 0.99))
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
if not is_running_in_azure_ml():
|
||||
self.num_epochs = 1
|
||||
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
|
||||
self.best_checkpoint_filename_with_suffix = (
|
||||
self.best_checkpoint_filename + ".ckpt"
|
||||
)
|
||||
self.checkpoint_folder_path = "outputs/checkpoints/"
|
||||
best_checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=self.checkpoint_folder_path,
|
||||
monitor="val/accuracy",
|
||||
filename=self.best_checkpoint_filename,
|
||||
auto_insert_metric_name=False,
|
||||
mode="max",
|
||||
)
|
||||
self.callbacks = best_checkpoint_callback
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
return Path(
|
||||
f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/"
|
||||
)
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.encoder_type == InnerEyeSSLEncoder.__name__:
|
||||
from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
self.downloader = CheckpointDownloader(
|
||||
aml_workspace=get_workspace(),
|
||||
run_id=innereye_ssl_checkpoint_binary, # innereye_ssl_checkpoint
|
||||
checkpoint_filename="best_checkpoint.ckpt", # "last.ckpt",
|
||||
download_dir="outputs/",
|
||||
remote_checkpoint_dir=Path("outputs/checkpoints")
|
||||
)
|
||||
os.chdir(fixed_paths.repository_parent_directory())
|
||||
self.downloader.download_checkpoint_if_necessary()
|
||||
self.encoder = self.get_encoder()
|
||||
if not self.is_finetune:
|
||||
self.encoder.eval()
|
||||
|
||||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
image_key = PandaTilesDataset.IMAGE_COLUMN
|
||||
if self.is_finetune:
|
||||
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
|
||||
else:
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
|
||||
])
|
||||
|
||||
return PandaTilesDataModule(
|
||||
root_path=self.local_dataset,
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
transform=transform,
|
||||
cache_mode=self.cache_mode,
|
||||
precache_location=self.precache_location,
|
||||
cache_dir=self.cache_dir,
|
||||
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
|
||||
cross_validation_split_index=self.cross_validation_split_index,
|
||||
)
|
||||
|
||||
def create_model(self) -> DeepMILModule:
|
||||
self.data_module = self.get_data_module()
|
||||
# Encoding is done in the datamodule, so here we provide instead a dummy
|
||||
# no-op IdentityEncoder to be used inside the model
|
||||
self.slide_dataset = self.get_slide_dataset()
|
||||
self.level = 1
|
||||
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
|
||||
if self.is_finetune:
|
||||
self.model_encoder = self.encoder
|
||||
for params in self.model_encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
|
||||
return DeepMILModule(encoder=self.model_encoder,
|
||||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
class_weights=self.data_module.class_weights,
|
||||
l_rate=self.l_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
adam_betas=self.adam_betas,
|
||||
slide_dataset=self.get_slide_dataset(),
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
class_names=self.class_names,
|
||||
is_finetune=self.is_finetune)
|
||||
|
||||
def get_slide_dataset(self) -> PandaDataset:
|
||||
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
|
||||
|
||||
def get_callbacks(self) -> List[Callback]:
|
||||
return super().get_callbacks() + [self.callbacks]
|
||||
|
||||
def get_path_to_best_checkpoint(self) -> Path:
|
||||
"""
|
||||
Returns the full path to a checkpoint file that was found to be best during training, whatever criterion
|
||||
was applied there.
|
||||
"""
|
||||
# absolute path is required for registering the model.
|
||||
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
|
||||
self.checkpoint_folder_path,
|
||||
self.best_checkpoint_filename_with_suffix)
|
||||
if absolute_checkpoint_path.is_file():
|
||||
return absolute_checkpoint_path
|
||||
|
||||
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
|
||||
self.checkpoint_folder_path,
|
||||
self.best_checkpoint_filename_with_suffix)
|
||||
if absolute_checkpoint_path_parent.is_file():
|
||||
return absolute_checkpoint_path_parent
|
||||
|
||||
checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path))
|
||||
if checkpoint_path.is_file():
|
||||
return checkpoint_path
|
||||
|
||||
raise ValueError("Path to best checkpoint not found")
|
||||
|
||||
class PandaImageNetMIL(DeepSMILEPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class PandaImageNetSimCLRMIL(DeepSMILEPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class PandaInnerEyeSSLMIL(DeepSMILEPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs)
|
||||
|
||||
|
||||
class PandaHistoSSLMIL(DeepSMILEPanda):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs)
|
|
@ -1,7 +0,0 @@
|
|||
innereye_ssl_checkpoint = "hsharma_panda_explore:hsharma_panda_explore_1638437076_357167ae"
|
||||
innereye_ssl_checkpoint_binary = "hsharma_panda_tiles_ssl:hsharma_panda_tiles_ssl_1639766433_161e03b9"
|
||||
innereye_ssl_checkpoint_crck_4ws = "ModifyOldSSLCheckpoint:a9259fdb-3964-4c5b-8962-4660e0b79d44"
|
||||
innereye_ssl_checkpoint_crck_radiomics = "ModifyOldSSLCheckpoint:704b1af8-7c75-46ed-8460-d80a0e603194"
|
||||
|
||||
# outdated checkpoints
|
||||
# innereye_ssl_checkpoint_crck_radiomics = updated_transforms:updated_transforms_1636471522_5473e3ff
|
|
@ -10,7 +10,7 @@ from abc import abstractmethod
|
|||
from collections import Counter, defaultdict
|
||||
from multiprocessing import cpu_count
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Set, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -22,13 +22,10 @@ from InnerEye.ML.dataset.full_image_dataset import GeneralDataset
|
|||
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarDataSource, ScalarItem, SequenceDataSource
|
||||
from InnerEye.ML.scalar_config import LabelTransformation, ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.features_util import FeatureStatistics
|
||||
|
||||
T = TypeVar('T', bound=ScalarDataSource)
|
||||
|
||||
|
||||
def extract_label_classification(label_string: str, sample_id: str, num_classes: int,
|
||||
is_classification_dataset: bool) -> List[float]:
|
||||
|
@ -180,7 +177,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
metadata_columns: Optional[Set[str]] = None,
|
||||
is_classification_dataset: bool = True,
|
||||
num_classes: int = 1,
|
||||
sequence_position_numeric: Optional[int] = None) -> T:
|
||||
sequence_position_numeric: Optional[int] = None) -> ScalarDataSource:
|
||||
"""
|
||||
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
|
||||
labels, the non-image features, and the paths to the image files.
|
||||
|
@ -328,7 +325,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
return datasource # type: ignore
|
||||
|
||||
|
||||
class DataSourceReader(Generic[T]):
|
||||
class DataSourceReader():
|
||||
"""
|
||||
Class that allows reading of data sources from a scalar dataset data frame.
|
||||
"""
|
||||
|
@ -421,7 +418,7 @@ class DataSourceReader(Generic[T]):
|
|||
|
||||
@staticmethod
|
||||
def load_data_sources_as_per_config(data_frame: pd.DataFrame,
|
||||
args: ScalarModelBase) -> List[T]:
|
||||
args: ScalarModelBase) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Loads dataset items from the given dataframe, where all column and channel configurations are taken from their
|
||||
respective model config elements.
|
||||
|
@ -436,11 +433,7 @@ class DataSourceReader(Generic[T]):
|
|||
if args.categorical_feature_encoder is not None:
|
||||
assert isinstance(args.categorical_feature_encoder, CategoricalToOneHotEncoder) # mypy
|
||||
|
||||
sequence_column = None
|
||||
if isinstance(args, SequenceModelBase):
|
||||
sequence_column = args.sequence_column
|
||||
|
||||
return DataSourceReader[T](
|
||||
return DataSourceReader(
|
||||
data_frame=data_frame,
|
||||
image_channels=args.image_channels,
|
||||
image_file_column=args.image_file_column,
|
||||
|
@ -450,14 +443,13 @@ class DataSourceReader(Generic[T]):
|
|||
non_image_feature_channels=args.get_non_image_feature_channels_dict(),
|
||||
numerical_columns=args.numerical_columns,
|
||||
categorical_data_encoder=args.categorical_feature_encoder,
|
||||
sequence_column=sequence_column,
|
||||
subject_column=args.subject_column,
|
||||
channel_column=args.channel_column,
|
||||
num_classes=len(args.class_names),
|
||||
is_classification_dataset=args.is_classification_model
|
||||
).load_data_sources(num_dataset_reader_workers=args.num_dataset_reader_workers)
|
||||
|
||||
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[T]:
|
||||
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Extracts information from a dataframe to create a list of ClassificationItem. This will create one entry per
|
||||
unique
|
||||
|
@ -484,12 +476,12 @@ class DataSourceReader(Generic[T]):
|
|||
|
||||
return list(flatten(filter(None, results)))
|
||||
|
||||
def load_datasources_for_subject(self, subject_id: str) -> Optional[List[T]]:
|
||||
def load_datasources_for_subject(self, subject_id: str) -> Optional[List[ScalarDataSource]]:
|
||||
|
||||
rows = self.data_frame[np.in1d(self.data_frame[self.subject_column].values, [subject_id])]
|
||||
|
||||
def _load_single_data_source(_rows: pd.DataFrame,
|
||||
_sequence_position_numeric: Optional[int] = None) -> T:
|
||||
_sequence_position_numeric: Optional[int] = None) -> ScalarDataSource:
|
||||
return load_single_data_source(
|
||||
subject_rows=_rows,
|
||||
subject_id=subject_id,
|
||||
|
@ -507,20 +499,6 @@ class DataSourceReader(Generic[T]):
|
|||
num_classes=self.num_classes,
|
||||
sequence_position_numeric=_sequence_position_numeric
|
||||
)
|
||||
|
||||
def _load_sequence_data_source(_sequence_position: Any) -> T:
|
||||
_sequence_position_numeric = int(_sequence_position)
|
||||
if _sequence_position_numeric < 0:
|
||||
raise ValueError(
|
||||
f"Sequence positions must be non-negative integers, but got: {_sequence_position}")
|
||||
else:
|
||||
seq_rows = rows[np.in1d(rows[self.sequence_column].values, [_sequence_position])]
|
||||
return _load_single_data_source(seq_rows, _sequence_position_numeric)
|
||||
|
||||
if self.sequence_column:
|
||||
seq_positions = rows[self.sequence_column].unique()
|
||||
return list(map(_load_sequence_data_source, seq_positions))
|
||||
else:
|
||||
if len(self.expected_channels) > 0:
|
||||
missing_channels = self.expected_channels - set(rows[self.channel_column])
|
||||
if len(missing_channels) > 0:
|
||||
|
@ -582,10 +560,10 @@ def is_valid_item_index(item: ScalarDataSource,
|
|||
return min_sequence_position_value <= item.metadata.sequence_position <= max_sequence_position_value
|
||||
|
||||
|
||||
def filter_valid_classification_data_sources_items(items: Iterable[T],
|
||||
def filter_valid_classification_data_sources_items(items: Iterable[ScalarDataSource],
|
||||
file_to_path_mapping: Optional[Dict[str, Path]],
|
||||
max_sequence_position_value: Optional[int] = None,
|
||||
min_sequence_position_value: int = 0) -> List[T]:
|
||||
min_sequence_position_value: int = 0) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Consumes a list of classification data sources, and removes all of those that have missing file names,
|
||||
or that have NaN or Inf features. If the file_to_path_mapping is given too, all items that have any missing files
|
||||
|
@ -601,7 +579,7 @@ def filter_valid_classification_data_sources_items(items: Iterable[T],
|
|||
:return: A list of items, all of which are valid now.
|
||||
"""
|
||||
|
||||
def all_files_present(item: T) -> bool:
|
||||
def all_files_present(item: ScalarDataSource) -> bool:
|
||||
if file_to_path_mapping:
|
||||
return all(f in file_to_path_mapping for f in item.channel_files)
|
||||
else:
|
||||
|
@ -678,14 +656,14 @@ class ScalarItemAugmentation:
|
|||
return item
|
||||
|
||||
|
||||
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
||||
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
|
||||
"""
|
||||
A base class for datasets for classification tasks. It contains logic for loading images from disk,
|
||||
either from a fixed folder or traversing into subfolders.
|
||||
"""
|
||||
one_hot_encoder: Optional[CategoricalToOneHotEncoder] = None
|
||||
status: str = ""
|
||||
items: List[T]
|
||||
items: List[ScalarDataSource]
|
||||
|
||||
def __init__(self, args: ScalarModelBase,
|
||||
data_frame: Optional[pd.DataFrame] = None,
|
||||
|
@ -710,7 +688,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
self.file_to_full_path = files_by_stem(args.local_dataset)
|
||||
logging.info("Finished traversing folder.")
|
||||
|
||||
def load_all_data_sources(self) -> List[T]:
|
||||
def load_all_data_sources(self) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Uses the dataframe to create data sources to be used by the dataset.
|
||||
:return:
|
||||
|
@ -721,7 +699,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
self.status += f"After filtering: {self.create_status_string(all_data_sources)}"
|
||||
return all_data_sources
|
||||
|
||||
def filter_valid_data_sources_items(self, data_sources: List[T]) -> List[T]:
|
||||
def filter_valid_data_sources_items(self, data_sources: List[ScalarDataSource]) -> List[ScalarDataSource]:
|
||||
raise NotImplementedError("filter_valid_data_source_items must be implemented by child classes")
|
||||
|
||||
@abstractmethod
|
||||
|
@ -737,7 +715,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
If None, they will be computed from the data in the present object.
|
||||
"""
|
||||
if self.items:
|
||||
self.feature_statistics = self.feature_statistics or FeatureStatistics[T].from_data_sources(self.items)
|
||||
self.feature_statistics = self.feature_statistics or FeatureStatistics.from_data_sources(self.items)
|
||||
self.items = self.feature_statistics.standardize(self.items)
|
||||
|
||||
def load_item(self, item: ScalarDataSource) -> ScalarItem:
|
||||
|
@ -757,7 +735,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
|
||||
return self.transform(sample)
|
||||
|
||||
def create_status_string(self, items: List[T]) -> str:
|
||||
def create_status_string(self, items: List[ScalarDataSource]) -> str:
|
||||
"""
|
||||
Creates a human readable string that contains the number of items, and the distinct number of subjects.
|
||||
:param items: Use the items provided to create the string
|
||||
|
@ -767,14 +745,14 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
return f"{len(items)} items for {distinct} subjects. "
|
||||
|
||||
|
||||
class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
|
||||
class ScalarDataset(ScalarDatasetBase):
|
||||
"""
|
||||
A dataset class that can read CSV files with a flexible schema, and extract image file paths and non-image features.
|
||||
"""
|
||||
|
||||
def __init__(self, args: ScalarModelBase,
|
||||
data_frame: Optional[pd.DataFrame] = None,
|
||||
feature_statistics: Optional[FeatureStatistics[ScalarDataSource]] = None,
|
||||
feature_statistics: Optional[FeatureStatistics] = None,
|
||||
name: Optional[str] = None,
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
|
|
|
@ -1,304 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, ScalarItemAugmentation, \
|
||||
filter_valid_classification_data_sources_items
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.features_util import FeatureStatistics
|
||||
|
||||
|
||||
def get_longest_contiguous_sequence(items: List[SequenceDataSource],
|
||||
min_sequence_position_value: int = 0,
|
||||
max_sequence_position_value: Optional[int] = None) -> List[SequenceDataSource]:
|
||||
"""
|
||||
From a list of classification items, extract the longest contiguous sequence of items starting
|
||||
at position value min_sequence_position_value.
|
||||
|
||||
For example:
|
||||
|
||||
if min_sequence_position_value = 1 and the
|
||||
input has sequence positions [0, 1, 3, 4], the result retains the items for positions [1].
|
||||
|
||||
if min_sequence_position_value = 1 and max_sequence_position_value = 2, then if
|
||||
input has sequence positions [0, 1, 2, 3], the result retains the items for positions [1, 2].
|
||||
|
||||
if min_sequence_position_value = 1 and max_sequence_position_value = 4, then if
|
||||
input has sequence positions [0, 1, 2, 3], the result retains the items for positions [1, 2, 3].
|
||||
|
||||
If the input sequence is [2, 3], the result is an empty list
|
||||
(the longest sequence must start at 1 but there is no item with position 1)
|
||||
|
||||
:param items: A list of classification items, sorted by sequence_position.
|
||||
:param min_sequence_position_value: The minimum sequence position all sequences start from, 0 is default.
|
||||
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
:return: A list of classification items, with a maximum sequence_position that is the
|
||||
len(result) - 1.
|
||||
"""
|
||||
result: List[SequenceDataSource] = []
|
||||
|
||||
# make sure the input sequence is sorted by sequence position first
|
||||
items = list(sorted(items, key=lambda x: x.metadata.sequence_position))
|
||||
|
||||
_last_seq_item = next((x for x in items if x.metadata.sequence_position == min_sequence_position_value), None)
|
||||
|
||||
if _last_seq_item is None:
|
||||
return result
|
||||
else:
|
||||
result.append(_last_seq_item)
|
||||
for item in items[items.index(_last_seq_item) + 1:]:
|
||||
if max_sequence_position_value is not None and \
|
||||
item.metadata.sequence_position > max_sequence_position_value:
|
||||
break
|
||||
elif item.metadata.sequence_position - _last_seq_item.metadata.sequence_position == 1:
|
||||
_last_seq_item = item
|
||||
result.append(item)
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def group_samples_into_sequences(items: Iterable[SequenceDataSource],
|
||||
min_sequence_position_value: int = 0,
|
||||
max_sequence_position_value: Optional[int] = None) -> ListOfSequences:
|
||||
"""
|
||||
Turns a flat list of classification items into a list of per-subject classification items. The resulting list
|
||||
has one entry per unique sample ID in the input. With a single sample ID, the items
|
||||
are sorted by metadata.sequence_position in ascending order.
|
||||
Also, all subject data is restricted to the largest contiguous sequence starting at 0
|
||||
(e.g., if sequence positions are [0, 1, 4], only [0, 1] are retained,
|
||||
if sequence positions are [1, 2, 3] nothing is retained)
|
||||
:param items: The items that should be grouped.
|
||||
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
up to and including this value. Entries beyond that sequence_position will be dropped.
|
||||
:param min_sequence_position_value: All sequences must have a entries with sequence_position starting
|
||||
from and including this value, 0 is default.
|
||||
:return:
|
||||
"""
|
||||
if min_sequence_position_value < 0:
|
||||
raise ValueError("Argument min_sequence_position_value must be >= 0")
|
||||
|
||||
if max_sequence_position_value:
|
||||
if max_sequence_position_value < min_sequence_position_value:
|
||||
raise ValueError(f"Argument max_sequence_position_value: {max_sequence_position_value} must "
|
||||
f"be >= min_sequence_position_value: {min_sequence_position_value}")
|
||||
|
||||
grouped: DefaultDict[str, List[SequenceDataSource]] = defaultdict(list)
|
||||
for item in items:
|
||||
grouped[item.id].append(item)
|
||||
result: List[ClassificationItemSequence[SequenceDataSource]] = []
|
||||
for sample_id, items in grouped.items():
|
||||
unique_positions = set(x.metadata.sequence_position for x in items)
|
||||
if len(unique_positions) != len(items):
|
||||
raise ValueError(f"The set of sequence positions for subject {sample_id} contains duplicates.")
|
||||
|
||||
group_sorted = get_longest_contiguous_sequence(
|
||||
items=items,
|
||||
min_sequence_position_value=min_sequence_position_value,
|
||||
max_sequence_position_value=max_sequence_position_value
|
||||
)
|
||||
|
||||
if len(group_sorted) > 0:
|
||||
result.append(ClassificationItemSequence(id=sample_id, items=group_sorted))
|
||||
else:
|
||||
# No contiguous sequence at all
|
||||
logging.warning(f"Skipped sequence for subject {sample_id} as it was not contiguous")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_difference_features(sequences: ListOfSequences, feature_indices: List[int]) -> ListOfSequences:
|
||||
"""
|
||||
For each sequence in the argument, compute feature differences to the first sequence element, and adds them
|
||||
as new features at the end of the non-image features. Feature differences are only compute for those columns
|
||||
in numerical_non_image_features that are given in the feature_indices argument.
|
||||
The first sequence elements gets feature differences that are all zeros. The i.th sequence element will get
|
||||
additional features that are the differences of numerical_non_image_features[:,j] and the same element in the
|
||||
0.th sequence
|
||||
element.
|
||||
:param sequences: The input sequences.
|
||||
:param feature_indices: The column indices in numerical_non_image_features for which differences should be computed.
|
||||
:return: A new list of sequences with the feature differences added as columns in the
|
||||
numerical_non_image_features field.
|
||||
"""
|
||||
|
||||
def add_features(seq: ClassificationItemSequence) -> ClassificationItemSequence:
|
||||
items_mapped: List[SequenceDataSource] = []
|
||||
feature_baseline = None
|
||||
for item_index, item in enumerate(seq.items):
|
||||
if item_index == 0:
|
||||
feature_baseline = torch.stack([item.numerical_non_image_features[:, i] for i in feature_indices],
|
||||
dim=0)
|
||||
features_for_diff = torch.stack([item.numerical_non_image_features[:, i] for i in feature_indices], dim=0)
|
||||
diff = features_for_diff - feature_baseline
|
||||
new_features = torch.cat([item.numerical_non_image_features, diff.t()], dim=1)
|
||||
items_mapped.append(item.clone_with_overrides(numerical_non_image_features=new_features))
|
||||
return ClassificationItemSequence(id=seq.id, items=items_mapped)
|
||||
|
||||
return list(map(add_features, sequences))
|
||||
|
||||
|
||||
"""
|
||||
Example for the use of SequenceDataset:
|
||||
|
||||
A sequence dataset groups rows not only by subject ID (as the normal ClassificationDataset does), but also
|
||||
by a sequence position. That sequence position is read out from a column specified in the `sequence_column`
|
||||
field of the model configuration.
|
||||
|
||||
Within a given subject, a sequence dataset returns instance of ClassificationItemSequence, each of which contains
|
||||
a ClassificationItem for each individual sequence position.
|
||||
|
||||
Example use case:
|
||||
subject,POSITION,measure0,measure0,image,Label
|
||||
1,0,92,362,img1,0
|
||||
1,1,92,357,img1,1
|
||||
1,2,92,400,,0
|
||||
2,0,82,477,img2,0
|
||||
2,1,82,,img2,1
|
||||
2,2,82,220,img2,0
|
||||
|
||||
To read images and measure1 as a non-imaging feature from this file, you would specify:
|
||||
image_channels = []
|
||||
image_file_column = "image"
|
||||
label_channel = None
|
||||
label_value_column = "Label"
|
||||
non_image_feature_channels = []
|
||||
numerical_columns = ["measure1"]
|
||||
sequence_column = "POSITION"
|
||||
|
||||
All of the "*_channel" arguments can be left empty. After grouping by subject and sequence position,
|
||||
only 1 row remains, and it is hence clear to the data loader which row to read from.
|
||||
|
||||
After reading the CSV files, the data loader will remove all rows where
|
||||
* there is no image file path given in the file
|
||||
* there is a missing value in the non-image features (missing measure1 in the example above)
|
||||
* If the traverse_dirs_when_loading is given in the model config, the data loader will also remove items where
|
||||
the image file does not exist.
|
||||
|
||||
After this filtering, the data loader will group the items by subject, and sort by position within a subject.
|
||||
Within a subject, the sequences must start at position 0, and are kept up to the first "gap". Hence, if only positions
|
||||
0, 1, and 2 are valid, the sequence that is kept contains items [0, 1]
|
||||
|
||||
Assuming that the image files all exist, this would return
|
||||
* result[0] containing "1" with POSITION numbers 0 and 1 (position 2 has no image file)
|
||||
* result[1] containing "2" with POSITION number 0 only (position 1 has missing measure1, and hence position 2 has to
|
||||
be dropped as well)
|
||||
"""
|
||||
|
||||
|
||||
class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
||||
"""
|
||||
A dataset class that groups its raw dataset rows by subject ID and a sequence index. Each item in the dataset
|
||||
has all the rows for a given subject, and within each subject, a sorted sequence of rows.
|
||||
"""
|
||||
items: List[ClassificationItemSequence[SequenceDataSource]] # type: ignore
|
||||
|
||||
def __init__(self,
|
||||
args: SequenceModelBase,
|
||||
data_frame: pd.DataFrame,
|
||||
feature_statistics: Optional[
|
||||
FeatureStatistics[ClassificationItemSequence[SequenceDataSource]]] = None,
|
||||
name: Optional[str] = None,
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
Creates a new sequence dataset from a dataframe.
|
||||
:param args: The model configuration object.
|
||||
:param data_frame: The dataframe to read from.
|
||||
:param feature_statistics: If given, the normalization factor for the non-image features is taken
|
||||
:param sample_transform: Transformation to apply to each sample in the loading step. By default, no
|
||||
transformation is applied.
|
||||
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
|
||||
:param name: Name of the dataset, used for logging
|
||||
"""
|
||||
super().__init__(args=args,
|
||||
data_frame=data_frame,
|
||||
feature_statistics=feature_statistics,
|
||||
name=name,
|
||||
sample_transform=sample_transform)
|
||||
if self.args.sequence_column is None:
|
||||
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
|
||||
"sequence index should be read from.")
|
||||
|
||||
if len(self.args.class_names) > 1:
|
||||
raise ValueError("Multilabel configs not supported for sequence datasets.")
|
||||
|
||||
data_sources = self.load_all_data_sources()
|
||||
grouped = group_samples_into_sequences(
|
||||
data_sources,
|
||||
min_sequence_position_value=self.args.min_sequence_position_value,
|
||||
max_sequence_position_value=self.args.max_sequence_position_value
|
||||
)
|
||||
if self.args.add_differences_for_features:
|
||||
missing_columns = set(self.args.add_differences_for_features) - set(self.args.numerical_columns)
|
||||
if len(missing_columns) > 0:
|
||||
raise ValueError(f"Unable to add differences for these columns because they have not been specified "
|
||||
f"in the `non_image_feature_channels` property: {missing_columns}")
|
||||
feature_indices = [self.args.numerical_columns.index(f) for f in self.args.add_differences_for_features]
|
||||
grouped = add_difference_features(grouped, feature_indices)
|
||||
self.status += f"After grouping: {len(grouped)} subjects."
|
||||
self.items = grouped
|
||||
self.standardize_non_imaging_features()
|
||||
|
||||
def get_status(self) -> str:
|
||||
"""
|
||||
Creates a human readable string that describes the contents of the dataset.
|
||||
"""
|
||||
return self.status
|
||||
|
||||
def filter_valid_data_sources_items(self, data_sources: List[SequenceDataSource]) -> List[SequenceDataSource]:
|
||||
return filter_valid_classification_data_sources_items(
|
||||
items=data_sources,
|
||||
file_to_path_mapping=self.file_to_full_path,
|
||||
min_sequence_position_value=self.args.min_sequence_position_value,
|
||||
max_sequence_position_value=self.args.max_sequence_position_value
|
||||
)
|
||||
|
||||
def get_labels_for_imbalanced_sampler(self) -> List[float]:
|
||||
"""
|
||||
Returns a list of all the labels at the target_index position. Is used to
|
||||
compute the weights in the ImbalancedSampler. If more than on target position
|
||||
is specified the ImbalancedSampler cannot be used.
|
||||
:return:
|
||||
"""
|
||||
if len(self.args.get_target_indices()) > 1:
|
||||
raise NotImplementedError("You cannot use the ImbalancedSampler if you"
|
||||
"want to predict more than one sequence position."
|
||||
"Use loss weighting instead.")
|
||||
return [seq.get_labels_at_target_indices(self.args.get_target_indices())[-1].item()
|
||||
for seq in self.items]
|
||||
|
||||
def get_class_counts(self) -> Dict[int, int]:
|
||||
"""
|
||||
Return the label counts (summed over all target indices).
|
||||
:return: Dictionary of {"label": count}
|
||||
"""
|
||||
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
|
||||
for seq in self.items]) # [N, T, 1]
|
||||
non_nan_and_nonzero_labels = list(
|
||||
filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
|
||||
counts = dict(Counter(non_nan_and_nonzero_labels))
|
||||
if not len(counts.keys()) == 1 or 1 not in counts.keys():
|
||||
raise ValueError("get_class_counts supports only binary targets.")
|
||||
return {0: counts[1]}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, i: int) -> Dict[str, Any]:
|
||||
loaded = list(map(self.load_item, self.items[i].items))
|
||||
return vars(ClassificationItemSequence(id=self.items[i].id, items=loaded))
|
|
@ -1,77 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from InnerEye.Common.common_util import check_properties_are_not_none
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
|
||||
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
|
||||
|
||||
T = TypeVar('T', SequenceDataSource, ScalarItem)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClassificationItemSequence(Generic[T]):
|
||||
"""
|
||||
A class that holds a sequence of samples for a given patient ID.
|
||||
"""
|
||||
id: str
|
||||
items: List[T]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
check_properties_are_not_none(self)
|
||||
|
||||
@staticmethod
|
||||
def create_labels_tensor_for_minibatch(sequences: List[ClassificationItemSequence[ScalarItem]],
|
||||
target_indices: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Create label tensor for a minibatch training from a list of sequences for the provided
|
||||
target indices. If sequences are unequal then they are padded with a NaN value.
|
||||
:param sequences: sequences to create label tensor from.
|
||||
:param target_indices: label indices for which to extract label for from the provided sequences.
|
||||
:return: A label tensor with NaN padding if required.
|
||||
"""
|
||||
return sequences_to_padded_tensor(
|
||||
sequences=[seq.get_labels_at_target_indices(target_indices) for seq in sequences],
|
||||
padding_value=np.nan
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_minibatch(minibatch: Dict[str, Any]) -> List[ClassificationItemSequence[ScalarItem]]:
|
||||
"""
|
||||
Creates a list of ClassificationItemSequence from the output of a data loader. The data loader returns a
|
||||
dictionary with collated items, this function is effectively the inverse.
|
||||
:param minibatch: A dictionary that contains the collated fields of ClassificationItemSequence objects.
|
||||
:return: A list of ClassificationItemSequence objects.
|
||||
"""
|
||||
# batched is a de-generate ClassificationItemSequence, with id being a list of strings, and items being
|
||||
# a list of lists.
|
||||
batched = ClassificationItemSequence(**minibatch)
|
||||
return [ClassificationItemSequence(id=sample_id, items=items)
|
||||
for (sample_id, items) in zip(batched.id, batched.items)]
|
||||
|
||||
def get_labels_at_target_indices(self, target_indices: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Gets the label fields for the sequence elements with the given zero-based indices, if they exist
|
||||
otherwise fill with NaN.
|
||||
"""
|
||||
target_indices = sorted(target_indices)
|
||||
nan = torch.tensor([np.nan]).type_as(self.items[0].label)
|
||||
|
||||
def _get_label_or_nan(idx: int) -> torch.Tensor:
|
||||
return self.items[idx].label if idx < len(self.items) else nan
|
||||
|
||||
if any(p < 0 for p in target_indices):
|
||||
raise ValueError("Argument target_indices cannot contain negative values")
|
||||
|
||||
return torch.stack(list(map(_get_label_or_nan, target_indices)))
|
||||
|
||||
|
||||
ListOfSequences = List[ClassificationItemSequence[SequenceDataSource]]
|
|
@ -19,11 +19,9 @@ from InnerEye.ML.metrics import compute_dice_across_patches
|
|||
from InnerEye.ML.metrics_dict import DataframeLogger, MetricsDict
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils import image_util, metrics_util, model_util
|
||||
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
|
||||
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
|
||||
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank"
|
||||
|
@ -184,12 +182,7 @@ class ScalarLightning(InnerEyeLightning):
|
|||
super().__init__(config, *args, **kwargs)
|
||||
self.model = config.create_model()
|
||||
raw_loss = model_util.create_scalar_loss_function(config)
|
||||
if isinstance(config, SequenceModelBase):
|
||||
self.loss_fn = lambda model_output, loss: apply_sequence_model_loss(raw_loss, model_output, loss)
|
||||
self.target_indices = config.get_target_indices()
|
||||
else:
|
||||
self.loss_fn = raw_loss
|
||||
self.target_indices = []
|
||||
|
||||
self.target_names = config.target_names
|
||||
self.is_classification_model = config.is_classification_model
|
||||
|
@ -202,11 +195,6 @@ class ScalarLightning(InnerEyeLightning):
|
|||
self.train_metric_computers = config.create_metric_computers()
|
||||
self.val_metric_computers = config.create_metric_computers()
|
||||
self.compute_and_log_metrics = config.compute_and_log_metrics
|
||||
# if config.compute_grad_cam:
|
||||
# model_to_evaluate = self.train_val_params.mean_teacher_model if \
|
||||
# config.compute_mean_teacher_model else self.train_val_params.model
|
||||
# self.guided_grad_cam = VisualizationMaps(model_to_evaluate, config)
|
||||
# config.visualization_folder.mkdir(exist_ok=True)
|
||||
|
||||
def forward(self, *model_inputs: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
|
@ -246,7 +234,7 @@ class ScalarLightning(InnerEyeLightning):
|
|||
:param batch_index: The index of the present batch (supplied only for diagnostics).
|
||||
Runs a minibatch of training or validation data through the model.
|
||||
"""
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample)
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, sample)
|
||||
labels = model_inputs_and_labels.labels
|
||||
if is_training:
|
||||
logits = self.model(*model_inputs_and_labels.model_inputs)
|
||||
|
|
|
@ -33,7 +33,6 @@ from InnerEye.ML.pipelines.scalar_inference import ScalarEnsemblePipeline, Scala
|
|||
ScalarInferencePipelineBase
|
||||
from InnerEye.ML.reports.segmentation_report import boxplot_per_structure
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils import io_util, ml_util
|
||||
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array
|
||||
from InnerEye.ML.utils.io_util import ImageHeader, MedicalImageFileType, load_nifti_image, save_lines_to_file
|
||||
|
@ -397,10 +396,6 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
|
|||
Create an instance of either a ScalarMetricsDict or SequenceMetricsDict, depending on whether the given
|
||||
configuration is a sequence model configuration or not.
|
||||
"""
|
||||
if isinstance(config, SequenceModelBase):
|
||||
return SequenceMetricsDict.create(is_classification_model=config.is_classification_model,
|
||||
sequence_target_positions=config.sequence_target_positions)
|
||||
else:
|
||||
return ScalarMetricsDict(hues=config.target_names,
|
||||
is_classification_metrics=config.is_classification_model)
|
||||
|
||||
|
@ -424,25 +419,19 @@ def classification_model_test(config: ScalarModelBase,
|
|||
checkpoint_paths=checkpoint_paths)
|
||||
if pipeline is None:
|
||||
raise ValueError("Inference pipeline could not be created.")
|
||||
|
||||
# for mypy
|
||||
assert isinstance(pipeline, ScalarInferencePipelineBase)
|
||||
|
||||
ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
|
||||
ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_dataload_workers=0
|
||||
)
|
||||
|
||||
logging.info(f"Starting to evaluate model on {data_split.value} set.")
|
||||
results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc)
|
||||
os.makedirs(str(results_folder), exist_ok=True)
|
||||
metrics_dict = create_metrics_dict_for_scalar_models(config)
|
||||
if not isinstance(config, SequenceModelBase):
|
||||
output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV)
|
||||
else:
|
||||
output_logger = None
|
||||
|
||||
for sample in ds:
|
||||
result = pipeline.predict(sample)
|
||||
|
@ -463,15 +452,11 @@ def classification_model_test(config: ScalarModelBase,
|
|||
labels=label,
|
||||
loss_type=config.loss_type)
|
||||
logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")
|
||||
|
||||
average = metrics_dict.average(across_hues=False)
|
||||
logging.info(average.to_string())
|
||||
|
||||
if isinstance(metrics_dict, ScalarMetricsDict):
|
||||
csv_file = results_folder / SUBJECT_METRICS_FILE_NAME
|
||||
|
||||
logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")
|
||||
|
||||
# If we are running inference after a training run, the validation set metrics may have been written
|
||||
# during train time. If this is not the case, or we are running on the test set, create the metrics
|
||||
# file.
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import RNNCellBase
|
||||
|
||||
from InnerEye.ML.models.layers.identity import Identity
|
||||
|
||||
|
||||
class LayerNormGRUCell(RNNCellBase):
|
||||
"""
|
||||
Implements GRUCell with layer normalisation and zone-out on top.
|
||||
It inherits the base RNN cell whose trainable weight matrices are used.
|
||||
|
||||
References:
|
||||
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer Normalization." (2016).
|
||||
[2] Krueger, David, et al. "Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations." (2016).
|
||||
|
||||
:param input_size: Number of input features to the cell
|
||||
:param hidden_size: Number of hidden states in the cell
|
||||
:param use_layer_norm: If set to True, layer normalisation is applied to
|
||||
reset, update and new tensors before activation.
|
||||
:param dropout: Dropout probability for the hidden states [0,1]
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, use_layer_norm: bool = False, dropout: float = 0.0):
|
||||
super(LayerNormGRUCell, self).__init__(input_size, hidden_size, bias=False, num_chunks=3)
|
||||
|
||||
self.dropout = dropout
|
||||
self.ln_r = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
|
||||
self.ln_z = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
|
||||
self.ln_n = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
|
||||
|
||||
def forward(self, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> torch.Tensor: # type: ignore
|
||||
if hx is None:
|
||||
hx = input.new_zeros(size=(input.size(0), self.hidden_size), requires_grad=False)
|
||||
|
||||
ih = input.mm(self.weight_ih.t())
|
||||
hh = hx.mm(self.weight_hh.t())
|
||||
|
||||
i_r, i_z, i_n = ih.chunk(3, dim=1)
|
||||
h_r, h_z, h_n = hh.chunk(3, dim=1)
|
||||
|
||||
# Activations with layer normalisation
|
||||
r = torch.sigmoid(self.ln_r(i_r + h_r))
|
||||
z = torch.sigmoid(self.ln_z(i_z + h_z))
|
||||
n = torch.tanh(self.ln_n(i_n + r * h_n))
|
||||
new_h = (torch.tensor(1.0) - z) * n + z * hx
|
||||
|
||||
# Apply zoneout drop-out on hidden states
|
||||
if self.dropout > 0.0:
|
||||
bernouli_mask = F.dropout(torch.ones_like(new_h), p=self.dropout, training=bool(self.training))
|
||||
new_h = bernouli_mask * new_h + (torch.tensor(1.0) - bernouli_mask) * hx
|
||||
|
||||
return new_h
|
||||
|
||||
|
||||
class LayerNormGRU(nn.Module):
|
||||
"""
|
||||
Implements a stacked GRU layers. Differs from torch.nn.GRU implementation by
|
||||
the use of layer normalisation and hidden state preserving drop-out techniques
|
||||
(zone-out) which are currently not provided in the default implementation.
|
||||
|
||||
https://arxiv.org/pdf/1607.06450.pdf
|
||||
https://arxiv.org/pdf/1606.01305.pdf
|
||||
|
||||
:param input_size: Number of input features.
|
||||
:param hidden_size: Number of hidden states in GRU, it is used for all layers.
|
||||
:param num_layers: Number of stacked GRU layers in the module.
|
||||
:param batch_first: If set to true, input tensor should have the batch dimension in the first axis.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int = 1,
|
||||
**kwargs: Any):
|
||||
super(LayerNormGRU, self).__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.cells = nn.ModuleList([
|
||||
LayerNormGRUCell(input_size if i == 0 else hidden_size, hidden_size, **kwargs)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor, hx: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
seq_axis = 1
|
||||
for i, cell in enumerate(self.cells): # type: ignore
|
||||
y = []
|
||||
hidden = hx[i]
|
||||
for xc in x.chunk(x.size(seq_axis), dim=seq_axis):
|
||||
xc = xc.squeeze(seq_axis)
|
||||
hidden = cell(xc, hidden)
|
||||
y.append(hidden.unsqueeze(0))
|
||||
x = torch.stack(y, dim=seq_axis + 1).squeeze(0)
|
||||
return x
|
|
@ -1,200 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
|
||||
from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
|
||||
from InnerEye.ML.models.architectures.sequential.gru import LayerNormGRU
|
||||
from InnerEye.ML.models.layers.identity import Identity
|
||||
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
|
||||
|
||||
|
||||
class RNNClassifier(DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor]):
|
||||
"""
|
||||
Recurrent neural network (GRU) to perform a binary classification of sequence datasets.
|
||||
The class scores that the model outputs are results of a log softmax.
|
||||
:param input_dim: Number of input channels for the GRU layer.
|
||||
:param hidden_dim: Number of hidden states
|
||||
:param output_dim: Number of model output channels
|
||||
:param num_rnn_layers: Number of RNN layers to be stacked in the classifier. By default, a single GRU layer is used.
|
||||
:param use_layer_norm: If set to True, hidden state activations are normalised at each time step.
|
||||
:param target_indices: Output target indices. For many input to one output sequential model,
|
||||
it should be equal to the last index `-1`. If a tensor of indices are provided,
|
||||
it will return sequential model outputs at given time indices.
|
||||
:param ref_indices: Optional, if set then the hidden states from these reference indices is concatenated to the
|
||||
hidden state of the target position before computing the class posterior.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
target_indices: List[int],
|
||||
num_rnn_layers: int = 1,
|
||||
use_layer_norm: bool = False,
|
||||
rnn_dropout: float = 0.00,
|
||||
ref_indices: Optional[List[int]] = None) -> None:
|
||||
super().__init__()
|
||||
self.target_indices = target_indices or [-1]
|
||||
self.input_dim = input_dim
|
||||
self.ref_indices = ref_indices
|
||||
self.hidden_dim = hidden_dim
|
||||
# The GRU takes embeddings as inputs, and outputs hidden states
|
||||
# with dimensionality hidden_dim.
|
||||
self.gru: LayerNormGRU = LayerNormGRU(input_size=input_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_rnn_layers,
|
||||
use_layer_norm=use_layer_norm,
|
||||
dropout=rnn_dropout)
|
||||
|
||||
# The linear layer that maps from hidden state space to class space
|
||||
if self.ref_indices is None:
|
||||
self.hidden2class = nn.Linear(hidden_dim, output_dim)
|
||||
else:
|
||||
self.hidden2class = nn.Linear((len(ref_indices) + 1) * hidden_dim, output_dim) # type: ignore
|
||||
|
||||
# Create a parameter to learn the initial hidden state
|
||||
hidden_size = torch.Size([num_rnn_layers, 1, hidden_dim])
|
||||
self.h0 = nn.Parameter(torch.zeros(size=hidden_size), requires_grad=True)
|
||||
self.initialise_parameters()
|
||||
|
||||
def forward(self, *input_seq: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
input_seq: Input sequence (batch_size x sequence_length x input_dim)
|
||||
return: Sequence classification output (batch_size x target_indices x output_dim)
|
||||
"""
|
||||
batch_size, seq_length, _ = input_seq[0].size()
|
||||
# GRU forward pass and linear mapping from hidden state to the output space.
|
||||
# gru_out of shape [batch_size, seq_length, hidden_dim]
|
||||
gru_out: torch.Tensor = self.gru(input_seq[0], self.h0.repeat(1, batch_size, 1))
|
||||
# pad the gru output if required to ensure values for each target index
|
||||
gru_out = self.pad_gru_output(gru_out)
|
||||
|
||||
if self.ref_indices is None:
|
||||
return self.hidden2class(gru_out[:, self.target_indices, :])
|
||||
else:
|
||||
predictions = []
|
||||
for target_index in self.target_indices:
|
||||
input_to_classifier = gru_out[:, self.ref_indices + [target_index], :].view(batch_size, -1)
|
||||
predictions.append(self.hidden2class(input_to_classifier))
|
||||
return torch.stack(predictions, dim=1)
|
||||
|
||||
def pad_gru_output(self, input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad the GRU output with zeros if required to make sure to make sure there is a value for each target index
|
||||
this RNN classifier is initialized with.
|
||||
"""
|
||||
current_sequence_length = len(range(input.shape[0]))
|
||||
required_sequence_length = len(range(max(self.target_indices) + 1))
|
||||
pad_size = max(required_sequence_length - current_sequence_length, 0)
|
||||
return F.pad(input=input, pad=[0, 0, 0, pad_size], mode='constant', value=0)
|
||||
|
||||
def initialise_parameters(self) -> None:
|
||||
"""
|
||||
Initialises the initial hidden state parameter in GRU.
|
||||
"""
|
||||
# Disable type checking here because these parameters are created via setattr, and hence
|
||||
# confuse mypy
|
||||
nn.init.xavier_normal_(self.h0, gain=nn.init.calculate_gain('tanh'))
|
||||
|
||||
def get_input_tensors(self, sequences: List[ClassificationItemSequence]) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns the input tensors as a List where the first element corresponds to the non-imaging features.
|
||||
"""
|
||||
seq_flattened = [torch.stack([i.get_all_non_imaging_features() for i in seq.items], dim=0)
|
||||
for seq in sequences]
|
||||
return [sequences_to_padded_tensor(seq_flattened)]
|
||||
|
||||
|
||||
class RNNClassifierWithEncoder(RNNClassifier):
|
||||
"""
|
||||
RNN classifier for a combination of imaging and non-imaging features. The images are first encoded using
|
||||
an image encoder that is passed to the constructor.
|
||||
:param image_encode: torch module to use to encode the image features. For example a ImageEncoder could be
|
||||
for a U-Net like encoding of the images.
|
||||
:param input_dim: Number of input channels for the GRU layer i.e. number of non_imaging features + number
|
||||
of features at the output of the image encoder (if images/segmentations are used).
|
||||
:param hidden_dim: Number of hidden states
|
||||
:param output_dim: Number of model output channels
|
||||
:param num_rnn_layers: Number of RNN layers to be stacked in the classifier. By default, a single GRU layer is used.
|
||||
:param use_layer_norm: If set to True, hidden state activations are normalised at each time step.
|
||||
:param target_indices: Output target indices. For many input to one output sequential model,
|
||||
it should be equal to the last index `-1`. If a tensor of indices are provided,
|
||||
it will return sequential model outputs at given time indices.
|
||||
:param ref_indices: Optional, if set then the hidden state from these reference indices is concatenated to the
|
||||
hidden state of the target position before computing the class posterior.
|
||||
:param use_encoder_batch_norm: If True, apply batchNorm to the encoded features at the output of the image_encoder
|
||||
module prior to feeding them to the GRU layers. If False, the raw output from the image encoder is fed to the GRU
|
||||
layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
use_encoder_batch_norm: bool = False,
|
||||
image_encoder: Optional[DeviceAwareModule[ScalarItem, torch.Tensor]] = None,
|
||||
**kwargs: Any) -> None:
|
||||
super().__init__(input_dim, hidden_dim, output_dim, **kwargs)
|
||||
self.image_encoder = image_encoder
|
||||
if self.image_encoder is not None:
|
||||
# Adding necessary attributes required by GradCam computation.
|
||||
self.imaging_feature_type = image_encoder.imaging_feature_type # type: ignore
|
||||
self.num_non_image_features = image_encoder.num_non_image_features # type: ignore
|
||||
self.last_encoder_layer = ["image_encoder"] + image_encoder.get_last_encoder_layer_names() # type: ignore
|
||||
self.conv_in_3d = self.image_encoder.conv_in_3d
|
||||
self.encode_channels_jointly = False
|
||||
self.use_encoder_batch_norm = use_encoder_batch_norm
|
||||
self.layer_norm = nn.BatchNorm1d(input_dim) if use_encoder_batch_norm else Identity()
|
||||
|
||||
def forward(self, *input_seq: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
input_seq: Input sequence (batch_size x sequence_length x input_dim)
|
||||
return: Sequence classification output (batch_size x target_indices x output_dim)
|
||||
"""
|
||||
batch_size, seq_length = input_seq[0].size()[:2]
|
||||
# If we have a model that uses images the input will be a List with 2 elements
|
||||
# (imaging_features, non_imaging_features). Else it will be a List with only one
|
||||
# element (non_imaging_features).
|
||||
non_imaging_seq = input_seq[0] if len(input_seq) == 1 else input_seq[1]
|
||||
encoded_seq = []
|
||||
if self.image_encoder is not None:
|
||||
imaging_seq = input_seq[0]
|
||||
for seq in range(seq_length):
|
||||
encoded_features = self.image_encoder(imaging_seq[:, seq, :], non_imaging_seq[:, seq, :])
|
||||
if self.training and batch_size == 1 and self.use_encoder_batch_norm:
|
||||
# This check is necessary as BatchNorm fails if the
|
||||
# batch_size is equal to 1 (can't compute the variance).
|
||||
logging.warning("BatchNorm will not be applied to the encoded image features as the"
|
||||
"effective batch size is 1 on this device.")
|
||||
else:
|
||||
encoded_features = self.layer_norm(encoded_features)
|
||||
encoded_seq.append(encoded_features)
|
||||
encoded_input = non_imaging_seq if encoded_seq == [] else torch.stack(encoded_seq, dim=1)
|
||||
return super().forward(encoded_input)
|
||||
|
||||
def get_last_encoder_layer_names(self) -> List[str]:
|
||||
return self.last_encoder_layer
|
||||
|
||||
def get_input_tensors(self, sequences: List[ClassificationItemSequence]) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns the input tensors as a List where the first element corresponds to the non-imaging features.
|
||||
The second corresponds to the images loaded as required by the image encoder.
|
||||
"""
|
||||
non_imaging_seq = super().get_input_tensors(sequences)[0]
|
||||
if self.image_encoder is not None:
|
||||
seq_flattened_imaging = [torch.stack([self.image_encoder.get_input_tensors(item)[0]
|
||||
for item in seq.items], dim=0)
|
||||
for seq in sequences]
|
||||
imaging_seq = sequences_to_padded_tensor(seq_flattened_imaging)
|
||||
return [imaging_seq, non_imaging_seq]
|
||||
else:
|
||||
return [non_imaging_seq]
|
|
@ -15,7 +15,6 @@ from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_in
|
|||
from InnerEye.ML.lightning_models import ScalarLightning
|
||||
from InnerEye.ML.pipelines.inference import InferencePipelineBase
|
||||
from InnerEye.ML.scalar_config import EnsembleAggregationType, ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
|
||||
|
||||
|
||||
|
@ -103,10 +102,7 @@ class ScalarInferencePipeline(ScalarInferencePipelineBase):
|
|||
:return: Returns ScalarInferencePipelineBase.Result with the subject ids, ground truth labels and predictions.
|
||||
"""
|
||||
assert isinstance(self.model_config, ScalarModelBase)
|
||||
target_indices = self.model_config.get_target_indices() \
|
||||
if isinstance(self.model_config, SequenceModelBase) else []
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model.model,
|
||||
target_indices=target_indices,
|
||||
sample=sample)
|
||||
model_inputs_and_labels.move_to_device(self.model.device)
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -54,7 +54,6 @@ from InnerEye.ML.reports.notebook_report import generate_classification_crossval
|
|||
generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \
|
||||
get_ipynb_report_name, reports_folder
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler, download_all_checkpoints_from_run
|
||||
from InnerEye.ML.visualizers import activation_maps
|
||||
from InnerEye.ML.visualizers.plot_cross_validation import \
|
||||
|
@ -96,9 +95,7 @@ def is_classification_model(model: Any) -> bool:
|
|||
"""
|
||||
Returns True if the given object is an InnerEye classification, but not a sequence model.
|
||||
"""
|
||||
return (isinstance(model, ScalarModelBase)
|
||||
and model.is_classification_model
|
||||
and not isinstance(model, SequenceModelBase))
|
||||
return isinstance(model, ScalarModelBase)
|
||||
|
||||
|
||||
class MLRunner:
|
||||
|
@ -843,7 +840,7 @@ class MLRunner:
|
|||
val_metrics=path_to_best_epoch_val,
|
||||
test_metrics=path_to_best_epoch_test)
|
||||
else:
|
||||
if isinstance(config, ScalarModelBase) and not isinstance(config, SequenceModelBase):
|
||||
if isinstance(config, ScalarModelBase):
|
||||
generate_classification_notebook(
|
||||
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
|
||||
config=config,
|
||||
|
|
|
@ -1,157 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import param
|
||||
from pandas import DataFrame
|
||||
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
|
||||
from InnerEye.ML.metrics_dict import SequenceMetricsDict
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
SEQUENCE_LENGTH_STATS_FILE = "sequence_length_stats.txt"
|
||||
SEQUENCE_LENGTH_FILE = "sequence_length.csv"
|
||||
|
||||
|
||||
class SequenceModelBase(ScalarModelBase):
|
||||
sequence_column: Optional[str] = \
|
||||
param.String(allow_None=True, default=None,
|
||||
doc="If provided, create a sequence dataset, ordering by the column given here. The value in that "
|
||||
"column is expected to be numeric, starting at 0.")
|
||||
|
||||
min_sequence_position_value: int = \
|
||||
param.Integer(default=0, bounds=(0, None),
|
||||
doc="When creating a sequence dataset, restrict it to items with a sequence index at min this "
|
||||
"value. For example, if sequence_min_index==2, only items with sequence positions >= 2 will "
|
||||
"be retained, and sequences that don't have all positions up to (including) 2 will be "
|
||||
"discarded.")
|
||||
|
||||
max_sequence_position_value: Optional[int] = \
|
||||
param.Integer(default=None, allow_None=True, bounds=(0, None),
|
||||
doc="When creating a sequence dataset, restrict it to items with a sequence index "
|
||||
"at most this value. For example, if sequence_max_index==3, only items with sequence "
|
||||
"positions 0, 1, 2, and 3 will be retained.")
|
||||
|
||||
sequence_target_positions: List[int] = \
|
||||
param.List(class_=int,
|
||||
doc="Stores the sequence positions for which the model should make predictions. Sequence positions "
|
||||
"are given by the value in the sequence_column. For example, if a sequence consists of items "
|
||||
"with positions [2, 3, 4, 5], and sequence_target_position==[2,5], the model would be evaluated "
|
||||
"on the first and last sequence elements.")
|
||||
|
||||
temperature_scaling_config: Optional[TemperatureScalingConfig] = param.ClassSelector(
|
||||
class_=TemperatureScalingConfig,
|
||||
allow_None=True,
|
||||
default=None,
|
||||
doc="If a config is provided then it will be used to learn a temperature scaling parameter using the "
|
||||
"validation set to calibrate the model logits see: https://arxiv.org/abs/1706.04599 for each "
|
||||
"epoch that requires a checkpoint to be saved. Turned off by default.")
|
||||
|
||||
def __init__(self, **params: Any):
|
||||
super().__init__(**params)
|
||||
# For sequence models, create a hook for computing dataset statistics by default, because sequence
|
||||
# length is expected to have a major impact on performance. If an alternative hook is needed,
|
||||
# overwrite the hook in a derived class or after instantiating the model configuration.
|
||||
self.dataset_stats_hook = self.compute_dataset_stats_hook
|
||||
if len(self.sequence_target_positions) == 0:
|
||||
raise ValueError("sequence_target_positions must not be empty")
|
||||
if self.temperature_scaling_config:
|
||||
logging.info(f"Temperature scaling will be performed on the "
|
||||
f"validation set using the config: {self.temperature_scaling_config}")
|
||||
|
||||
def validate(self) -> None:
|
||||
self.target_names = [SequenceMetricsDict.get_hue_name_from_target_index(p)
|
||||
for p in self.sequence_target_positions]
|
||||
|
||||
def get_target_indices(self) -> List[int]:
|
||||
"""
|
||||
Computes the zero based array indices inside of a sequence of items
|
||||
for which the model should make predictions.
|
||||
"""
|
||||
return [pos - self.min_sequence_position_value for pos in self.sequence_target_positions]
|
||||
|
||||
def get_total_number_of_numerical_non_imaging_features(self) -> int:
|
||||
return len(self.numerical_columns)
|
||||
|
||||
def get_total_number_of_categorical_non_imaging_features(self) -> int:
|
||||
if self.categorical_feature_encoder:
|
||||
return sum([self.categorical_feature_encoder.get_feature_length(col) for col in self.categorical_columns])
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def is_non_imaging_model(self) -> bool:
|
||||
"""
|
||||
Returns whether the model uses non image features only
|
||||
"""
|
||||
return self.image_file_column is None
|
||||
|
||||
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
|
||||
sample_transform = self.get_scalar_item_transform()
|
||||
assert sample_transform.train is not None # for mypy
|
||||
assert sample_transform.val is not None # for mypy
|
||||
assert sample_transform.test is not None # for mypy
|
||||
|
||||
train = SequenceDataset(self,
|
||||
dataset_splits.train,
|
||||
name="training",
|
||||
sample_transform=sample_transform.train)
|
||||
val = SequenceDataset(self,
|
||||
dataset_splits.val,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="validation",
|
||||
sample_transform=sample_transform.val)
|
||||
test = SequenceDataset(self,
|
||||
dataset_splits.test,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="test",
|
||||
sample_transform=sample_transform.test)
|
||||
|
||||
return {
|
||||
ModelExecutionMode.TRAIN: train,
|
||||
ModelExecutionMode.VAL: val,
|
||||
ModelExecutionMode.TEST: test
|
||||
}
|
||||
|
||||
def compute_dataset_stats_hook(self, datasets: Dict[ModelExecutionMode, Any]) -> None:
|
||||
"""
|
||||
Writes files with details and summary statistics about the datasets for each of the 3 dataset
|
||||
splits (train/val/test).
|
||||
"""
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
|
||||
mode_series = []
|
||||
id_series = []
|
||||
length_series = []
|
||||
for mode in ModelExecutionMode:
|
||||
dataset = datasets[mode]
|
||||
assert isinstance(dataset, SequenceDataset)
|
||||
for seq in dataset.items:
|
||||
mode_series.append(mode.value)
|
||||
id_series.append(seq.id)
|
||||
length_series.append(len(seq.items))
|
||||
# Add a constant column that is the cross validation index, so that we can more easily merge these files later
|
||||
# in the post-crossvalidation hook.
|
||||
df = DataFrame.from_dict({
|
||||
LoggingColumns.CrossValidationSplitIndex.value: [self.cross_validation_split_index] * len(mode_series),
|
||||
LoggingColumns.DataSplit.value: mode_series,
|
||||
LoggingColumns.Patient.value: id_series,
|
||||
LoggingColumns.SequenceLength.value: length_series
|
||||
})
|
||||
self.logs_folder.mkdir(exist_ok=True, parents=True)
|
||||
details_file = self.logs_folder / SEQUENCE_LENGTH_FILE
|
||||
df.to_csv(details_file, index=False)
|
||||
# Drop all columns apart from the sequence length column, so that the stats file will also contain
|
||||
# the name of the series that is described
|
||||
stats = df.drop(columns=[LoggingColumns.Patient.value, LoggingColumns.CrossValidationSplitIndex.value]) \
|
||||
.groupby(by=LoggingColumns.DataSplit.value).describe()
|
||||
out_file = self.logs_folder / SEQUENCE_LENGTH_STATS_FILE
|
||||
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 150):
|
||||
out_file.write_text(str(stats))
|
|
@ -5,19 +5,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, List, TypeVar
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
|
||||
from InnerEye.Common.common_util import check_properties_are_not_none
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataSource, SequenceDataSource
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
|
||||
|
||||
FT = TypeVar('FT', ClassificationItemSequence[SequenceDataSource], ScalarDataSource)
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataSource
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FeatureStatistics(Generic[FT]):
|
||||
class FeatureStatistics:
|
||||
"""
|
||||
Class to store statistics (mean and standard deviation) of a set of features in a given dataset.
|
||||
Allows to perform feature standardization for this set of features.
|
||||
|
@ -29,7 +26,7 @@ class FeatureStatistics(Generic[FT]):
|
|||
check_properties_are_not_none(self)
|
||||
|
||||
@staticmethod
|
||||
def from_data_sources(sources: List[FT]) -> FeatureStatistics:
|
||||
def from_data_sources(sources: List[ScalarDataSource]) -> FeatureStatistics:
|
||||
"""
|
||||
For the provided data sources, compute the mean and std across all non-image features across all entries.
|
||||
|
||||
|
@ -40,11 +37,7 @@ class FeatureStatistics(Generic[FT]):
|
|||
if len(sources) == 0:
|
||||
raise ValueError("sources must have a length greater than 0")
|
||||
|
||||
data_sources: List[Any] # for mypy
|
||||
if isinstance(sources[0], ClassificationItemSequence):
|
||||
data_sources = [item for seq in sources for item in seq.items]
|
||||
else:
|
||||
data_sources = sources
|
||||
data_sources: List[Any] = sources
|
||||
|
||||
numerical_non_image_features = [x.numerical_non_image_features for x in data_sources]
|
||||
if len(numerical_non_image_features) == 0:
|
||||
|
@ -88,7 +81,7 @@ class FeatureStatistics(Generic[FT]):
|
|||
std = torch.sqrt(torch.max(variance, torch.zeros_like(variance)))
|
||||
return FeatureStatistics(mean=mean, std=std)
|
||||
|
||||
def standardize(self, sources: List[FT]) -> List[FT]:
|
||||
def standardize(self, sources: List[ScalarDataSource]) -> List[ScalarDataSource]:
|
||||
"""
|
||||
For the provided data sources, apply standardization to the non-image features in each source. This will
|
||||
standardize them to mean 0, variance 1 across all sequences.
|
||||
|
@ -104,14 +97,7 @@ class FeatureStatistics(Generic[FT]):
|
|||
new_features[zero_or_nan] = source.numerical_non_image_features[zero_or_nan]
|
||||
return source.clone_with_overrides(numerical_non_image_features=new_features)
|
||||
|
||||
def apply_sequence(seq: ClassificationItemSequence) -> ClassificationItemSequence:
|
||||
# noinspection PyTypeChecker
|
||||
return ClassificationItemSequence(id=seq.id, items=list(map(apply_source, seq.items)))
|
||||
|
||||
if len(sources) > 0:
|
||||
if isinstance(sources[0], ClassificationItemSequence):
|
||||
return list(map(apply_sequence, sources)) # type: ignore
|
||||
else:
|
||||
return list(map(apply_source, sources)) # type: ignore
|
||||
else:
|
||||
return sources
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import MSELoss
|
||||
|
@ -18,7 +18,6 @@ from InnerEye.ML.common import ModelExecutionMode
|
|||
from InnerEye.ML.config import ModelArchitectureConfig, PaddingMode, SegmentationLoss, SegmentationModelBase, \
|
||||
basic_size_shrinkage
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
|
||||
from InnerEye.ML.deep_learning_config import OptimizerParams, OptimizerType
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel, CropSizeConstraints
|
||||
|
@ -30,11 +29,9 @@ from InnerEye.ML.models.losses.cross_entropy import CrossEntropyLoss
|
|||
from InnerEye.ML.models.losses.mixture import MixtureLoss
|
||||
from InnerEye.ML.models.losses.soft_dice import SoftDiceLoss
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
|
||||
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
|
||||
from InnerEye.ML.utils.supervised_criterion import BinaryCrossEntropyWithLogitsLoss, SupervisedLearningCriterion
|
||||
from InnerEye.ML.utils.temperature_scaling import ModelWithTemperature
|
||||
from InnerEye.ML.visualizers.model_summary import ModelSummary
|
||||
|
||||
|
||||
|
@ -221,9 +218,7 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware
|
|||
# get_model_input function to convert the dataset item to input tensors, and feed them through the model.
|
||||
train_dataset = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN)
|
||||
train_item_0 = next(iter(train_dataset.as_data_loader(shuffle=False, batch_size=1, num_dataload_workers=0)))
|
||||
target_indices = config.get_target_indices() if isinstance(config, SequenceModelBase) else []
|
||||
model_inputs = get_scalar_model_inputs_and_labels(model,
|
||||
target_indices=target_indices,
|
||||
sample=train_item_0)
|
||||
# The model inputs may already be converted to float16, assuming that we would do mixed precision.
|
||||
# However, the model is not yet converted to float16 when this function is called, hence convert back to float32
|
||||
|
@ -248,16 +243,11 @@ def create_model_with_temperature_scaling(config: ModelConfigBase) -> Any:
|
|||
"""
|
||||
# wrap the model around a temperature scaling model if required
|
||||
model = config.create_model()
|
||||
if isinstance(config, SequenceModelBase) and config.temperature_scaling_config:
|
||||
model = ModelWithTemperature(model, config.temperature_scaling_config)
|
||||
return model
|
||||
|
||||
|
||||
E = TypeVar('E', List[ClassificationItemSequence[ScalarItem]], ScalarItem)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalarModelInputsAndLabels(Generic[E]):
|
||||
class ScalarModelInputsAndLabels():
|
||||
"""
|
||||
Holds the results of calling get_scalar_model_inputs_and_labels: For a given sample returned by the data loader,
|
||||
create the model inputs, the labels, the list of subjects (data loader sample can be batched),
|
||||
|
@ -266,7 +256,7 @@ class ScalarModelInputsAndLabels(Generic[E]):
|
|||
model_inputs: List[torch.Tensor]
|
||||
labels: torch.Tensor
|
||||
subject_ids: List[str]
|
||||
data_item: E
|
||||
data_item: ScalarItem
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
common_util.check_properties_are_not_none(self)
|
||||
|
@ -281,7 +271,6 @@ class ScalarModelInputsAndLabels(Generic[E]):
|
|||
|
||||
|
||||
def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
|
||||
target_indices: List[int],
|
||||
sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
|
||||
"""
|
||||
For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
|
||||
|
@ -293,29 +282,12 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
|
|||
:return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
|
||||
label tensor, subject IDs, and the data item reconstructed from the data loader output
|
||||
"""
|
||||
if target_indices:
|
||||
sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model # type: ignore
|
||||
sequences = ClassificationItemSequence.from_minibatch(sample)
|
||||
subject_ids = [x.id for x in sequences]
|
||||
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
|
||||
sequences=sequences,
|
||||
target_indices=target_indices
|
||||
)
|
||||
model_inputs = sequence_model.get_input_tensors(sequences)
|
||||
|
||||
return ScalarModelInputsAndLabels[List[ClassificationItemSequence]](
|
||||
model_inputs=model_inputs,
|
||||
labels=labels,
|
||||
subject_ids=subject_ids,
|
||||
data_item=sequences
|
||||
)
|
||||
else:
|
||||
scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore
|
||||
scalar_item = ScalarItem.from_dict(sample)
|
||||
subject_ids = [str(x.id) for x in scalar_item.metadata] # type: ignore
|
||||
model_inputs = scalar_model.get_input_tensors(scalar_item)
|
||||
|
||||
return ScalarModelInputsAndLabels[ScalarItem](
|
||||
return ScalarModelInputsAndLabels(
|
||||
model_inputs=model_inputs,
|
||||
labels=scalar_item.label,
|
||||
subject_ids=subject_ids,
|
||||
|
|
|
@ -3,15 +3,10 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import abc
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
from InnerEye.ML.utils.sequence_utils import map_packed_sequence_data
|
||||
|
||||
T = TypeVar('T', torch.Tensor, PackedSequence)
|
||||
|
||||
|
||||
class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
|
||||
|
@ -27,7 +22,7 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
|
|||
self.smoothing_eps = smoothing_eps
|
||||
self.is_binary_classification = is_binary_classification
|
||||
|
||||
def forward(self, *input: T, **kwargs: Any) -> Any:
|
||||
def forward(self, *input: torch.Tensor, **kwargs: Any) -> Any:
|
||||
def _smooth_target(target: torch.Tensor) -> torch.Tensor:
|
||||
if self.is_binary_classification or len(target.shape) <= 2:
|
||||
_num_classes = 2
|
||||
|
@ -39,11 +34,8 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
|
|||
return target * (1.0 - self.smoothing_eps) + \
|
||||
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
|
||||
|
||||
_input: List[T] = list(input)
|
||||
_input: List[torch.Tensor] = list(input)
|
||||
if self.smoothing_eps > 0.0:
|
||||
if isinstance(_input[1], PackedSequence):
|
||||
_input[1] = map_packed_sequence_data(_input[1], _smooth_target)
|
||||
else:
|
||||
_input[1] = _smooth_target(_input[1])
|
||||
|
||||
return self.forward_minibatch(*_input, **kwargs)
|
||||
|
@ -102,8 +94,5 @@ class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
|
|||
sorted(self._class_counts.items())] # Uses the first number on the tuple to compare
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
def forward_minibatch(self, output: T, target: T, **kwargs: Any) -> Any:
|
||||
if isinstance(target, PackedSequence) and isinstance(output, PackedSequence):
|
||||
return self._loss_fn(output.data.view(-1, 1), target.data.view(-1, 1))
|
||||
else:
|
||||
def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> Any:
|
||||
return self._loss_fn(output, target)
|
||||
|
|
|
@ -1,661 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from InnerEye.Common.metrics_constants import SEQUENCE_POSITION_HUE_NAME_PREFIX
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
|
||||
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
|
||||
from InnerEye.ML.reports.notebook_report import convert_to_html
|
||||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
|
||||
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES
|
||||
from InnerEye.ML.visualizers.model_hooks import HookBasedFeatureExtractor
|
||||
|
||||
|
||||
def _tensor_as_numpy(tensor: torch.Tensor) -> np.ndarray:
|
||||
return tensor.detach().cpu().numpy()
|
||||
|
||||
|
||||
class GradientBasedFeatureExtractor(HookBasedFeatureExtractor):
|
||||
"""
|
||||
Base class for GradCam and BackPropagation classes.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module,
|
||||
config: ScalarModelBase,
|
||||
target_layer: Any,
|
||||
target_pos: int = -1):
|
||||
if not config.is_classification_model:
|
||||
raise NotImplementedError("Visualizations maps with GradCam are only"
|
||||
"implemented for classification models.")
|
||||
super().__init__(model, target_layer)
|
||||
self.config = config
|
||||
self.hooks: List[Any] = []
|
||||
if config.use_gpu:
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.encode_jointly = getattr(self.net, "encode_channels_jointly", True)
|
||||
self.imaging_feature_type = getattr(self.net, "imaging_feature_type", ImagingFeatureType.Image)
|
||||
self.num_non_image_features = getattr(self.net, "num_non_image_features", 0)
|
||||
self.target_label_index = target_pos
|
||||
self.logits = torch.Tensor()
|
||||
self.probabilities = torch.Tensor()
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks = []
|
||||
|
||||
def get_target_score(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the target score i.e. logits for a positive predicted class,
|
||||
negative logits for negative target class.
|
||||
|
||||
:param probabilities: Probabilities associated to the logits
|
||||
:param logits: Output of the network before Sigmoid
|
||||
"""
|
||||
if self.logits.shape[-1] != 1:
|
||||
raise NotImplementedError("More than one output class")
|
||||
return torch.where(self.probabilities > 0.5, self.logits, -self.logits)
|
||||
|
||||
def backward(self) -> None:
|
||||
"""
|
||||
Defines the backward pass. Computes first the target scores to use to
|
||||
for backpropagation and then updates the `gradients` attribute based on
|
||||
the current value of the `logits` attribute (set in the forward pass).
|
||||
"""
|
||||
target_scores = self.get_target_score() # [B, num_targets, 1] or [B, 1]
|
||||
self.model.zero_grad()
|
||||
|
||||
# If we have a sequence model, with potentially multiple labels.
|
||||
# Only backpropagate the gradients for the given target_pos
|
||||
if isinstance(self.config, SequenceModelBase):
|
||||
gradients_to_propagate = torch.zeros(target_scores.shape, device=self.device)
|
||||
gradients_to_propagate[:, self.target_label_index, :] = 1
|
||||
else:
|
||||
gradients_to_propagate = torch.ones(target_scores.shape, device=self.device)
|
||||
target_scores.backward(gradient=gradients_to_propagate)
|
||||
self.remove_hooks()
|
||||
|
||||
|
||||
class GradCam(GradientBasedFeatureExtractor):
|
||||
"""
|
||||
Class to generate GradCam maps for images, "Pseudo-GradCam" (i.e. ReLu(input x gradients))
|
||||
for non-images features of one batch for the given classification model. Tested and maintained for
|
||||
ImageEncoderWithMLP and RNNClassifier (models that take both images and non-imaging feautres as input).
|
||||
|
||||
GradCam computes Relu(Gradients x Activations) at the output of the encoder of the network
|
||||
(before the global pooling layer). "PseudoGradCam" for non-imaging features denotes
|
||||
ReLu(input x gradients) for non-imaging features. "PseudoGradCam" is
|
||||
used to compare relative feature importance of various non-imaging features for the final classification
|
||||
task.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[DeviceAwareModule, torch.nn.DataParallel],
|
||||
config: ScalarModelBase) -> None:
|
||||
"""
|
||||
|
||||
:param model: The model to analyse
|
||||
:param config: The ScalarModelBase config defining the parameters of this model.
|
||||
"""
|
||||
self.total_num_categorical_features = config.get_total_number_of_categorical_non_imaging_features()
|
||||
self.total_number_of_numerical_non_imaging_features = \
|
||||
config.get_total_number_of_numerical_non_imaging_features()
|
||||
self.is_non_imaging_model = config.is_non_imaging_model
|
||||
if self.is_non_imaging_model:
|
||||
super().__init__(model, config=config, target_layer=None)
|
||||
else:
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
_model: DeviceAwareModule = model.module # type: ignore
|
||||
target_layer = _model.get_last_encoder_layer_names()
|
||||
self.conv_in_3d = bool(_model.conv_in_3d)
|
||||
else:
|
||||
target_layer = model.get_last_encoder_layer_names()
|
||||
self.conv_in_3d = bool(model.conv_in_3d)
|
||||
super().__init__(model=model, config=config, target_layer=target_layer)
|
||||
self.gradients: Dict = {}
|
||||
self.activations: Dict = {}
|
||||
|
||||
def backward_hook_fn(self, module: Module, grad_in: torch.Tensor, grad_out: torch.Tensor) -> None:
|
||||
"""
|
||||
Backward hook to save the gradients per device (to allow GradCam to be computed
|
||||
with DataParallel models when training on multiple GPUs).
|
||||
"""
|
||||
device = str(grad_out[0].get_device())
|
||||
if device not in self.gradients:
|
||||
self.gradients[device] = []
|
||||
self.gradients[device].append(grad_out[0].data.clone())
|
||||
|
||||
def forward_hook_fn(self, module: Module, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
"""
|
||||
Forward hook to save the activations of a given layer (per device to allow GradCam to be computed
|
||||
with DataParallel models when training on multiple GPUs.
|
||||
"""
|
||||
device = str(output[0].get_device())
|
||||
if device not in self.activations:
|
||||
self.activations[device] = []
|
||||
if isinstance(output, tuple):
|
||||
self.activations[device].append([output[index].data.clone() for index in range(len(output))])
|
||||
else:
|
||||
self.activations[device].append(output.data.clone())
|
||||
|
||||
def forward(self, *input) -> None: # type: ignore
|
||||
"""
|
||||
Triggers the call to the forward pass of the module. Prior to call the forward model function, we
|
||||
set the forward and backward passes. When calling this function, the `activations` attribute
|
||||
will containing the activations of the target layer for the given `input` batch passed as an
|
||||
argument to this function.
|
||||
"""
|
||||
self.activations = {}
|
||||
if self.layer_name is not None:
|
||||
submodule = self.net
|
||||
for el in self.layer_name:
|
||||
submodule = submodule._modules[el] # type: ignore
|
||||
target_layer = submodule
|
||||
self.hooks.append(target_layer.register_forward_hook(self.forward_hook_fn))
|
||||
self.hooks.append(target_layer.register_backward_hook(self.backward_hook_fn)) # type: ignore
|
||||
|
||||
self.logits = self.model(*input)
|
||||
if isinstance(self.logits, List):
|
||||
self.logits = torch.nn.parallel.gather(self.logits, target_device=self.device)
|
||||
self.probabilities = torch.nn.Sigmoid()(self.logits)
|
||||
|
||||
def backward(self) -> None:
|
||||
"""
|
||||
Defines the backward pass. Computes first the target scores to use to
|
||||
for backpropagation and then updates the `gradients` attribute based on
|
||||
the current value of the `logits` attribute (set in the forward pass).
|
||||
"""
|
||||
self.gradients = {}
|
||||
super().backward()
|
||||
|
||||
def _get_image_grad_cam(self, input: List[torch.Tensor]) -> np.ndarray:
|
||||
"""
|
||||
Get GradCam mps for images input. GradCam computes
|
||||
Relu(Gradients x Activations) at the output of the encoder of the network
|
||||
(before the global pooling layer).
|
||||
|
||||
:param input: input batch
|
||||
:return: the GradCam maps
|
||||
"""
|
||||
list_gradients = []
|
||||
list_activations = []
|
||||
# put all channels in one tensor per device
|
||||
for device in self.gradients:
|
||||
list_gradients.append(torch.stack(self.gradients[device], dim=1)) # [B, C_in, C_out, Z, X, Y]
|
||||
list_activations.append(torch.stack(self.activations[device], dim=1)) # [B, C_in, C_out, Z, X, Y]
|
||||
|
||||
if self.config.use_gpu:
|
||||
activations = torch.nn.parallel.gather(list_activations, target_device=self.device)
|
||||
gradients = torch.nn.parallel.gather(list_gradients, target_device=self.device)
|
||||
|
||||
else:
|
||||
assert len(list_activations) == 1
|
||||
activations = list_activations[0]
|
||||
gradients = list_gradients[0]
|
||||
self.gradients = {}
|
||||
self.activations = {}
|
||||
|
||||
B, C_in = input[0].shape[:2]
|
||||
Z, X, Y = input[0].shape[-3:]
|
||||
B_act, _, C_act, Z_act, X_act, Y_act = activations.shape
|
||||
if self.conv_in_3d:
|
||||
weights = torch.mean(gradients, dim=(3, 4, 5), keepdim=True)
|
||||
Z_low = Z_act
|
||||
else:
|
||||
weights = torch.mean(gradients, dim=(4, 5), keepdim=True)
|
||||
Z_low = Z
|
||||
del list_gradients, gradients
|
||||
|
||||
low_dim_cam = torch.nn.functional.relu(torch.mul(activations, weights).sum(dim=2))
|
||||
del weights, list_activations, activations
|
||||
|
||||
# Case one separate encoding per channel i.e. one GradCam map per channel
|
||||
if not self.encode_jointly:
|
||||
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
|
||||
assert low_dim_cam.shape == (B, C_in, Z_low, X_act, Y_act) \
|
||||
or low_dim_cam.shape == (B, C_in / HDF5_NUM_SEGMENTATION_CLASSES, Z_low, X_act, Y_act)
|
||||
|
||||
elif self.imaging_feature_type == ImagingFeatureType.Image:
|
||||
assert low_dim_cam.shape == (B, C_in, Z_low, X_act, Y_act)
|
||||
# Case one global encoding i.e. one GradCam map per image
|
||||
else:
|
||||
assert low_dim_cam.shape == (B, 1, Z_low, X_act, Y_act)
|
||||
|
||||
grad_cam = torch.nn.functional.interpolate(
|
||||
low_dim_cam,
|
||||
(Z, X, Y),
|
||||
mode="trilinear"
|
||||
)
|
||||
return _tensor_as_numpy(grad_cam)
|
||||
|
||||
def _get_non_imaging_grad_cam(self) -> np.ndarray:
|
||||
"""
|
||||
Computes the "Pseudo GradCam" for non-imaging features i.e.
|
||||
ReLu(non_imaging_inputs x gradients).
|
||||
"""
|
||||
assert self.non_image_input.grad is not None
|
||||
total_pseudo_cam_non_image = _tensor_as_numpy(torch.nn.functional.relu(
|
||||
torch.mul(self.non_image_input, self.non_image_input.grad)))
|
||||
batch_size = self.non_image_input.shape[0]
|
||||
non_image_input = _tensor_as_numpy(self.non_image_input.detach())
|
||||
if self.total_num_categorical_features > 0:
|
||||
if len(total_pseudo_cam_non_image.shape) == 2:
|
||||
total_pseudo_cam_non_image = total_pseudo_cam_non_image.reshape(batch_size, 1, -1)
|
||||
non_image_input = self.non_image_input.reshape(batch_size, 1, -1)
|
||||
|
||||
pseudo_cam_numerical = total_pseudo_cam_non_image[:, :,
|
||||
:self.total_number_of_numerical_non_imaging_features]
|
||||
|
||||
pseudo_cam_one_hot = total_pseudo_cam_non_image[:, :,
|
||||
self.total_number_of_numerical_non_imaging_features:]
|
||||
categorical_input_one_hot = non_image_input[:, :, self.total_number_of_numerical_non_imaging_features:]
|
||||
|
||||
# Back to "not one hot", only one value per feature is non zero
|
||||
batch_size, number_positions = pseudo_cam_one_hot.shape[:2]
|
||||
if isinstance(self.config, SequenceModelBase):
|
||||
pseudo_cam_categorical = np.zeros((batch_size, number_positions, len(self.config.categorical_columns)))
|
||||
for b in range(batch_size):
|
||||
for t in range(number_positions):
|
||||
# Some features come from sequence padding, for those the entire row is 0 i.e. the feature
|
||||
# is not really one-hot encoded.
|
||||
if np.any(categorical_input_one_hot[b, t] != 0):
|
||||
pseudo_cam_categorical[b, t] = pseudo_cam_one_hot[
|
||||
b, t, categorical_input_one_hot[b, t] != 0]
|
||||
else:
|
||||
# For a non-sequence model a categorical feature might appear several times for several channels but
|
||||
# there
|
||||
# is no padding. Hence we handle the conversion differently.
|
||||
pseudo_cam_categorical = pseudo_cam_one_hot[categorical_input_one_hot.cpu() != 0].reshape(
|
||||
(batch_size, number_positions, -1))
|
||||
|
||||
return np.concatenate([pseudo_cam_numerical, pseudo_cam_categorical], axis=2)
|
||||
else:
|
||||
return total_pseudo_cam_non_image
|
||||
|
||||
def generate(self, input: List[torch.Tensor], target_position: int = -1, target_label_index: int = -1) \
|
||||
-> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Generates the GradCam for images, PseudoGradCam for non-imaging features
|
||||
of one batch for the given classification model.
|
||||
|
||||
GradCam computes Relu(Gradients x Activations) at the output of the encoder of the network
|
||||
(before the global pooling layer). "PseudoGradCam" for non-imaging features denotes
|
||||
ReLu(input x gradients) for non-imaging features. "PseudoGradCam" is used to compare relative feature importance
|
||||
of various non-imaging features for the final classification task.
|
||||
|
||||
:param input: input image [B, C, Z, X, Y]
|
||||
:param target_position: in case of sequence model with multiple target weeks, specify which target
|
||||
position prediction should be visualized. By default the last one.
|
||||
:param target_label_index: index of the target label in the array of targets labels i.e. if target
|
||||
positions are [2,3,5], the target_label_index for position 3 is 1.
|
||||
:return: grad_cam: grad_cam maps [B, Z, X, Y]
|
||||
"""
|
||||
self.target_label_index = target_label_index
|
||||
|
||||
self.model.eval()
|
||||
if self.num_non_image_features > 0:
|
||||
self.non_image_input = input[1].clone().to(self.device).requires_grad_(True)
|
||||
self.forward(*[input[0], self.non_image_input])
|
||||
elif self.is_non_imaging_model:
|
||||
self.non_image_input = input[0].clone().to(self.device).requires_grad_(True)
|
||||
self.forward(self.non_image_input)
|
||||
else:
|
||||
self.forward(*input)
|
||||
self.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
grad_cam_image = None
|
||||
pseudo_cam_non_image = None
|
||||
if not self.is_non_imaging_model:
|
||||
grad_cam_image = self._get_image_grad_cam(input)
|
||||
if target_position > -1:
|
||||
grad_cam_image = grad_cam_image[:, :(target_position + 1), ...]
|
||||
if self.num_non_image_features > 0 or self.is_non_imaging_model:
|
||||
pseudo_cam_non_image = self._get_non_imaging_grad_cam()
|
||||
if target_position > -1:
|
||||
pseudo_cam_non_image = pseudo_cam_non_image[:, :(target_position + 1), ...]
|
||||
if self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
if self.encode_jointly:
|
||||
# In this case broadcasting can happen automatically in GuidedGradCam
|
||||
# computation.
|
||||
assert grad_cam_image is not None # for mypy
|
||||
assert grad_cam_image.shape[1] == 1
|
||||
else:
|
||||
# Otherwise, copy GradCam output twice to compute GuidedGradCam (once for images and
|
||||
# once for segmentations).
|
||||
grad_cam_image = np.concatenate([grad_cam_image, grad_cam_image], axis=1)
|
||||
return grad_cam_image, pseudo_cam_non_image, _tensor_as_numpy(self.probabilities[:, target_label_index])
|
||||
|
||||
|
||||
class GuidedBackPropagation(GradientBasedFeatureExtractor):
|
||||
"""
|
||||
Class to compute GuidedBackPropagation maps for images features.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module, config: ScalarModelBase) -> None:
|
||||
super().__init__(model=model, config=config, target_layer=None)
|
||||
|
||||
def guided_backprop_hook(self, module: Module, grad_in: torch.Tensor, grad_out: torch.Tensor) \
|
||||
-> Optional[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
Backward hook for guided Backpropagation.
|
||||
Propagate only positive gradient when backpropagating through ReLu layers.
|
||||
"""
|
||||
# For all ReLU layers propagate only positive gradients
|
||||
if isinstance(module, torch.nn.ReLU):
|
||||
return torch.nn.functional.relu(grad_in[0]),
|
||||
return None
|
||||
|
||||
def forward(self, *input): # type: ignore
|
||||
"""
|
||||
Triggers the call to the forward pass of the module and the registration of the backward hook.
|
||||
"""
|
||||
for layer in self.net.modules():
|
||||
# Type check disabled: the type is correct but the PyTorch documentation is not.
|
||||
# noinspection PyTypeChecker
|
||||
self.hooks.append(layer.register_backward_hook(self.guided_backprop_hook))
|
||||
|
||||
self.image_input_grad = input[0].clone().requires_grad_(True)
|
||||
if self.num_non_image_features > 0:
|
||||
self.logits = self.model(self.image_input_grad, input[1])
|
||||
else:
|
||||
self.logits = self.model(self.image_input_grad)
|
||||
if isinstance(self.logits, List):
|
||||
self.logits = torch.nn.parallel.gather(self.logits, target_device=self.device)
|
||||
self.probabilities = torch.nn.Sigmoid()(self.logits)
|
||||
|
||||
def generate(self, input: List[torch.Tensor],
|
||||
target_position: int = -1,
|
||||
target_label_index: int = -1) -> np.ndarray:
|
||||
"""
|
||||
Generate Guided Backpropagation maps for one input batch.
|
||||
|
||||
:param input: input batch
|
||||
:param target_position: in case of sequence model with multiple target weeks, specify which target
|
||||
position prediction should be visualized. By default the last one.
|
||||
:param target_label_index: index of the target label in the array of targets labels i.e. if target
|
||||
positions are [2,3,5], the target_label_index for position 3 is 1.
|
||||
:return: guided backprop maps, size [B, C, Z, X, Y]
|
||||
"""
|
||||
self.target_label_index = target_label_index
|
||||
self.model.eval()
|
||||
self.forward(*input)
|
||||
if self.config.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
self.backward()
|
||||
|
||||
B, C = input[0].shape[:2]
|
||||
Z, X, Y = input[0].shape[-3:]
|
||||
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
|
||||
grads_of_one_hot = -_tensor_as_numpy(self.image_input_grad.grad)
|
||||
one_hot_input = _tensor_as_numpy(self.image_input_grad)
|
||||
backprop_map = grads_of_one_hot * one_hot_input
|
||||
backprop_map = backprop_map.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES, Z, X, Y))
|
||||
backprop_map = backprop_map.sum(axis=2) # [B, C, Z, X, Y]
|
||||
if target_position > -1:
|
||||
backprop_map = backprop_map[:, :(target_position + 1), ...]
|
||||
return backprop_map
|
||||
elif self.imaging_feature_type == ImagingFeatureType.Image:
|
||||
backprop_map = self.image_input_grad.grad.detach().cpu().numpy().reshape((B, C, Z, X, Y))
|
||||
if target_position > -1:
|
||||
backprop_map = backprop_map[:, :(target_position + 1), ...]
|
||||
return backprop_map
|
||||
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
grads = _tensor_as_numpy(self.image_input_grad.grad) # [1, CLASSES, Z, X, Y] or [1, week, CLASSES, Z, X, Y]
|
||||
input_image = _tensor_as_numpy(self.image_input_grad)
|
||||
if len(grads.shape) == 5:
|
||||
grads = grads.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES + 1, Z, X, Y))
|
||||
input_image = input_image.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES + 1, Z, X, Y))
|
||||
one_hot_input = input_image[:, :, :-1, ...]
|
||||
grads_of_one_hot = grads[:, :, :-1, ...]
|
||||
backprop_map_segmentation = (grads_of_one_hot * one_hot_input).sum(axis=2) # [B, C, Z, X, Y]
|
||||
backprop_map_image = grads[:, :, -1, ...] # [B, C, Z, X, Y]
|
||||
if target_position > -1:
|
||||
backprop_map_segmentation = backprop_map_segmentation[:, :(target_position + 1), ...]
|
||||
backprop_map_image = backprop_map_image[:, :(target_position + 1), ...]
|
||||
return np.concatenate([backprop_map_segmentation, backprop_map_image], axis=1) # [B, 2*C, Z, X, Y]
|
||||
else:
|
||||
ValueError("This imaging feature type is not supported.")
|
||||
|
||||
|
||||
class VisualizationMaps:
|
||||
"""
|
||||
Wrapper class to compute GradCam maps, GuidedGradCam and "Pseudo-GradCam" maps
|
||||
for a specific model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[DeviceAwareModule, torch.nn.DataParallel],
|
||||
config: ScalarModelBase) -> None:
|
||||
self.config = config
|
||||
self.is_non_imaging_model = config.is_non_imaging_model
|
||||
self.grad_cam: GradCam = GradCam(model, config)
|
||||
if not self.is_non_imaging_model:
|
||||
self.guided_backprop: GuidedBackPropagation = GuidedBackPropagation(model, config)
|
||||
self.encode_channels_jointly: bool = self.guided_backprop.encode_jointly
|
||||
self.imaging_feature_type = self.grad_cam.imaging_feature_type
|
||||
|
||||
def generate(self, input: List[torch.Tensor],
|
||||
target_position: int = -1,
|
||||
target_label_index: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Generates the GuidedGradCam, GradCam and PseudoGradCam maps (for non-imaging data)
|
||||
for one batch of data.
|
||||
|
||||
:param input: input batch sixe [B, C, Z, X, Y]
|
||||
:param target_position: in case of sequence model with multiple target weeks, specify which target
|
||||
position prediction should be visualized. By default the last one.
|
||||
:param target_label_index: index of the target label in the array of targets labels i.e. if target
|
||||
positions are [2,3,5], the target_label_index for position 3 is 1.
|
||||
:return: A tuple of GuidedGradCam maps (for image input), GradCam maps (for image input),
|
||||
PseudoGradCam (for non-image inputs), posteriors predicted for the given target position.
|
||||
by the model for this batch of data.
|
||||
"""
|
||||
image_gcam, pseudo_cam_non_img, probability = self.grad_cam.generate(input, target_position, target_label_index)
|
||||
if self.is_non_imaging_model:
|
||||
image_guided_gcam = None
|
||||
else:
|
||||
guided_bp = self.guided_backprop.generate(input, target_position, target_label_index)
|
||||
image_guided_gcam = image_gcam * guided_bp
|
||||
return image_guided_gcam, \
|
||||
image_gcam, \
|
||||
pseudo_cam_non_img, \
|
||||
probability
|
||||
|
||||
def save_visualizations_in_notebook(self,
|
||||
classification_sequence: Union[ScalarItem,
|
||||
List[ClassificationItemSequence[ScalarItem]]],
|
||||
input_batch: List[torch.Tensor],
|
||||
filenames: List[str],
|
||||
ground_truth_labels: np.ndarray,
|
||||
gradcam_dir: Path
|
||||
) -> None:
|
||||
"""
|
||||
Generate, plot and save the visualizations for one batch of
|
||||
data for a sequence model. The visualization are produced in
|
||||
a Jupyter Notebook for readability. There is one notebook generated
|
||||
for each subject. This notebook can be viewed in the AML UI. Additionally
|
||||
a HTML is produced containing only the cells' output.
|
||||
|
||||
:param input_batch: input to the network
|
||||
:param classification_sequence: classification item for current batch
|
||||
:param filenames: a list of filenames for the plots size: [Batch]
|
||||
:param ground_truth_labels: the labels for this input_batch
|
||||
:param gradcam_dir: directory where to save the plots.
|
||||
"""
|
||||
non_image_features = self.config.numerical_columns + self.config.categorical_columns
|
||||
has_non_image_features = len(non_image_features) > 0
|
||||
batch_size = len(filenames)
|
||||
if isinstance(self.config, SequenceModelBase):
|
||||
target_indices = self.config.get_target_indices()
|
||||
if target_indices is None:
|
||||
target_indices = [-1]
|
||||
else:
|
||||
target_indices = [-1]
|
||||
for label_index in range(len(target_indices)):
|
||||
target_position = target_indices[label_index]
|
||||
current_output_dir = self.config.visualization_folder / f"{SEQUENCE_POSITION_HUE_NAME_PREFIX}_" \
|
||||
f"{target_position}"
|
||||
current_output_dir.mkdir(exist_ok=True)
|
||||
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = self.generate(input_batch,
|
||||
target_position,
|
||||
label_index)
|
||||
for i in range(batch_size):
|
||||
if not self.is_non_imaging_model:
|
||||
non_imaging_labels = self._get_non_imaging_plot_labels(
|
||||
classification_sequence, # type: ignore
|
||||
non_image_features,
|
||||
index=i,
|
||||
target_position=target_position)
|
||||
if isinstance(self.config, SequenceModelBase):
|
||||
image = self._get_image_attributes_for_sequence_item(classification_sequence, # type: ignore
|
||||
index=i,
|
||||
target_position=target_position)
|
||||
else:
|
||||
image = self._get_image_attributes_for_scalar_item(classification_sequence, i) # type: ignore
|
||||
|
||||
# Need to temporarily save the variables to access them from the notebook.
|
||||
# Because papermill does not support passing numpy array as parameters.
|
||||
np.save(str(gradcam_dir / "image.npy"), image)
|
||||
np.save(str(gradcam_dir / "gradcam.npy"), grad_cams[i])
|
||||
np.save(str(gradcam_dir / "guided_grad_cam.npy"), guided_grad_cams[i])
|
||||
if has_non_image_features:
|
||||
np.save(str(gradcam_dir / "non_image_pseudo_cam.npy"), pseudo_cam_non_img[i])
|
||||
has_image_features = True
|
||||
else:
|
||||
non_imaging_labels = self._get_non_imaging_plot_labels(
|
||||
classification_sequence, non_image_features, index=i, target_position=target_position)
|
||||
has_non_image_features = True
|
||||
has_image_features = False
|
||||
self.encode_channels_jointly = False
|
||||
self.imaging_feature_type = ImagingFeatureType.Image
|
||||
np.save(str(gradcam_dir / "non_image_pseudo_cam.npy"), pseudo_cam_non_img[i])
|
||||
|
||||
current_label = ground_truth_labels[i, label_index]
|
||||
|
||||
# If the label is NaN it means that we don't have data for this position and
|
||||
# we used padding for the input. Hence do not save visualizations for this position.
|
||||
if not np.isnan(current_label):
|
||||
params_dict = dict(subject_id=filenames[i],
|
||||
target_position=target_position,
|
||||
gradcam_dir=str(gradcam_dir),
|
||||
has_non_image_features=has_non_image_features,
|
||||
probas=str(probas[i]),
|
||||
ground_truth_labels=str(current_label),
|
||||
non_image_labels=non_imaging_labels,
|
||||
encode_jointly=self.encode_channels_jointly,
|
||||
imaging_feature_type=self.imaging_feature_type.value,
|
||||
has_image_features=has_image_features,
|
||||
value_image_and_segmentation=ImagingFeatureType.ImageAndSegmentation.value, )
|
||||
|
||||
result_path = str(current_output_dir.joinpath(f"{filenames[i]}.ipynb"))
|
||||
import papermill
|
||||
papermill.execute_notebook(os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
"gradcam_visualization.ipynb"),
|
||||
result_path,
|
||||
parameters=params_dict,
|
||||
progress_bar=False)
|
||||
convert_to_html(Path(result_path))
|
||||
|
||||
def _get_non_imaging_plot_labels(self, classification_item: Union[ScalarItem,
|
||||
List[ClassificationItemSequence[ScalarItem]]],
|
||||
non_image_features: List[str],
|
||||
index: int,
|
||||
target_position: int = -1) -> List[str]:
|
||||
"""
|
||||
Gets labels to use for the plots of non-imaging feature importance.
|
||||
|
||||
:param classification_item: The classification item for which the return the
|
||||
label (can vary from subject to subject as they might not all contain the same
|
||||
position in case of a sequence model).
|
||||
:param non_image_features: The name of the imaging features used by the model.
|
||||
:param index: The index of the subject in the batch (used only for sequence models).
|
||||
:return: the labels (list of string)
|
||||
"""
|
||||
if isinstance(self.config, SequenceModelBase):
|
||||
channels = []
|
||||
for item in classification_item[index].items: # type: ignore
|
||||
if (item.metadata.sequence_position - self.config.min_sequence_position_value) <= target_position \
|
||||
or target_position == -1:
|
||||
channels.append(item.metadata.sequence_position)
|
||||
return [f"{col}_{channel}" for channel in channels for col in
|
||||
non_image_features] # type: ignore
|
||||
else:
|
||||
non_imaging_labels = []
|
||||
non_image_features = self.config.numerical_columns + self.config.categorical_columns
|
||||
non_image_feature_channels_dict = self.config.get_non_image_feature_channels_dict()
|
||||
for col in non_image_features:
|
||||
non_imaging_labels.extend(
|
||||
[f"{col}_{channel}" for channel in non_image_feature_channels_dict[col]]) # type: ignore
|
||||
return non_imaging_labels
|
||||
|
||||
def _get_image_attributes_for_sequence_item(self,
|
||||
classification_sequence: List[ClassificationItemSequence[
|
||||
ScalarItem]],
|
||||
index: int,
|
||||
target_position: int) -> np.ndarray:
|
||||
"""
|
||||
Extract the image and/or the segmentation for the classification item to be able to
|
||||
produce the visualizations.
|
||||
|
||||
:param classification_sequence: The classification sequence for which to plot (contains
|
||||
the entire batch)
|
||||
:param index: the exact subject for which to plot.
|
||||
:return: An array containing the imaging input to plot.
|
||||
"""
|
||||
images = []
|
||||
segmentations = []
|
||||
for item in classification_sequence[index].items:
|
||||
if (item.metadata.sequence_position - self.config.min_sequence_position_value) <= target_position \
|
||||
or target_position == -1:
|
||||
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
|
||||
segmentations.append(item.segmentations)
|
||||
elif self.imaging_feature_type == ImagingFeatureType.Image:
|
||||
images.append(item.images)
|
||||
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
segmentations.append(item.segmentations)
|
||||
images.append(item.images)
|
||||
if self.imaging_feature_type == ImagingFeatureType.Image:
|
||||
return _tensor_as_numpy(torch.cat(images, dim=0)).astype(float) # type: ignore
|
||||
elif self.imaging_feature_type == ImagingFeatureType.Segmentation:
|
||||
return _tensor_as_numpy(torch.cat(segmentations, dim=0)).astype(int) # type: ignore
|
||||
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
return np.concatenate(
|
||||
[_tensor_as_numpy(torch.cat(images, dim=0)).astype(float), # type: ignore
|
||||
_tensor_as_numpy(torch.cat(segmentations, dim=0)).astype(int)], # type: ignore
|
||||
axis=0)
|
||||
|
||||
def _get_image_attributes_for_scalar_item(self, classification_item: ScalarItem, index: int) -> np.ndarray:
|
||||
"""
|
||||
Extract the image and/or the segmentation for the classification item to be able to
|
||||
produce the visualizations.
|
||||
|
||||
:param classification_item: The classification items for which to plot (contains the
|
||||
entire batch data).
|
||||
:param index: the exact subject for which to plot.
|
||||
:return: An array containing the imaging input to plot.
|
||||
"""
|
||||
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
|
||||
if classification_item.segmentations is None:
|
||||
raise ValueError("Expected classification_item.segmentations to not be None")
|
||||
return _tensor_as_numpy(classification_item.segmentations[index]).astype(int)
|
||||
elif self.imaging_feature_type == ImagingFeatureType.Image:
|
||||
return _tensor_as_numpy(classification_item.images[index]).astype(float)
|
||||
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
image = _tensor_as_numpy(classification_item.images[index]).astype(float) # [C, Z, X, Y]
|
||||
assert classification_item.segmentations is not None # for mypy
|
||||
seg = _tensor_as_numpy(classification_item.segmentations[index]).astype(int)
|
||||
return np.concatenate([seg, image], axis=0)
|
||||
else:
|
||||
raise ValueError("The provided imaging feature type is not supported.")
|
17
README.md
17
README.md
|
@ -11,7 +11,6 @@ On the modelling side, this toolbox supports
|
|||
|
||||
- Segmentation models
|
||||
- Classification and regression models
|
||||
- Sequence models
|
||||
- Adding cloud support to any PyTorch Lightning model, via a [bring-your-own-model setup](docs/bring_your_own_model.md)
|
||||
- Active label cleaning and noise robust learning toolbox (stand-alone folder)
|
||||
|
||||
|
@ -48,8 +47,7 @@ often seen with medical images.
|
|||
- Easy creation of new models via a configuration-based approach, and inheritance from an existing
|
||||
architecture.
|
||||
|
||||
Once training in AzureML is done, the models can be deployed from within AzureML or via
|
||||
[Azure Stack Hub](https://azure.microsoft.com/en-us/products/azure-stack/hub/).
|
||||
Once training in AzureML is done, the models can be deployed from within AzureML.
|
||||
|
||||
## Getting started
|
||||
|
||||
|
@ -180,14 +178,7 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
|
|||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Credits
|
||||
## This toolbox is maintained by the
|
||||
[Microsoft Medical Image Analysis team](https://www.microsoft.com/en-us/research/project/medical-image-analysis/).
|
||||
|
||||
|
||||
This toolbox is maintained by the
|
||||
[Microsoft InnerEye team](https://www.microsoft.com/en-us/research/project/medical-image-analysis/),
|
||||
and has received valuable contributions from a number
|
||||
of people outside our team. We would like to thank in particular our interns,
|
||||
[Yao Quin](http://cseweb.ucsd.edu/~yaq007/), [Zoe Landgraf](https://www.linkedin.com/in/zoe-landgraf-a2212293),
|
||||
[Padmaja Jonnalagedda](https://www.linkedin.com/in/jspadmaja/),
|
||||
[Mathias Perslev](https://github.com/perslev), as well as the AI Residents
|
||||
[Patricia Gillespie](https://www.microsoft.com/en-us/research/people/t-pagill/) and
|
||||
[Guilherme Ilunga](https://gilunga.github.io/).
|
||||
|
|
|
@ -185,7 +185,7 @@ S1,image2,img12.nii,True
|
|||
S2,image2,image22.nii,False
|
||||
""")
|
||||
df = pd.read_csv(csv_string, sep=",", dtype=str)
|
||||
items: List[ScalarDataSource] = DataSourceReader[ScalarDataSource](
|
||||
items: List[ScalarDataSource] = DataSourceReader(
|
||||
data_frame=df,
|
||||
image_channels=["image1", "image2"],
|
||||
image_file_column="path",
|
||||
|
|
|
@ -1,676 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import math
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.dataset.full_image_dataset import collate_with_metadata
|
||||
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
|
||||
from InnerEye.ML.dataset.scalar_dataset import DataSourceReader, filter_valid_classification_data_sources_items
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarDataSource, ScalarItem, SequenceDataSource
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset, add_difference_features, \
|
||||
group_samples_into_sequences
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.features_util import FeatureStatistics
|
||||
from InnerEye.ML.utils.io_util import ImageAndSegmentations
|
||||
from InnerEye.ML.utils.ml_util import set_random_seed
|
||||
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToyMultiLabelSequenceModel, \
|
||||
_get_multi_label_sequence_dataframe
|
||||
from Tests.ML.util import assert_tensors_equal, create_dataset_csv_file
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
|
||||
|
||||
def test_load_items_seq() -> None:
|
||||
"""
|
||||
Test loading file paths and labels from a datafrome if
|
||||
"""
|
||||
csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META
|
||||
S1,0,foo.nii,,0,0,M1
|
||||
S1,1,,True,1.1,1.2,M2
|
||||
S2,1,bar.nii,False,2.1,2.2,
|
||||
""")
|
||||
df = pd.read_csv(csv_string, sep=",", dtype=str)
|
||||
items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource](
|
||||
data_frame=df,
|
||||
image_channels=None,
|
||||
image_file_column="path",
|
||||
label_channels=None,
|
||||
label_value_column="value",
|
||||
numerical_columns=["scalar1", "scalar2"],
|
||||
sequence_column="seq").load_data_sources()
|
||||
|
||||
assert len(items) == 3
|
||||
assert isinstance(items[0].metadata, GeneralSampleMetadata)
|
||||
assert items[0].metadata.id == "S1"
|
||||
assert items[0].metadata.props == {"META": "M1"}
|
||||
assert items[0].metadata.sequence_position == 0
|
||||
assert len(items[0].label.tolist()) == 1
|
||||
assert math.isnan(items[0].label.item())
|
||||
assert items[0].channel_files == ["foo.nii"]
|
||||
assert_tensors_equal(items[0].numerical_non_image_features, [0.0, 0.0])
|
||||
assert isinstance(items[1].metadata, GeneralSampleMetadata)
|
||||
assert items[1].metadata.id == "S1"
|
||||
assert items[1].metadata.props == {"META": "M2"}
|
||||
assert items[1].metadata.sequence_position == 1
|
||||
assert_tensors_equal(items[1].label, [1.0])
|
||||
assert items[1].channel_files == ['']
|
||||
assert_tensors_equal(items[1].numerical_non_image_features, [1.1, 1.2])
|
||||
assert isinstance(items[2].metadata, GeneralSampleMetadata)
|
||||
assert items[2].metadata.id == "S2"
|
||||
assert items[2].metadata.props == {"META": ''}
|
||||
assert items[2].metadata.sequence_position == 1
|
||||
assert_tensors_equal(items[2].label, [0.0])
|
||||
assert items[2].channel_files == ["bar.nii"]
|
||||
assert_tensors_equal(items[2].numerical_non_image_features, [2.1, 2.2])
|
||||
|
||||
|
||||
def test_load_items_seq_from_dataset() -> None:
|
||||
"""
|
||||
Test loading a sequence dataset with numerical, categorical features and images.
|
||||
"""
|
||||
dummy_dataset = full_ml_test_data_path() / "sequence_data_for_classification" / "dataset.csv"
|
||||
df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
|
||||
items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource](
|
||||
data_frame=df,
|
||||
image_channels=None,
|
||||
image_file_column="IMG",
|
||||
label_channels=None,
|
||||
label_value_column="Label",
|
||||
numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
|
||||
sequence_column="Position").load_data_sources()
|
||||
assert len(items) == 3 * 9 # 3 subjects, 9 visits each, no missing
|
||||
assert items[0].metadata.id == "2137.00005"
|
||||
assert items[0].metadata.sequence_position == 0
|
||||
assert items[0].metadata.props["CAT2"] == "category_A"
|
||||
# One of the labels is missing, missing labels should be encoded as NaN
|
||||
assert math.isnan(items[0].label[0])
|
||||
assert items[0].channel_files == ["img_1"]
|
||||
assert str(items[0].numerical_non_image_features.tolist()) == str([362.0, np.nan, np.nan, 71.0])
|
||||
assert items[8].metadata.id == "2137.00005"
|
||||
assert items[8].metadata.sequence_position == 8
|
||||
assert items[8].label.tolist() == [0.0]
|
||||
assert items[8].channel_files == ['']
|
||||
assert str(items[8].numerical_non_image_features.tolist()) == str([350.0, np.nan, np.nan, 8.0])
|
||||
assert items[16].metadata.id == "2627.00001"
|
||||
assert items[16].label.tolist() == [0.0]
|
||||
assert items[16].channel_files == ["img_2"]
|
||||
assert_tensors_equal(items[16].numerical_non_image_features, [217.0, 0.0, 0.01, 153.0])
|
||||
assert items[26].metadata.id == "3250.00005"
|
||||
assert items[26].metadata.sequence_position == 8
|
||||
assert_tensors_equal(items[26].label, [0.0])
|
||||
assert items[26].channel_files == ["img_11"]
|
||||
assert_tensors_equal(items[26].numerical_non_image_features, [238.0, 0.0, 0.02, 84.0])
|
||||
|
||||
grouped = group_samples_into_sequences(
|
||||
filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
max_sequence_position_value=None))
|
||||
# There are 3 patients total, but one of them has missing measurements for all visits
|
||||
assert len(grouped) == 2
|
||||
assert grouped[0].id == "2627.00001"
|
||||
assert grouped[1].id == "3250.00005"
|
||||
# 2627.00001 has full information for weeks 0, 4, and 8
|
||||
assert len(grouped[0].items) == 3
|
||||
assert grouped[0].items[0].metadata["VISIT"] == "V1"
|
||||
assert grouped[0].items[2].metadata["VISIT"] == "VST 3"
|
||||
assert len(grouped[1].items) == 9
|
||||
assert items[16].metadata.sequence_position == 7
|
||||
|
||||
|
||||
def test_seq_dataset_loader() -> None:
|
||||
dummy_dataset = full_ml_test_data_path() / "sequence_data_for_classification" / "dataset.csv"
|
||||
df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
|
||||
dataset = SequenceDataset(
|
||||
args=SequenceModelBase(
|
||||
image_file_column="IMG",
|
||||
label_value_column="Label",
|
||||
numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
|
||||
sequence_target_positions=[8],
|
||||
sequence_column="Position",
|
||||
local_dataset=Path(),
|
||||
should_validate=False
|
||||
),
|
||||
data_frame=df
|
||||
)
|
||||
assert len(dataset) == 2
|
||||
# Patch the load_images function that well be called once we access a dataset item
|
||||
with mock.patch('InnerEye.ML.dataset.scalar_sample.load_images_and_stack',
|
||||
return_value=ImageAndSegmentations[torch.Tensor](images=torch.ones(1),
|
||||
segmentations=torch.empty(0))):
|
||||
item0 = ClassificationItemSequence(**dataset[0])
|
||||
item1 = ClassificationItemSequence(**dataset[1])
|
||||
assert item0.id == "2627.00001"
|
||||
len_2627 = 3
|
||||
assert len(item0.items) == len_2627
|
||||
assert item1.id == "3250.00005"
|
||||
len_3250 = 9
|
||||
assert len(item1.items) == len_3250
|
||||
|
||||
# Data loaders use a customized collate function, that must work with the sequences too.
|
||||
collated = collate_with_metadata([dataset[0], dataset[1]])
|
||||
assert collated["id"] == ["2627.00001", "3250.00005"]
|
||||
# All subject sequences should be turned into lists of lists.
|
||||
assert isinstance(collated["items"], list)
|
||||
assert len(collated["items"]) == 2
|
||||
assert isinstance(collated["items"][0], list)
|
||||
assert isinstance(collated["items"][1], list)
|
||||
assert len(collated["items"][0]) == len_2627
|
||||
assert len(collated["items"][1]) == len_3250
|
||||
back_to_items = ClassificationItemSequence(**collated)
|
||||
assert back_to_items.id == ["2627.00001", "3250.00005"]
|
||||
|
||||
|
||||
def test_group_items() -> None:
|
||||
"""
|
||||
Test if grouping and filtering of sequence data sets works.
|
||||
"""
|
||||
|
||||
def _create(id: str, sequence_position: int, file: Optional[str], metadata: str) -> SequenceDataSource:
|
||||
return SequenceDataSource(channel_files=[file],
|
||||
numerical_non_image_features=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor([]),
|
||||
label=torch.tensor([]),
|
||||
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
|
||||
props={"M": metadata}))
|
||||
|
||||
items = [
|
||||
_create("a", 1, "f", "a.1"),
|
||||
_create("a", 0, "f", "a.0"),
|
||||
_create("a", 4, "f", "a.4"),
|
||||
_create("b", 1, None, "b.1"),
|
||||
_create("b", 0, None, "b.0"),
|
||||
_create("c", 0, "f", "c.0"),
|
||||
_create("d", 1, "f", "d.1"),
|
||||
]
|
||||
grouped = group_samples_into_sequences(items)
|
||||
assert len(grouped) == 3
|
||||
|
||||
def assert_group(group: ClassificationItemSequence, subject: str, props: List[str]) -> None:
|
||||
assert isinstance(group, ClassificationItemSequence)
|
||||
assert group.id == subject
|
||||
assert [i.metadata.props["M"] for i in group.items] == props
|
||||
|
||||
# For subject a, item a.4 should be dropped because the consecutive sequence is only [0, 1]
|
||||
assert_group(grouped[0], "a", ["a.0", "a.1"])
|
||||
assert_group(grouped[1], "b", ["b.0", "b.1"])
|
||||
assert_group(grouped[2], "c", ["c.0"])
|
||||
# Group should not contain subject d because its only item is at index 1
|
||||
|
||||
|
||||
def _create_item(id: str, sequence_position: int, metadata: str, label: Optional[float] = None) -> SequenceDataSource:
|
||||
return SequenceDataSource(channel_files=["foo"],
|
||||
numerical_non_image_features=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor([]),
|
||||
label=(torch.tensor([label]) if label else torch.tensor([])),
|
||||
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
|
||||
props={"M": metadata}))
|
||||
|
||||
|
||||
def _assert_group(group: ClassificationItemSequence, subject: str, props: List[str]) -> None:
|
||||
assert group.id == subject
|
||||
assert [i.metadata.props["M"] for i in group.items] == props
|
||||
|
||||
|
||||
def test_group_items_with_min_and_max_sequence_position_values() -> None:
|
||||
"""
|
||||
Test if grouping of sequence data works when requiring a full set of items.
|
||||
"""
|
||||
items = [
|
||||
_create_item("a", 1, "a.1"),
|
||||
_create_item("a", 0, "a.0"),
|
||||
_create_item("a", 2, "a.2"),
|
||||
_create_item("b", 1, "b.1"),
|
||||
_create_item("b", 0, "b.0"),
|
||||
]
|
||||
# When not providing a max_sequence_position_value, sequences of any length are OK.
|
||||
grouped = group_samples_into_sequences(items, max_sequence_position_value=None)
|
||||
assert len(grouped) == 2
|
||||
_assert_group(grouped[0], "a", ["a.0", "a.1", "a.2"])
|
||||
_assert_group(grouped[1], "b", ["b.0", "b.1"])
|
||||
# With a max_sequence_position_value, the set must be complete up to the given index.
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2)
|
||||
assert len(grouped) == 2
|
||||
_assert_group(grouped[0], "a", ["a.1", "a.2"])
|
||||
# When a max position is given, the sequence will be truncated to at most contain the given value.
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=0, max_sequence_position_value=1)
|
||||
assert len(grouped) == 2
|
||||
_assert_group(grouped[0], "a", ["a.0", "a.1"])
|
||||
_assert_group(grouped[1], "b", ["b.0", "b.1"])
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=1)
|
||||
assert len(grouped) == 2
|
||||
_assert_group(grouped[0], "a", ["a.1"])
|
||||
_assert_group(grouped[1], "b", ["b.1"])
|
||||
# Allow sequences upto max_sequence_position_value=2
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2)
|
||||
assert len(grouped) == 2
|
||||
_assert_group(grouped[0], "a", ["a.1", "a.2"])
|
||||
_assert_group(grouped[1], "b", ["b.1"])
|
||||
|
||||
# There are no items that have sequence position == 3, hence the next two calls should not return any items.
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=3)
|
||||
assert len(grouped) == 0
|
||||
# Check that items upto max_sequence_position_value=3 are included
|
||||
grouped = group_samples_into_sequences(items, max_sequence_position_value=3)
|
||||
assert len(grouped) == 2
|
||||
|
||||
# Sequence positions must be unique
|
||||
with pytest.raises(ValueError) as ex:
|
||||
group_samples_into_sequences([_create_item("a", 0, "a.0")] * 2)
|
||||
assert "contains duplicates" in str(ex)
|
||||
|
||||
|
||||
def test_group_items_with_label_positions() -> None:
|
||||
items = [
|
||||
_create_item("a", 0, "a.0", 1),
|
||||
_create_item("a", 3, "a.3", math.inf),
|
||||
_create_item("a", 1, "a.1", 0),
|
||||
_create_item("a", 2, "a.2", 1),
|
||||
]
|
||||
|
||||
# Extracting the sequence from 2 to 3
|
||||
grouped = group_samples_into_sequences(items, min_sequence_position_value=2, max_sequence_position_value=3)
|
||||
assert len(grouped) == 1
|
||||
_assert_group(grouped[0], "a", ["a.2", 'a.3'])
|
||||
|
||||
|
||||
def test_filter_valid_items() -> None:
|
||||
"""
|
||||
Test if filtering of sequence data sets works.
|
||||
"""
|
||||
|
||||
def _create(id: str, sequence_position: int, file: Optional[str], metadata: str) -> SequenceDataSource:
|
||||
return SequenceDataSource(channel_files=[file],
|
||||
numerical_non_image_features=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor([]),
|
||||
label=torch.tensor([]),
|
||||
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
|
||||
props={"M": metadata}))
|
||||
|
||||
items = [
|
||||
_create("a", 1, "f1", "a.1"), # Valid item
|
||||
_create("b", 0, None, "b.0"), # Invalid because no file
|
||||
_create("b", 1, "d", "b.1"), # valid
|
||||
_create("c", 0, "f3", "c.0"), # valid item for subject "c"
|
||||
]
|
||||
|
||||
def assert_items(filtered: List[SequenceDataSource], props: List[str]) -> None:
|
||||
assert [i.metadata.props["M"] for i in filtered] == props
|
||||
|
||||
# Standard filtering should remove items with missing file name only, that is b.0
|
||||
filtered1 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
max_sequence_position_value=None)
|
||||
assert_items(filtered1, ["a.1", "b.1", "c.0"])
|
||||
|
||||
# Filtering also for max_sequence_position_value
|
||||
filtered2 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
max_sequence_position_value=1)
|
||||
assert_items(filtered2, ["a.1", "b.1", "c.0"])
|
||||
filtered3 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
max_sequence_position_value=0)
|
||||
assert_items(filtered3, ["c.0"])
|
||||
|
||||
# Filtering also for min_sequence_position_value
|
||||
filtered4 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
min_sequence_position_value=1,
|
||||
max_sequence_position_value=None)
|
||||
assert_items(filtered4, ["a.1", "b.1"])
|
||||
|
||||
filtered5 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
|
||||
min_sequence_position_value=2,
|
||||
max_sequence_position_value=None)
|
||||
assert_items(filtered5, [])
|
||||
|
||||
# Now also filter by file name mapping: only "d" is in the mapping, hence only b.1 should survive
|
||||
file_mapping = {"d": Path("d"), "foo": Path("bar")}
|
||||
filtered4 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=file_mapping,
|
||||
max_sequence_position_value=1)
|
||||
assert_items(filtered4, ["b.1"])
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def test_sequence_dataloader() -> None:
|
||||
"""
|
||||
Test if we can create a data loader from the dataset, and recover the items as expected in batched form.
|
||||
Including instances where not all elements of the sequence have labels.
|
||||
"""
|
||||
csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META
|
||||
S1,0,foo.nii,,0,0,M1
|
||||
S1,1,,True,1.1,1.2,M2
|
||||
S2,0,bar.nii,False,2.1,2.2,M3
|
||||
S2,1,,False,2.0,2.0,M4
|
||||
""")
|
||||
df = pd.read_csv(csv_string, sep=",", dtype=str)
|
||||
config = SequenceModelBase(
|
||||
image_file_column=None,
|
||||
label_value_column="value",
|
||||
numerical_columns=["scalar1"],
|
||||
sequence_target_positions=[1],
|
||||
sequence_column="seq",
|
||||
local_dataset=Path.cwd(),
|
||||
should_validate=False
|
||||
)
|
||||
dataset = SequenceDataset(config, data_frame=df)
|
||||
assert len(dataset) == 2
|
||||
data_loader = dataset.as_data_loader(shuffle=False, batch_size=2, num_dataload_workers=0)
|
||||
# We have 2 subjects, with a batch size of 2 those should be turned into 1 batch
|
||||
data_loader_output = list(i for i in data_loader)
|
||||
assert len(data_loader_output) == 1
|
||||
loaded = list(ClassificationItemSequence(**i) for i in data_loader_output)
|
||||
assert loaded[0].id == ["S1", "S2"]
|
||||
assert isinstance(loaded[0].items[0][0], ScalarItem)
|
||||
assert loaded[0].items[0][0].metadata.id == "S1"
|
||||
assert loaded[0].items[0][1].metadata.id == "S1"
|
||||
assert loaded[0].items[1][0].metadata.id == "S2"
|
||||
assert loaded[0].items[1][1].metadata.id == "S2"
|
||||
|
||||
# The batched sequence data are awkward to work with. Check if we can un-roll them correctly via
|
||||
# from_minibatch
|
||||
un_batched = ClassificationItemSequence.from_minibatch(data_loader_output[0])
|
||||
assert len(un_batched) == 2
|
||||
for i in range(2):
|
||||
assert un_batched[i].id == dataset.items[i].id
|
||||
assert len(un_batched[i].items) == len(dataset.items[i].items)
|
||||
for j in range(len(un_batched[i].items)):
|
||||
assert un_batched[i].items[j].metadata.id == dataset.items[i].items[j].metadata.id
|
||||
|
||||
|
||||
def test_standardize_features() -> None:
|
||||
"""
|
||||
Test if the non-image feature can be normalized to mean 0, std 1.
|
||||
:return:
|
||||
"""
|
||||
set_random_seed(1234)
|
||||
expected_mean = torch.tensor([[123, 2, 3], [4, 5, 6]])
|
||||
expected_std = torch.tensor([[0, 2, 3], [3, 4, 4]])
|
||||
feature_size = (2, 3)
|
||||
sequences: List[ClassificationItemSequence] = []
|
||||
for s in range(1000):
|
||||
items = []
|
||||
seq_length = torch.randint(low=3, high=6, size=(1,)).item()
|
||||
for i in range(seq_length): # type: ignore
|
||||
# All features are random Gaussian, apart from feature 0 which is constant.
|
||||
# Normalization must be able to deal with constant features when dividing by standard deviation.
|
||||
features = torch.randn(size=feature_size, dtype=torch.float32) * expected_std + expected_mean
|
||||
# Randomly put some infinite values in the vector
|
||||
features[s % 2, s % 3] = np.inf if torch.rand(1) > 0.9 else features[s % 2, s % 3]
|
||||
features[0, 0] = expected_mean[0, 0]
|
||||
item = ScalarItem(metadata=GeneralSampleMetadata(id="foo"),
|
||||
numerical_non_image_features=features,
|
||||
categorical_non_image_features=features,
|
||||
label=torch.tensor([]),
|
||||
images=torch.tensor([]),
|
||||
segmentations=torch.tensor([]))
|
||||
items.append(item)
|
||||
sequences.append(ClassificationItemSequence(id="foo", items=items))
|
||||
mean_std = FeatureStatistics.from_data_sources(sequences)
|
||||
assert mean_std.mean.shape == feature_size
|
||||
assert mean_std.std.shape == feature_size
|
||||
|
||||
assert_tensors_equal(mean_std.mean, expected_mean, 0.07)
|
||||
assert_tensors_equal(mean_std.std, expected_std, 0.07)
|
||||
|
||||
# After normalization, mean should be 0, and std should be 1.
|
||||
standardized_seq = mean_std.standardize(sequences)
|
||||
mean_std_from_standardized = FeatureStatistics.from_data_sources(standardized_seq)
|
||||
# After normalization, the mean should be 0, apart from the constant feature, which should be left untouched,
|
||||
# hence its mean is the original feature value.
|
||||
expected_mean_from_standardized = torch.zeros(feature_size)
|
||||
expected_mean_from_standardized[0, 0] = expected_mean[0, 0]
|
||||
expected_std_from_standardized = torch.ones(feature_size)
|
||||
expected_std_from_standardized[0, 0] = 0.0
|
||||
assert_tensors_equal(mean_std_from_standardized.mean, expected_mean_from_standardized, abs=1e-5)
|
||||
assert_tensors_equal(mean_std_from_standardized.std, expected_std_from_standardized, abs=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_sequence", [True, False])
|
||||
def test_standardize_features_when_singleton(is_sequence: bool) -> None:
|
||||
"""
|
||||
Test how feature standardize copes with datasets that only have 1 entry.
|
||||
"""
|
||||
numerical_features = torch.ones((1, 3))
|
||||
categorical_features = torch.tensor([[0, 1, 1], [1, 0, 0]])
|
||||
item: Union[SequenceDataSource, ScalarDataSource]
|
||||
sources: Union[ListOfSequences, List[ScalarDataSource]]
|
||||
if is_sequence:
|
||||
item = SequenceDataSource(metadata=GeneralSampleMetadata(id="foo"),
|
||||
numerical_non_image_features=numerical_features,
|
||||
categorical_non_image_features=categorical_features,
|
||||
label=torch.tensor([]),
|
||||
channel_files=[])
|
||||
sources = [ClassificationItemSequence(id="foo", items=[item])]
|
||||
mean_std = FeatureStatistics.from_data_sources(sources)
|
||||
else:
|
||||
item = ScalarDataSource(metadata=GeneralSampleMetadata(id="foo"),
|
||||
numerical_non_image_features=numerical_features,
|
||||
categorical_non_image_features=categorical_features,
|
||||
label=torch.tensor([]),
|
||||
channel_files=[])
|
||||
|
||||
sources = [item]
|
||||
mean_std = FeatureStatistics.from_data_sources(sources)
|
||||
|
||||
assert_tensors_equal(mean_std.mean, numerical_features)
|
||||
# Standard deviation can't be computed because there is only one element, hence becomes nan.
|
||||
assert torch.all(torch.isnan(mean_std.std))
|
||||
# When applying such a standardization to the sequences, they should not be changed (similar to features that
|
||||
# are constant)
|
||||
standardized_sources = mean_std.standardize(sources)
|
||||
if is_sequence:
|
||||
assert_tensors_equal(standardized_sources[0].items[0].numerical_non_image_features, numerical_features)
|
||||
assert_tensors_equal(standardized_sources[0].items[0].categorical_non_image_features, categorical_features)
|
||||
else:
|
||||
assert_tensors_equal(standardized_sources[0].numerical_non_image_features, numerical_features)
|
||||
assert_tensors_equal(standardized_sources[0].categorical_non_image_features, categorical_features)
|
||||
|
||||
|
||||
def test_add_difference_features() -> None:
|
||||
"""
|
||||
Test if we can add difference features for sequence data sets (differences from position i compared to position 0
|
||||
in the sequence)
|
||||
"""
|
||||
|
||||
def _create(features: List) -> SequenceDataSource:
|
||||
return SequenceDataSource(metadata=GeneralSampleMetadata(id="foo"),
|
||||
channel_files=[],
|
||||
label=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor([]),
|
||||
numerical_non_image_features=torch.tensor(features).float())
|
||||
|
||||
item1 = _create([[1, 2, 3], [4, 5, 6]])
|
||||
item2 = _create([[11, 22, 33], [44, 55, 66]])
|
||||
items = [ClassificationItemSequence[SequenceDataSource](id="bar", items=[item1, item2])]
|
||||
updated = add_difference_features(items, [0, 2])
|
||||
# The two difference features should be added along dimension 1 of the tensor
|
||||
assert updated[0].items[0].numerical_non_image_features.shape == (2, 5)
|
||||
# Item 0 should have differences of 0
|
||||
assert_tensors_equal(updated[0].items[0].numerical_non_image_features[:, 0:3], item1.numerical_non_image_features)
|
||||
assert_tensors_equal(updated[0].items[0].numerical_non_image_features[:, 3:5], [[0, 0], [0, 0]])
|
||||
# Item 1 should have non-zero diff, and keep the original non-image features in the first few dim
|
||||
assert_tensors_equal(updated[0].items[1].numerical_non_image_features[:, 0:3], item2.numerical_non_image_features)
|
||||
assert_tensors_equal(updated[0].items[1].numerical_non_image_features[:, 3:5], [[10, 30], [40, 60]])
|
||||
|
||||
|
||||
def test_seq_to_tensor() -> None:
|
||||
"""
|
||||
Test if we can create a tensor from a variable length sequence.
|
||||
"""
|
||||
|
||||
def _create(features: List) -> torch.Tensor:
|
||||
return ScalarItem(
|
||||
segmentations=torch.empty(0),
|
||||
metadata=GeneralSampleMetadata(id="foo"),
|
||||
images=torch.tensor([]),
|
||||
label=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor(features).float(),
|
||||
numerical_non_image_features=torch.tensor(features).float()
|
||||
).get_all_non_imaging_features()
|
||||
|
||||
item1 = _create([1, 2, 3, 4, 5, 6])
|
||||
item2 = _create([11, 22, 33])
|
||||
items = [item1, item1, item2, item1]
|
||||
stacked = sequences_to_padded_tensor(items)
|
||||
assert torch.is_tensor(stacked)
|
||||
# pad_sequence will pad the tensors to the maximum sequence length
|
||||
assert stacked.shape == (len(items), item1.numel())
|
||||
|
||||
|
||||
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Check that the sequence dataset works end-to-end, including applying the right standardization.
|
||||
"""
|
||||
csv_string = """subject,seq,value,scalar1,scalar2,META,BETA
|
||||
S1,0,False,0,0,M1,B1
|
||||
S1,1,True,1,10,M2,B2
|
||||
S2,0,False,2,20,M2,B1
|
||||
S3,0,True,3,30,M1,B1
|
||||
S4,0,True,4,40,M2,B1
|
||||
"""
|
||||
csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir)
|
||||
config = SequenceModelBase(
|
||||
local_dataset=csv_path,
|
||||
image_file_column=None,
|
||||
label_value_column="value",
|
||||
numerical_columns=["scalar1", "scalar2"],
|
||||
sequence_target_positions=[0],
|
||||
categorical_columns=["META", "BETA"],
|
||||
sequence_column="seq",
|
||||
num_dataload_workers=0,
|
||||
train_batch_size=2,
|
||||
should_validate=False,
|
||||
shuffle=False
|
||||
)
|
||||
config.read_dataset_if_needed()
|
||||
df = config.dataset_data_frame
|
||||
assert df is not None
|
||||
df1 = df[df.subject.isin(["S1", "S2"])]
|
||||
df2 = df[df.subject == "S3"]
|
||||
df3 = df[df.subject == "S4"]
|
||||
splits = DatasetSplits(train=df1, val=df2, test=df3)
|
||||
with mock.patch.object(SequenceModelBase,
|
||||
'get_model_train_test_dataset_splits',
|
||||
return_value=splits):
|
||||
train_val_loaders = config.create_data_loaders()
|
||||
# Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10)
|
||||
# Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10)
|
||||
feature_stats = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN).feature_statistics
|
||||
assert feature_stats is not None
|
||||
assert_tensors_equal(feature_stats.mean, [1, 10])
|
||||
assert_tensors_equal(feature_stats.std, [1, 10])
|
||||
|
||||
train_items = list(ClassificationItemSequence.from_minibatch(b)
|
||||
for b in train_val_loaders[ModelExecutionMode.TRAIN])
|
||||
assert len(train_items) == 1, "2 items in training set with batch size of 2 should return 1 minibatch"
|
||||
assert len(train_items[0]) == 2
|
||||
assert train_items[0][0].id == "S1"
|
||||
assert_tensors_equal(train_items[0][0].items[0].get_all_non_imaging_features(), [-1., -1., 1., 0., 1., 0.])
|
||||
assert_tensors_equal(train_items[0][0].items[1].get_all_non_imaging_features(), [0., 0., 0., 1., 0., 1.])
|
||||
assert train_items[0][1].id == "S2"
|
||||
assert_tensors_equal(train_items[0][1].items[0].get_all_non_imaging_features(), [1., 1., 0., 1., 1., 0.])
|
||||
val_items = list(ClassificationItemSequence.from_minibatch(b)
|
||||
for b in train_val_loaders[ModelExecutionMode.VAL])
|
||||
assert len(val_items) == 1
|
||||
assert len(val_items[0]) == 1
|
||||
assert val_items[0][0].id == "S3"
|
||||
# Items in the validation set should be normalized using the mean and std on the training data.
|
||||
# Hence, the non-image features (3, 30) should turn into (2, 2)
|
||||
assert_tensors_equal(val_items[0][0].items[0].get_all_non_imaging_features(), [2., 2., 1., 0., 1., 0.])
|
||||
|
||||
# Check that the test set is also normalized correctly using the training mean and std.
|
||||
test_items = list(ClassificationItemSequence(**b)
|
||||
for b in config.get_torch_dataset_for_inference(ModelExecutionMode.TEST))
|
||||
assert test_items[0].id == "S4"
|
||||
# Check Non-image features of (4, 40)
|
||||
assert_tensors_equal(test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])
|
||||
|
||||
|
||||
def test_get_class_counts(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models that predicts at multiple time points,
|
||||
when it is started via run_ml.
|
||||
"""
|
||||
dataset_contents = _get_multi_label_sequence_dataframe()
|
||||
config = ToyMultiLabelSequenceModel(should_validate=False)
|
||||
assert config.get_target_indices() == [1, 2, 3]
|
||||
expected_prediction_targets = ["Seq_pos 01", "Seq_pos 02", "Seq_pos 03"]
|
||||
assert len(config.get_target_indices()) == len(expected_prediction_targets) # type: ignore
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = dataset_contents
|
||||
config.pre_process_dataset_dataframe()
|
||||
splits = config.get_dataset_splits()
|
||||
train_dataset = config.create_torch_datasets(splits)[ModelExecutionMode.TRAIN]
|
||||
class_counts = train_dataset.get_class_counts()
|
||||
assert class_counts == {0: 2}
|
||||
|
||||
|
||||
def test_get_labels_at_target_indices() -> None:
|
||||
"""
|
||||
Test to ensure label selection based on target indices is as expected
|
||||
"""
|
||||
sequence_items = _create_scalar_items(length=3)
|
||||
|
||||
sequence = ClassificationItemSequence(id="A", items=sequence_items)
|
||||
|
||||
# since label at sequence position 3 will not exist, we expect the result tensor to be padded with a nan
|
||||
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1, 2, 3])
|
||||
assert torch.allclose(labels, torch.tensor([[1.0], [1.0], [1.0], [np.nan]]), equal_nan=True)
|
||||
|
||||
# test we can extract all of the labels in the sequence
|
||||
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1, 2])
|
||||
assert torch.equal(labels, torch.tensor([[1.0], [1.0], [1.0]]))
|
||||
|
||||
# test we can extract only a subset of the labels in the sequence
|
||||
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1])
|
||||
assert torch.equal(labels, torch.tensor([[1.0], [1.0]]))
|
||||
|
||||
# test we raise an exception for invalid target indices
|
||||
with pytest.raises(Exception):
|
||||
sequence.get_labels_at_target_indices(target_indices=[-1])
|
||||
|
||||
|
||||
def test_create_labels_tensor_for_minibatch() -> None:
|
||||
"""
|
||||
Test to make sure labels tensor is created as expected for minibatch
|
||||
"""
|
||||
|
||||
sequences = [ClassificationItemSequence(id=x, items=_create_scalar_items(length=i + 1))
|
||||
for i, x in enumerate(["A", "B"])]
|
||||
|
||||
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0, 1, 2])
|
||||
assert torch.allclose(labels, torch.tensor([
|
||||
[[1.0], [np.nan], [np.nan]],
|
||||
[[1.0], [1.0], [np.nan]]]
|
||||
), equal_nan=True)
|
||||
|
||||
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0, 1])
|
||||
assert torch.allclose(labels, torch.tensor([
|
||||
[[1.0], [np.nan]],
|
||||
[[1.0], [1.0]]]
|
||||
), equal_nan=True)
|
||||
|
||||
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0])
|
||||
assert torch.equal(labels, torch.tensor([
|
||||
[[1.0]],
|
||||
[[1.0]]]
|
||||
))
|
||||
|
||||
|
||||
def _create_scalar_items(length: int, label_value: float = 1.0) -> List[ScalarItem]:
|
||||
return [ScalarItem(metadata=GeneralSampleMetadata(id="foo", sequence_position=x),
|
||||
numerical_non_image_features=torch.tensor([]),
|
||||
categorical_non_image_features=torch.tensor([]),
|
||||
label=torch.tensor([label_value]),
|
||||
images=torch.tensor([]),
|
||||
segmentations=torch.tensor([])) for x in range(length)]
|
|
@ -1,177 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
|
||||
|
||||
def noop_transform(x: Any) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
def _check_generator_consistency(dl: DataLoader) -> None:
|
||||
dataloader_generator = dl.generator
|
||||
bag_sampler_generator = dl.dataset.data.bag_sampler.generator # type: ignore
|
||||
assert torch.equal(dataloader_generator.get_state(),
|
||||
bag_sampler_generator.get_state())
|
||||
|
||||
|
||||
def compare_dataloaders(dl1: DataLoader, dl2: DataLoader) -> None:
|
||||
for batch1, batch2 in zip(dl1, dl2):
|
||||
_check_generator_consistency(dl1)
|
||||
_check_generator_consistency(dl2)
|
||||
assert batch1.keys() == batch2.keys()
|
||||
for key in batch1:
|
||||
assert len(batch1[key]) == len(batch2[key])
|
||||
for item1, item2 in zip(batch1[key], batch2[key]):
|
||||
if isinstance(item1, torch.Tensor):
|
||||
assert torch.allclose(item1, item2, equal_nan=True)
|
||||
else:
|
||||
assert item1 == item2
|
||||
|
||||
|
||||
class MockTilesDataset(TilesDataset):
|
||||
TILE_X_COLUMN = TILE_Y_COLUMN = None
|
||||
TRAIN_SPLIT_LABEL = 'train'
|
||||
VAL_SPLIT_LABEL = 'val'
|
||||
TEST_SPLIT_LABEL = 'test'
|
||||
|
||||
|
||||
def generate_mock_dataset_df(n_slides: int, n_tiles: int, n_classes: int) -> pd.DataFrame:
|
||||
np.random.seed(1234)
|
||||
slide_ids = np.random.randint(n_slides, size=n_tiles)
|
||||
slide_labels = np.random.randint(n_classes, size=n_slides)
|
||||
tile_labels = slide_labels[slide_ids]
|
||||
split_labels = [MockTilesDataset.TRAIN_SPLIT_LABEL,
|
||||
MockTilesDataset.VAL_SPLIT_LABEL,
|
||||
MockTilesDataset.TEST_SPLIT_LABEL]
|
||||
slide_splits = np.random.choice(split_labels, size=n_slides)
|
||||
tile_splits = slide_splits[slide_ids]
|
||||
|
||||
df = pd.DataFrame()
|
||||
df[MockTilesDataset.TILE_ID_COLUMN] = np.arange(n_tiles)
|
||||
df[MockTilesDataset.SLIDE_ID_COLUMN] = slide_ids
|
||||
df[MockTilesDataset.LABEL_COLUMN] = tile_labels
|
||||
df[MockTilesDataset.SPLIT_COLUMN] = tile_splits
|
||||
df[MockTilesDataset.IMAGE_COLUMN] = [f"{tile_splits[i]}/{i:06d}.png" for i in range(n_tiles)]
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class MockTilesDataModule(TilesDataModule):
|
||||
def get_splits(self) -> Tuple[MockTilesDataset, MockTilesDataset, MockTilesDataset]:
|
||||
df = MockTilesDataset(self.root_path).dataset_df
|
||||
df = df.reset_index()
|
||||
split_dfs = (df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TRAIN_SPLIT_LABEL],
|
||||
df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.VAL_SPLIT_LABEL],
|
||||
df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TEST_SPLIT_LABEL])
|
||||
return tuple(MockTilesDataset(self.root_path, dataset_df=split_df) # type: ignore
|
||||
for split_df in split_dfs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_dir(tmp_path: Path) -> Path:
|
||||
csv_dir = tmp_path / "mock_tiles_dataset"
|
||||
csv_dir.mkdir(exist_ok=True)
|
||||
csv_path = csv_dir / MockTilesDataset.DEFAULT_CSV_FILENAME
|
||||
if not csv_path.exists():
|
||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df = generate_mock_dataset_df(n_slides=8, n_tiles=100, n_classes=2)
|
||||
df.to_csv(csv_path, index=False)
|
||||
return csv_dir
|
||||
|
||||
def _get_datamodule(cache_mode: CacheMode, precache_location: CacheLocation,
|
||||
cache_dir_provided: bool, data_dir: Path) -> TilesDataModule:
|
||||
if (cache_mode is CacheMode.NONE and precache_location is not CacheLocation.NONE) \
|
||||
or (cache_mode is CacheMode.DISK and not cache_dir_provided) \
|
||||
or (precache_location is not CacheLocation.NONE and not cache_dir_provided):
|
||||
pytest.skip("Unsupported combination of caching arguments")
|
||||
|
||||
cache_dir = data_dir / f"datamodule_cache_{cache_mode.value}_{precache_location.value}" if cache_dir_provided else None
|
||||
|
||||
if cache_dir is not None and cache_dir.exists():
|
||||
shutil.rmtree(cache_dir)
|
||||
|
||||
return MockTilesDataModule(root_path=data_dir,
|
||||
transform=noop_transform,
|
||||
seed=0,
|
||||
batch_size=2,
|
||||
cache_mode=cache_mode,
|
||||
precache_location=precache_location,
|
||||
cache_dir=cache_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE])
|
||||
@pytest.mark.parametrize('precache_location', [CacheLocation.NONE, CacheLocation.CPU, CacheLocation.SAME])
|
||||
@pytest.mark.parametrize('cache_dir_provided', [True, False])
|
||||
def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation,
|
||||
cache_dir_provided: bool) -> None:
|
||||
# Compare two dataloaders from the same datamodule
|
||||
datamodule = _get_datamodule(cache_mode=cache_mode,
|
||||
precache_location=precache_location,
|
||||
cache_dir_provided=cache_dir_provided,
|
||||
data_dir=mock_data_dir)
|
||||
datamodule.prepare_data()
|
||||
train_dataloader = datamodule.train_dataloader()
|
||||
train_dataloader2 = datamodule.train_dataloader()
|
||||
|
||||
compare_dataloaders(train_dataloader, train_dataloader2)
|
||||
|
||||
# Compare datamodules reusing the same cache
|
||||
datamodule = _get_datamodule(cache_mode=cache_mode,
|
||||
precache_location=precache_location,
|
||||
cache_dir_provided=cache_dir_provided,
|
||||
data_dir=mock_data_dir)
|
||||
datamodule.prepare_data()
|
||||
train_dataloader = datamodule.train_dataloader()
|
||||
|
||||
reloaded_datamodule = _get_datamodule(cache_mode=cache_mode,
|
||||
precache_location=precache_location,
|
||||
cache_dir_provided=cache_dir_provided,
|
||||
data_dir=mock_data_dir)
|
||||
reloaded_datamodule.prepare_data()
|
||||
reloaded_train_dataloader = reloaded_datamodule.train_dataloader()
|
||||
|
||||
compare_dataloaders(train_dataloader, reloaded_train_dataloader)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cache_mode, precache_location, cache_dir_provided',
|
||||
[(CacheMode.DISK, CacheLocation.SAME, True),
|
||||
(CacheMode.DISK, CacheLocation.CPU, True),
|
||||
(CacheMode.MEMORY, CacheLocation.SAME, True),
|
||||
(CacheMode.MEMORY, CacheLocation.CPU, True),
|
||||
(CacheMode.MEMORY, CacheLocation.NONE, False),
|
||||
(CacheMode.NONE, CacheLocation.NONE, False)
|
||||
])
|
||||
def test_tile_id_coverage(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation,
|
||||
cache_dir_provided: bool) -> None:
|
||||
datamodule = _get_datamodule(cache_mode=cache_mode,
|
||||
precache_location=precache_location,
|
||||
cache_dir_provided=cache_dir_provided,
|
||||
data_dir=mock_data_dir)
|
||||
datamodule.prepare_data()
|
||||
train_dataset = datamodule.train_dataset
|
||||
train_dataloader = datamodule.train_dataloader()
|
||||
expected_tile_ids = set(train_dataset.dataset_df.index)
|
||||
loaded_tile_ids = set() # type: ignore
|
||||
for batch in train_dataloader:
|
||||
for stacked_bag_tile_ids in batch[train_dataset.TILE_ID_COLUMN]:
|
||||
if isinstance(stacked_bag_tile_ids, torch.Tensor):
|
||||
stacked_bag_tile_ids = stacked_bag_tile_ids.tolist()
|
||||
bag_tile_ids = set(stacked_bag_tile_ids)
|
||||
assert bag_tile_ids.isdisjoint(loaded_tile_ids), \
|
||||
f"Tile IDs already seen: {bag_tile_ids}"
|
||||
loaded_tile_ids.update(bag_tile_ids)
|
||||
assert loaded_tile_ids == expected_tile_ids
|
|
@ -1,40 +0,0 @@
|
|||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||
|
||||
HISTO_TEST_DATA_DIR = str(tests_root_directory("ML/histopathology/test_data"))
|
||||
|
||||
|
||||
class MockSlidesDataset(SlidesDataset):
|
||||
DEFAULT_CSV_FILENAME = "test_slides_dataset.csv"
|
||||
METADATA_COLUMNS = ('meta1', 'meta2')
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(root=HISTO_TEST_DATA_DIR)
|
||||
|
||||
|
||||
def test_slides_dataset() -> None:
|
||||
dataset = MockSlidesDataset()
|
||||
assert isinstance(dataset.dataset_df, pd.DataFrame)
|
||||
assert dataset.dataset_df.index.name == dataset.SLIDE_ID_COLUMN
|
||||
assert len(dataset) == len(dataset.dataset_df)
|
||||
|
||||
sample = dataset[0]
|
||||
assert isinstance(sample, dict)
|
||||
assert all(isinstance(key, SlideKey) for key in sample)
|
||||
|
||||
expected_keys = [SlideKey.SLIDE_ID, SlideKey.IMAGE, SlideKey.IMAGE_PATH, SlideKey.LABEL,
|
||||
SlideKey.METADATA]
|
||||
assert all(key in sample for key in expected_keys)
|
||||
|
||||
image_path = sample[SlideKey.IMAGE_PATH]
|
||||
assert isinstance(image_path, str)
|
||||
assert os.path.isfile(image_path)
|
||||
|
||||
metadata = sample[SlideKey.METADATA]
|
||||
assert isinstance(metadata, dict)
|
||||
assert all(meta_col in metadata for meta_col in type(dataset).METADATA_COLUMNS)
|
|
@ -1,42 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from monai.data.dataset import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.transforms import LoadTiled
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
|
||||
reason="TCGA-CRCk dataset is unavailable")
|
||||
@pytest.mark.parametrize('train', [True, False])
|
||||
def test_dataset(train: bool) -> None:
|
||||
base_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR, train=train)
|
||||
dataset = Dataset(base_dataset, transform=LoadTiled('image')) # type: ignore
|
||||
|
||||
expected_length = 93408 if train else 98904
|
||||
assert len(dataset) == expected_length
|
||||
|
||||
sample = dataset[0]
|
||||
expected_keys = ['slide_id', 'tile_id', 'image', 'split', 'label']
|
||||
assert all(key in sample for key in expected_keys)
|
||||
assert isinstance(sample['image'], torch.Tensor)
|
||||
assert sample['image'].shape == (3, 224, 224)
|
||||
|
||||
batch_size = 16
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # type: ignore
|
||||
batch = next(iter(loader))
|
||||
assert all(key in batch for key in expected_keys)
|
||||
assert isinstance(batch['image'], torch.Tensor)
|
||||
assert batch['image'].shape == (batch_size, 3, 224, 224)
|
||||
assert batch['image'].dtype == torch.float32
|
||||
assert batch['label'].shape == (batch_size,)
|
||||
assert batch['label'].dtype == torch.int64
|
|
@ -1,379 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from torchmetrics import Accuracy, Metric # noqa
|
||||
from torchvision.models import resnet18
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
AttentionLayer,
|
||||
GatedAttentionLayer,
|
||||
MeanPoolingLayer,
|
||||
)
|
||||
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import (
|
||||
DeepSMILECrck,
|
||||
)
|
||||
from InnerEye.ML.configs.histo_configs.classification.DeepSMILEPanda import (
|
||||
DeepSMILEPanda,
|
||||
)
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.default_paths import (
|
||||
TCGA_CRCK_DATASET_DIR,
|
||||
PANDA_TILES_DATASET_DIR,
|
||||
)
|
||||
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
|
||||
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
|
||||
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey
|
||||
|
||||
|
||||
def get_supervised_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
||||
|
||||
|
||||
def _test_lightningmodule(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
|
||||
assert n_classes > 0
|
||||
|
||||
# hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
|
||||
encoder = get_supervised_imagenet_encoder()
|
||||
module = DeepMILModule(
|
||||
encoder=encoder,
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=pooling_layer,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
|
||||
bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])
|
||||
bag_labels_list = []
|
||||
bag_logits_list = []
|
||||
bag_attn_list = []
|
||||
for bag in bag_images:
|
||||
if n_classes > 1:
|
||||
labels = randint(n_classes, size=(max_bag_size,))
|
||||
else:
|
||||
labels = randint(n_classes + 1, size=(max_bag_size,))
|
||||
bag_labels_list.append(module.get_bag_label(labels))
|
||||
logit, attn = module(bag)
|
||||
assert logit.shape == (1, n_classes)
|
||||
assert attn.shape == (module.pool_out_dim, max_bag_size)
|
||||
bag_logits_list.append(logit.view(-1))
|
||||
bag_attn_list.append(attn)
|
||||
|
||||
bag_logits = stack(bag_logits_list)
|
||||
bag_labels = stack(bag_labels_list).view(-1)
|
||||
|
||||
assert bag_logits.shape[0] == (batch_size)
|
||||
assert bag_labels.shape[0] == (batch_size)
|
||||
|
||||
if module.n_classes > 1:
|
||||
loss = module.loss_fn(bag_logits, bag_labels)
|
||||
else:
|
||||
loss = module.loss_fn(bag_logits.squeeze(1), bag_labels.float())
|
||||
|
||||
assert loss > 0
|
||||
assert loss.shape == ()
|
||||
|
||||
probs = module.activation_fn(bag_logits)
|
||||
assert ((probs >= 0) & (probs <= 1)).all()
|
||||
if n_classes > 1:
|
||||
assert probs.shape == (batch_size, n_classes)
|
||||
else:
|
||||
assert probs.shape[0] == batch_size
|
||||
|
||||
if n_classes > 1:
|
||||
preds = argmax(probs, dim=1)
|
||||
else:
|
||||
preds = round(probs)
|
||||
assert preds.shape[0] == batch_size
|
||||
|
||||
for metric_name, metric_object in module.train_metrics.items():
|
||||
if metric_name == MetricsKey.CONF_MATRIX:
|
||||
continue
|
||||
score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1))
|
||||
assert torch.all(score >= 0)
|
||||
assert torch.all(score <= 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_attention(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=pooling_layer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_mean_pooling(
|
||||
n_classes: int,
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=MeanPoolingLayer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=1,
|
||||
pool_out_dim=1,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
|
||||
def is_integral(x: torch.Tensor) -> bool:
|
||||
return (x == x.long()).all() # type: ignore
|
||||
|
||||
assert scores.shape == labels.shape
|
||||
assert torch.is_floating_point(scores), "Received scores with integer dtype"
|
||||
assert not is_integral(scores), "Received scores with integral values"
|
||||
assert is_integral(labels), "Received labels with floating-point values"
|
||||
|
||||
|
||||
def add_callback(fn: Callable, callback: Callable) -> Callable:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
callback(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_metrics() -> None:
|
||||
input_dim = (128,)
|
||||
module = DeepMILModule(
|
||||
encoder=IdentityEncoder(input_dim=input_dim),
|
||||
label_column=TilesDataset.LABEL_COLUMN,
|
||||
n_classes=1,
|
||||
pooling_layer=AttentionLayer,
|
||||
)
|
||||
|
||||
# Patching to enable running the module without a Trainer object
|
||||
module.trainer = MagicMock(world_size=1) # type: ignore
|
||||
module.log = MagicMock() # type: ignore
|
||||
|
||||
batch_size = 20
|
||||
bag_size = 5
|
||||
class_weights = torch.tensor([.8, .2])
|
||||
bags: List[Dict] = []
|
||||
for slide_idx in range(batch_size):
|
||||
bag_label = torch.multinomial(class_weights, 1)
|
||||
sample: Dict[str, Iterable] = {
|
||||
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
|
||||
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
|
||||
for tile_idx in range(bag_size)],
|
||||
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
|
||||
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
|
||||
}
|
||||
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
|
||||
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
|
||||
bags.append(sample)
|
||||
batch = default_collate(bags)
|
||||
|
||||
# ================
|
||||
# Test that the module metrics match manually computed metrics with the correct inputs
|
||||
module_metrics_dict = module.test_metrics
|
||||
independent_metrics_dict = module.get_metrics()
|
||||
|
||||
# Patch the metrics to check that the inputs are valid. In particular, test that the scores
|
||||
# do not have integral values, which would suggest that hard labels were passed instead.
|
||||
for metric_obj in module_metrics_dict.values():
|
||||
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)
|
||||
|
||||
results = module.test_step(batch, 0)
|
||||
predicted_probs = results[ResultsKey.PROB]
|
||||
true_labels = results[ResultsKey.TRUE_LABEL]
|
||||
|
||||
for key, metric_obj in module_metrics_dict.items():
|
||||
value = metric_obj.compute()
|
||||
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
|
||||
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"
|
||||
|
||||
# ================
|
||||
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
|
||||
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
|
||||
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
|
||||
if hasattr(metric, 'threshold')]
|
||||
|
||||
def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
|
||||
for key in thresholded_metrics_keys:
|
||||
metrics_dict[key].threshold = threshold
|
||||
|
||||
def reset_metrics(metrics_dict: Any) -> None:
|
||||
for metric_obj in metrics_dict.values():
|
||||
metric_obj.reset()
|
||||
|
||||
low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))
|
||||
|
||||
reset_metrics(module_metrics_dict)
|
||||
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
|
||||
_ = module.test_step(batch, 0)
|
||||
results_low_threshold = {key: module_metrics_dict[key].compute()
|
||||
for key in thresholded_metrics_keys}
|
||||
|
||||
reset_metrics(module_metrics_dict)
|
||||
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
|
||||
_ = module.test_step(batch, 0)
|
||||
results_high_threshold = {key: module_metrics_dict[key].compute()
|
||||
for key in thresholded_metrics_keys}
|
||||
|
||||
for key in thresholded_metrics_keys:
|
||||
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
|
||||
f"Got same value for '{key}' metric with low and high thresholds"
|
||||
|
||||
|
||||
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
return {
|
||||
key: [
|
||||
value.to(device) if isinstance(value, Tensor) else value for value in values
|
||||
]
|
||||
for key, values in batch.items()
|
||||
}
|
||||
|
||||
|
||||
CONTAINER_DATASET_DIR = {
|
||||
DeepSMILEPanda: PANDA_TILES_DATASET_DIR,
|
||||
DeepSMILECrck: TCGA_CRCK_DATASET_DIR,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("container_type", [DeepSMILEPanda,
|
||||
DeepSMILECrck])
|
||||
@pytest.mark.parametrize("use_gpu", [True, False])
|
||||
def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> None:
|
||||
dataset_dir = CONTAINER_DATASET_DIR[container_type]
|
||||
if not os.path.isdir(dataset_dir):
|
||||
pytest.skip(
|
||||
f"Dataset for container {container_type.__name__} "
|
||||
f"is unavailable: {dataset_dir}"
|
||||
)
|
||||
if container_type is DeepSMILECrck:
|
||||
container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__)
|
||||
elif container_type is DeepSMILEPanda:
|
||||
container = DeepSMILEPanda(encoder_type=ImageNetEncoder.__name__)
|
||||
else:
|
||||
container = container_type()
|
||||
|
||||
container.setup()
|
||||
|
||||
data_module: TilesDataModule = container.get_data_module() # type: ignore
|
||||
data_module.max_bag_size = 10
|
||||
module = container.create_model()
|
||||
if use_gpu:
|
||||
module.cuda()
|
||||
|
||||
train_data_loader = data_module.train_dataloader()
|
||||
for batch_idx, batch in enumerate(train_data_loader):
|
||||
batch = move_batch_to_expected_device(batch, use_gpu)
|
||||
loss = module.training_step(batch, batch_idx)
|
||||
loss.retain_grad()
|
||||
loss.backward()
|
||||
assert loss.grad is not None
|
||||
assert loss.shape == ()
|
||||
assert isinstance(loss, Tensor)
|
||||
break
|
||||
|
||||
val_data_loader = data_module.val_dataloader()
|
||||
for batch_idx, batch in enumerate(val_data_loader):
|
||||
batch = move_batch_to_expected_device(batch, use_gpu)
|
||||
loss = module.validation_step(batch, batch_idx)
|
||||
assert loss.shape == () # noqa
|
||||
assert isinstance(loss, Tensor)
|
||||
break
|
||||
|
||||
test_data_loader = data_module.test_dataloader()
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
batch = move_batch_to_expected_device(batch, use_gpu)
|
||||
outputs_dict = module.test_step(batch, batch_idx)
|
||||
loss = outputs_dict[ResultsKey.LOSS] # noqa
|
||||
assert loss.shape == ()
|
||||
assert isinstance(loss, Tensor)
|
||||
break
|
||||
|
||||
|
||||
def test_class_weights_binary() -> None:
|
||||
class_weights = Tensor([0.5, 3.5])
|
||||
n_classes = 1
|
||||
module = DeepMILModule(
|
||||
encoder=get_supervised_imagenet_encoder(),
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=AttentionLayer,
|
||||
pool_hidden_dim=5,
|
||||
pool_out_dim=1,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
logits = Tensor(randn(1, n_classes))
|
||||
bag_label = randint(n_classes + 1, size=(1,))
|
||||
|
||||
pos_weight = Tensor([class_weights[1] / (class_weights[0] + 1e-5)])
|
||||
loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float())
|
||||
criterion_unweighted = nn.BCEWithLogitsLoss()
|
||||
loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float())
|
||||
if bag_label.item() == 1:
|
||||
assert allclose(loss_weighted, pos_weight * loss_unweighted)
|
||||
else:
|
||||
assert allclose(loss_weighted, loss_unweighted)
|
||||
|
||||
|
||||
def test_class_weights_multiclass() -> None:
|
||||
class_weights = Tensor([0.33, 0.33, 0.33])
|
||||
n_classes = 3
|
||||
module = DeepMILModule(
|
||||
encoder=get_supervised_imagenet_encoder(),
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=AttentionLayer,
|
||||
pool_hidden_dim=5,
|
||||
pool_out_dim=1,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
logits = Tensor(randn(1, n_classes))
|
||||
bag_label = randint(n_classes, size=(1,))
|
||||
|
||||
loss_weighted = module.loss_fn(logits, bag_label)
|
||||
criterion_unweighted = nn.CrossEntropyLoss()
|
||||
loss_unweighted = criterion_unweighted(logits, bag_label)
|
||||
# The weighted and unweighted loss functions give the same loss values for batch_size = 1.
|
||||
# https://stackoverflow.com/questions/67639540/pytorch-cross-entropy-loss-weights-not-working
|
||||
# TODO: the test should reflect actual weighted loss operation for the class weights after batch_size > 1 is implemented.
|
||||
assert allclose(loss_weighted, loss_unweighted)
|
|
@ -1,70 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from torch import Tensor, float32, nn, rand
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
|
||||
ImageNetSimCLREncoder)
|
||||
from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor
|
||||
|
||||
TILE_SIZE = 224
|
||||
INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE)
|
||||
|
||||
|
||||
def get_supervised_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def get_simclr_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetSimCLREncoder(tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def get_histo_ssl_encoder() -> TileEncoder:
|
||||
return HistoSSLEncoder(tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int,
|
||||
batch_size: int = 5) -> None:
|
||||
if isinstance(encoder, nn.Module):
|
||||
for param_name, param in encoder.named_parameters():
|
||||
assert not param.requires_grad, \
|
||||
f"Feature extractor has unfrozen parameters: {param_name}"
|
||||
|
||||
images = rand(batch_size, *input_dims, dtype=float32)
|
||||
|
||||
features = encoder(images)
|
||||
assert isinstance(features, Tensor)
|
||||
assert features.shape == (batch_size, output_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
|
||||
get_simclr_imagenet_encoder,
|
||||
get_histo_ssl_encoder])
|
||||
def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None:
|
||||
encoder = create_encoder_fn()
|
||||
_test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding)
|
||||
|
||||
|
||||
def _dummy_classifier() -> nn.Module:
|
||||
input_size = np.prod(INPUT_DIMS)
|
||||
hidden_dim = 10
|
||||
return nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(input_size, hidden_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier])
|
||||
def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None:
|
||||
classifier = create_classifier_fn()
|
||||
encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS)
|
||||
_test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features)
|
|
@ -1,210 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from monai.transforms import Compose
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset
|
||||
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
|
||||
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
|
||||
from Tests.ML.util import assert_dicts_equal
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
|
||||
reason="TCGA-CRCk tiles dataset is unavailable")
|
||||
def test_load_tile() -> None:
|
||||
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
|
||||
image_key = tiles_dataset.IMAGE_COLUMN
|
||||
load_transform = LoadTiled(image_key)
|
||||
index = 0
|
||||
|
||||
# Test that the transform affects only the image entry in the sample
|
||||
input_sample = tiles_dataset[index]
|
||||
loaded_sample = load_transform(input_sample)
|
||||
assert_dicts_equal(loaded_sample, input_sample, exclude_keys=[image_key])
|
||||
|
||||
# Test that the MONAI Dataset applies the same transform
|
||||
loaded_dataset = Dataset(tiles_dataset, transform=load_transform) # type:ignore
|
||||
same_dataset_sample = loaded_dataset[index]
|
||||
assert_dicts_equal(same_dataset_sample, loaded_sample)
|
||||
|
||||
# Test that loading another sample gives different results
|
||||
different_sample = loaded_dataset[index + 1]
|
||||
assert not torch.allclose(different_sample[image_key], loaded_sample[image_key])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
|
||||
reason="TCGA-CRCk tiles dataset is unavailable")
|
||||
def test_load_tiles_batch() -> None:
|
||||
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
|
||||
image_key = tiles_dataset.IMAGE_COLUMN
|
||||
max_bag_size = 5
|
||||
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
|
||||
max_bag_size=max_bag_size)
|
||||
load_batch_transform = LoadTilesBatchd(image_key)
|
||||
loaded_dataset = Dataset(tiles_dataset, transform=LoadTiled(image_key)) # type:ignore
|
||||
image_shape = loaded_dataset[0][image_key].shape
|
||||
index = 0
|
||||
|
||||
# Test that the transform affects only the image entry in the batch,
|
||||
# and that the loaded images have the expected shape
|
||||
bagged_batch = bagged_dataset[index]
|
||||
manually_loaded_batch = load_batch_transform(bagged_batch)
|
||||
assert_dicts_equal(manually_loaded_batch, bagged_batch, exclude_keys=[image_key])
|
||||
assert manually_loaded_batch[image_key].shape == (max_bag_size, *image_shape)
|
||||
|
||||
# Test that the MONAI Dataset applies the same transform
|
||||
loaded_bagged_dataset = Dataset(bagged_dataset, transform=load_batch_transform) # type:ignore
|
||||
loaded_bagged_batch = loaded_bagged_dataset[index]
|
||||
assert_dicts_equal(loaded_bagged_batch, manually_loaded_batch)
|
||||
|
||||
# Test that loading another batch gives different results
|
||||
different_batch = loaded_bagged_dataset[index + 1]
|
||||
assert not torch.allclose(different_batch[image_key], manually_loaded_batch[image_key])
|
||||
|
||||
# Test that loading and bagging commute
|
||||
bagged_loaded_dataset = BagDataset(loaded_dataset, # type: ignore
|
||||
bag_ids=tiles_dataset.slide_ids,
|
||||
max_bag_size=max_bag_size)
|
||||
bagged_loaded_batch = bagged_loaded_dataset[index]
|
||||
assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch)
|
||||
|
||||
|
||||
def _test_cache_and_persistent_datasets(tmp_path: Path,
|
||||
base_dataset: TorchDataset,
|
||||
transform: Union[Sequence[Callable], Callable],
|
||||
cache_subdir: str) -> None:
|
||||
default_dataset = Dataset(base_dataset, transform=transform) # type: ignore
|
||||
cached_dataset = CacheDataset(base_dataset, transform=transform) # type: ignore
|
||||
cache_dir = tmp_path / cache_subdir
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
persistent_dataset = PersistentDataset(base_dataset, transform=transform, # type: ignore
|
||||
cache_dir=cache_dir)
|
||||
|
||||
for default_sample, cached_sample, persistent_sample \
|
||||
in zip(default_dataset, cached_dataset, persistent_dataset): # type: ignore
|
||||
assert_dicts_equal(cached_sample, default_sample)
|
||||
assert_dicts_equal(persistent_sample, default_sample)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
|
||||
reason="TCGA-CRCk tiles dataset is unavailable")
|
||||
def test_cached_loading(tmp_path: Path) -> None:
|
||||
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
|
||||
image_key = tiles_dataset.IMAGE_COLUMN
|
||||
|
||||
max_num_tiles = 100
|
||||
tiles_subset = Subset(tiles_dataset, range(max_num_tiles))
|
||||
_test_cache_and_persistent_datasets(tmp_path,
|
||||
tiles_subset,
|
||||
transform=LoadTiled(image_key),
|
||||
cache_subdir="TCGA-CRCk_tiles_cache")
|
||||
|
||||
max_bag_size = 5
|
||||
max_num_bags = max_num_tiles // max_bag_size
|
||||
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
|
||||
max_bag_size=max_bag_size)
|
||||
bagged_subset = Subset(bagged_dataset, range(max_num_bags))
|
||||
_test_cache_and_persistent_datasets(tmp_path,
|
||||
bagged_subset,
|
||||
transform=LoadTilesBatchd(image_key),
|
||||
cache_subdir="TCGA-CRCk_load_cache")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
|
||||
reason="TCGA-CRCk tiles dataset is unavailable")
|
||||
@pytest.mark.parametrize('use_gpu , chunk_size',
|
||||
[(False, 0), (False, 2), (True, 0), (True, 2)]
|
||||
)
|
||||
def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
|
||||
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
|
||||
image_key = tiles_dataset.IMAGE_COLUMN
|
||||
max_bag_size = 5
|
||||
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
|
||||
max_bag_size=max_bag_size)
|
||||
|
||||
encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3)
|
||||
if use_gpu:
|
||||
encoder.cuda()
|
||||
|
||||
encode_transform = EncodeTilesBatchd(image_key, encoder, chunk_size=chunk_size)
|
||||
transform = Compose([LoadTilesBatchd(image_key), encode_transform])
|
||||
dataset = Dataset(bagged_dataset, transform=transform) # type: ignore
|
||||
sample = dataset[0]
|
||||
assert sample[image_key].shape == (max_bag_size, encoder.num_encoding)
|
||||
# TODO: Ensure it works in DDP
|
||||
|
||||
max_num_bags = 20
|
||||
bagged_subset = Subset(bagged_dataset, range(max_num_bags))
|
||||
_test_cache_and_persistent_datasets(tmp_path,
|
||||
bagged_subset,
|
||||
transform=transform,
|
||||
cache_subdir="TCGA-CRCk_embed_cache")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('include_non_indexable', [True, False])
|
||||
@pytest.mark.parametrize('allow_missing_keys', [True, False])
|
||||
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
|
||||
batch_size = 5
|
||||
max_size = batch_size // 2
|
||||
data = {
|
||||
'array_1d': np.random.randn(batch_size),
|
||||
'array_2d': np.random.randn(batch_size, 4),
|
||||
'tensor_1d': torch.randn(batch_size),
|
||||
'tensor_2d': torch.randn(batch_size, 4),
|
||||
'list': torch.randn(batch_size).tolist(),
|
||||
'indices': list(range(batch_size)),
|
||||
'non-indexable': 42,
|
||||
}
|
||||
|
||||
keys_to_subsample = list(data.keys())
|
||||
if not include_non_indexable:
|
||||
keys_to_subsample.remove('non-indexable')
|
||||
keys_to_subsample.append('missing-key')
|
||||
|
||||
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
|
||||
allow_missing_keys=allow_missing_keys)
|
||||
|
||||
if include_non_indexable:
|
||||
with pytest.raises(ValueError):
|
||||
sub_data = subsampling(data)
|
||||
return
|
||||
elif not allow_missing_keys:
|
||||
with pytest.raises(KeyError):
|
||||
sub_data = subsampling(data)
|
||||
return
|
||||
else:
|
||||
sub_data = subsampling(data)
|
||||
|
||||
assert set(sub_data.keys()) == set(data.keys())
|
||||
|
||||
# Check lenghts before and after subsampling
|
||||
for key in keys_to_subsample:
|
||||
if key not in data:
|
||||
continue # Skip missing keys
|
||||
assert len(data[key]) == batch_size # type: ignore
|
||||
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
|
||||
|
||||
# Check contents of subsampled elements
|
||||
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||
for idx, elem in zip(sub_data['indices'], sub_data[key]):
|
||||
assert np.array_equal(elem, data[key][idx]) # type: ignore
|
||||
|
||||
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
|
||||
sub_data2 = subsampling(data)
|
||||
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore
|
|
@ -1,171 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from monai.data.image_reader import WSIReader
|
||||
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
|
||||
from InnerEye.ML.Histopathology.preprocessing.tiling import tile_array_2d
|
||||
from InnerEye.ML.Histopathology.preprocessing.loading import (LoadROId, get_luminance, load_slide_at_level,
|
||||
segment_foreground)
|
||||
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||
from Tests.ML.histopathology.datasets.test_slides_dataset import MockSlidesDataset
|
||||
|
||||
TEST_IMAGE_PATH = str(tests_root_directory("ML/histopathology/test_data/panda_wsi_example.tiff"))
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
|
||||
def test_load_slide() -> None:
|
||||
level = 2
|
||||
reader = WSIReader('cuCIM')
|
||||
from cucim import CuImage
|
||||
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||
dims = slide_obj.resolutions['level_dimensions'][level][::-1]
|
||||
|
||||
slide = load_slide_at_level(reader, slide_obj, level)
|
||||
assert isinstance(slide, np.ndarray)
|
||||
expected_shape = (3, *dims)
|
||||
assert slide.shape == expected_shape
|
||||
frac_empty = (slide == 0).mean()
|
||||
assert frac_empty == 0.0
|
||||
|
||||
larger_dims = (2 * dims[0], 2 * dims[1])
|
||||
larger_slide, _ = reader.get_data(slide_obj, size=larger_dims, level=level)
|
||||
assert isinstance(larger_slide, np.ndarray)
|
||||
assert larger_slide.shape == (3, *larger_dims)
|
||||
# Overlapping parts match exactly
|
||||
assert np.array_equal(larger_slide[:, :dims[0], :dims[1]], slide)
|
||||
# Non-overlapping parts are all empty
|
||||
empty_fill_value = 0 # fill value seems to depend on the image
|
||||
assert np.array_equiv(larger_slide[:, dims[0]:, :], empty_fill_value)
|
||||
assert np.array_equiv(larger_slide[:, :, dims[1]:], empty_fill_value)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
|
||||
def test_get_luminance() -> None:
|
||||
level = 2 # here we only need to test at a single resolution
|
||||
reader = WSIReader('cuCIM')
|
||||
from cucim import CuImage
|
||||
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||
|
||||
slide = load_slide_at_level(reader, slide_obj, level)
|
||||
slide_luminance = get_luminance(slide)
|
||||
assert isinstance(slide_luminance, np.ndarray)
|
||||
assert slide_luminance.shape == slide.shape[1:]
|
||||
assert (slide_luminance <= 255).all() and (slide_luminance >= 0).all()
|
||||
|
||||
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
|
||||
tiles_luminance = get_luminance(tiles)
|
||||
assert isinstance(tiles_luminance, np.ndarray)
|
||||
assert tiles_luminance.shape == (tiles.shape[0], *tiles.shape[2:])
|
||||
assert (tiles_luminance <= 255).all() and (tiles_luminance >= 0).all()
|
||||
|
||||
slide_luminance_tiles, _ = tile_array_2d(np.expand_dims(slide_luminance, axis=0),
|
||||
tile_size=224, constant_values=255)
|
||||
assert np.array_equal(slide_luminance_tiles.squeeze(1), tiles_luminance)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
|
||||
def test_segment_foreground() -> None:
|
||||
level = 2 # here we only need to test at a single resolution
|
||||
reader = WSIReader('cuCIM')
|
||||
from cucim import CuImage
|
||||
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||
slide = load_slide_at_level(reader, slide_obj, level)
|
||||
|
||||
auto_mask, auto_threshold = segment_foreground(slide, threshold=None)
|
||||
assert isinstance(auto_mask, np.ndarray)
|
||||
assert auto_mask.dtype == bool
|
||||
assert auto_mask.shape == slide.shape[1:]
|
||||
assert 0 < auto_mask.sum() < auto_mask.size # auto-seg should not produce trivial mask
|
||||
luminance = get_luminance(slide)
|
||||
assert luminance.min() < auto_threshold < luminance.max()
|
||||
|
||||
mask, returned_threshold = segment_foreground(slide, threshold=auto_threshold)
|
||||
assert isinstance(mask, np.ndarray)
|
||||
assert mask.dtype == bool
|
||||
assert mask.shape == slide.shape[1:]
|
||||
assert np.array_equal(mask, auto_mask)
|
||||
assert returned_threshold == auto_threshold
|
||||
|
||||
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
|
||||
tiles_mask, _ = segment_foreground(tiles, threshold=auto_threshold)
|
||||
assert isinstance(tiles_mask, np.ndarray)
|
||||
assert tiles_mask.dtype == bool
|
||||
assert tiles_mask.shape == (tiles.shape[0], *tiles.shape[2:])
|
||||
|
||||
slide_mask_tiles, _ = tile_array_2d(np.expand_dims(mask, axis=0),
|
||||
tile_size=224, constant_values=False)
|
||||
assert np.array_equal(slide_mask_tiles.squeeze(1), tiles_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('level', [1, 2])
|
||||
@pytest.mark.parametrize('foreground_threshold', [None, 215])
|
||||
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
|
||||
def test_get_bounding_box(level: int, foreground_threshold: Optional[float]) -> None:
|
||||
margin = 0
|
||||
reader = WSIReader('cuCIM')
|
||||
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||
foreground_threshold=foreground_threshold)
|
||||
from cucim import CuImage
|
||||
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||
level0_bbox, _ = loader._get_bounding_box(slide_obj)
|
||||
|
||||
highest_level = slide_obj.resolutions['level_count'] - 1
|
||||
# level = highest_level
|
||||
slide = load_slide_at_level(reader, slide_obj, level=level)
|
||||
scale = slide_obj.resolutions['level_downsamples'][level]
|
||||
bbox = level0_bbox / scale
|
||||
assert bbox.x >= 0 and bbox.y >= 0
|
||||
assert bbox.x + bbox.w <= slide.shape[1]
|
||||
assert bbox.y + bbox.h <= slide.shape[2]
|
||||
|
||||
# Now with nonzero margin
|
||||
margin = 42
|
||||
loader_margin = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||
foreground_threshold=foreground_threshold)
|
||||
level0_bbox_margin, _ = loader_margin._get_bounding_box(slide_obj)
|
||||
# Here we test the box differences at the highest resolution, because margin is
|
||||
# specified in low-res pixels. Otherwise could fail due to rounding error.
|
||||
level0_scale: float = slide_obj.resolutions['level_downsamples'][highest_level]
|
||||
level0_margin = int(level0_scale * margin)
|
||||
assert level0_bbox_margin.x == level0_bbox.x - level0_margin
|
||||
assert level0_bbox_margin.y == level0_bbox.y - level0_margin
|
||||
assert level0_bbox_margin.w == level0_bbox.w + 2 * level0_margin
|
||||
assert level0_bbox_margin.h == level0_bbox.h + 2 * level0_margin
|
||||
|
||||
|
||||
@pytest.mark.parametrize('level', [1, 2])
|
||||
@pytest.mark.parametrize('margin', [0, 42])
|
||||
@pytest.mark.parametrize('foreground_threshold', [None, 215])
|
||||
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
|
||||
def test_load_roi(level: int, margin: int, foreground_threshold: Optional[float]) -> None:
|
||||
dataset = MockSlidesDataset()
|
||||
sample = dataset[0]
|
||||
reader = WSIReader('cuCIM')
|
||||
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||
foreground_threshold=foreground_threshold)
|
||||
loaded_sample = loader(sample)
|
||||
assert isinstance(loaded_sample, dict)
|
||||
# Check that none of the input keys were removed
|
||||
assert all(key in loaded_sample for key in sample)
|
||||
|
||||
# Check that the expected new keys were inserted
|
||||
additional_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD]
|
||||
assert all(key in loaded_sample for key in additional_keys)
|
||||
|
||||
assert isinstance(loaded_sample[SlideKey.IMAGE], np.ndarray)
|
||||
image_shape = loaded_sample[SlideKey.IMAGE].shape
|
||||
assert len(image_shape)
|
||||
assert image_shape[0] == 3
|
||||
|
||||
origin = loaded_sample[SlideKey.ORIGIN]
|
||||
assert isinstance(origin, tuple)
|
||||
assert len(origin) == 2
|
||||
assert all(isinstance(coord, int) for coord in origin)
|
||||
|
||||
assert isinstance(loaded_sample[SlideKey.SCALE], (int, float))
|
||||
assert loaded_sample[SlideKey.SCALE] >= 1.0
|
||||
|
||||
assert isinstance(loaded_sample[SlideKey.FOREGROUND_THRESHOLD], (int, float))
|
|
@ -1,126 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from InnerEye.ML.Histopathology.preprocessing.tiling import assemble_tiles_2d, get_1d_padding, \
|
||||
pad_for_tiling_2d, tile_array_2d
|
||||
|
||||
|
||||
@pytest.mark.parametrize("length,tile_size",
|
||||
[(8, 4), (9, 4), (8, 3), (4, 4), (3, 4)])
|
||||
def test_1d_padding(length: int, tile_size: int) -> None:
|
||||
pad_pre, pad_post = get_1d_padding(length, tile_size)
|
||||
|
||||
assert pad_pre >= 0 and pad_post >= 0
|
||||
assert pad_pre < tile_size and pad_post < tile_size
|
||||
assert abs(pad_post - pad_pre) <= 1, "Asymmetric padding"
|
||||
|
||||
padded_length = pad_pre + length + pad_post
|
||||
assert padded_length % tile_size == 0
|
||||
|
||||
n_tiles = padded_length // tile_size
|
||||
expected_n_tiles = int(np.ceil(length / tile_size))
|
||||
assert n_tiles == expected_n_tiles
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width,height", [(8, 6)])
|
||||
@pytest.mark.parametrize("tile_size", [3, 4, 5])
|
||||
@pytest.mark.parametrize("channels_first", [True, False])
|
||||
def test_2d_padding(width: int, height: int, tile_size: int, channels_first: bool) -> None:
|
||||
channels = 2
|
||||
pad_value = 0
|
||||
array = np.random.rand(channels, height, width)
|
||||
|
||||
input_array = array if channels_first else array.transpose(1, 2, 0)
|
||||
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(input_array, tile_size, channels_first,
|
||||
constant_values=pad_value)
|
||||
if not channels_first:
|
||||
padded_array = padded_array.transpose(2, 0, 1)
|
||||
|
||||
padded_channels, padded_height, padded_width = padded_array.shape
|
||||
assert padded_channels == channels and padded_height >= height and padded_width >= width
|
||||
assert padded_height % tile_size == 0 and padded_width % tile_size == 0
|
||||
assert 0 <= offset_h < tile_size and 0 <= offset_w < tile_size
|
||||
|
||||
crop = padded_array[:, offset_h:offset_h + height, offset_w:offset_w + width]
|
||||
assert np.array_equal(crop, array)
|
||||
|
||||
# np.array_equiv() broadcasts the shapes
|
||||
assert np.array_equiv(padded_array[:, :offset_h, :], pad_value)
|
||||
assert np.array_equiv(padded_array[:, :, :offset_w], pad_value)
|
||||
assert np.array_equiv(padded_array[:, offset_h + height:, :], pad_value)
|
||||
assert np.array_equiv(padded_array[:, :, offset_w + width:], pad_value)
|
||||
|
||||
|
||||
def _get_2d_meshgrid(width: int, height: int, channels_first: bool = True) -> np.ndarray:
|
||||
array = np.stack(np.meshgrid(np.arange(width), np.arange(height)),
|
||||
axis=0 if channels_first else -1)
|
||||
assert array.shape == ((2, height, width) if channels_first else (height, width, 2))
|
||||
return array
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width,height", [(8, 6)])
|
||||
@pytest.mark.parametrize("tile_size", [3, 4, 5])
|
||||
@pytest.mark.parametrize("channels_first", [True, False])
|
||||
def test_tile_array_2d_both(width: int, height: int, tile_size: int, channels_first: bool) -> None:
|
||||
channels = 2
|
||||
array = _get_2d_meshgrid(width, height, channels_first)
|
||||
|
||||
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first,
|
||||
constant_values=0)
|
||||
|
||||
tiles, coords = tile_array_2d(array, tile_size, channels_first)
|
||||
assert tiles.shape[0] == coords.shape[0]
|
||||
|
||||
expected_n_tiles_w = int(np.ceil(width / tile_size))
|
||||
expected_n_tiles_h = int(np.ceil(height / tile_size))
|
||||
expected_n_tiles = expected_n_tiles_w * expected_n_tiles_h
|
||||
|
||||
if channels_first:
|
||||
assert tiles.shape == (expected_n_tiles, channels, tile_size, tile_size)
|
||||
else:
|
||||
assert tiles.shape == (expected_n_tiles, tile_size, tile_size, channels)
|
||||
assert coords.shape == (expected_n_tiles, 2)
|
||||
|
||||
for idx in range(tiles.shape[0]):
|
||||
row = coords[idx, 1] + offset_h
|
||||
col = coords[idx, 0] + offset_w
|
||||
if channels_first:
|
||||
expected_tile = padded_array[:, row:row + tile_size, col:col + tile_size]
|
||||
else:
|
||||
expected_tile = padded_array[row:row + tile_size, col:col + tile_size, :]
|
||||
assert np.array_equal(tiles[idx], expected_tile)
|
||||
|
||||
expected_x = tile_size * (idx % expected_n_tiles_w) - offset_w
|
||||
expected_y = tile_size * (idx // expected_n_tiles_w) - offset_h
|
||||
assert tuple(coords[idx]) == (expected_x, expected_y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width,height", [(8, 6)])
|
||||
@pytest.mark.parametrize("tile_size", [3, 4, 5])
|
||||
@pytest.mark.parametrize("channels_first", [True, False])
|
||||
def test_assemble_tiles_2d(width: int, height: int, tile_size: int, channels_first: bool) -> None:
|
||||
array = _get_2d_meshgrid(width, height, channels_first)
|
||||
fill_value = 0
|
||||
padded_array, padding_offset = pad_for_tiling_2d(array, tile_size, channels_first,
|
||||
constant_values=fill_value)
|
||||
|
||||
tiles, coords = tile_array_2d(array, tile_size, channels_first)
|
||||
|
||||
assembled_array, assembly_offset = assemble_tiles_2d(tiles, coords, fill_value=fill_value,
|
||||
channels_first=channels_first)
|
||||
assert np.array_equal(assembled_array, padded_array)
|
||||
assert np.array_equal(assembly_offset, padding_offset)
|
||||
|
||||
for idx in range(tiles.shape[0]):
|
||||
row = coords[idx, 1] + assembly_offset[1]
|
||||
col = coords[idx, 0] + assembly_offset[0]
|
||||
if channels_first:
|
||||
crop = assembled_array[:, row:row + tile_size, col:col + tile_size]
|
||||
else:
|
||||
crop = assembled_array[row:row + tile_size, col:col + tile_size, :]
|
||||
assert np.array_equal(crop, tiles[idx])
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:06eb0acaa2883181e9b6ab976863f71cc843a75ed9175fae8fe9b879635af1b0
|
||||
size 816563
|
|
@ -1,2 +0,0 @@
|
|||
slide_id,image,label,meta1,meta2
|
||||
foo,panda_wsi_example.tiff,0,bar,baz
|
|
|
@ -1,213 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
import matplotlib
|
||||
from torch.functional import Tensor
|
||||
import pytest
|
||||
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, \
|
||||
plot_heatmap_overlay, plot_normalized_confusion_matrix
|
||||
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
|
||||
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.plotting import resize_and_save
|
||||
from InnerEye.ML.utils.ml_util import set_random_seed
|
||||
from Tests.ML.util import assert_binary_files_match
|
||||
|
||||
|
||||
def assert_equal_lists(pred: List, expected: List) -> None:
|
||||
assert len(pred) == len(expected)
|
||||
for i, slide in enumerate(pred):
|
||||
for j, value in enumerate(slide):
|
||||
if type(value) in [int, float]:
|
||||
assert math.isclose(value, expected[i][j], rel_tol=1e-06)
|
||||
elif (type(value) == Tensor) and (value.ndim >= 1):
|
||||
for k, idx in enumerate(value):
|
||||
assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06)
|
||||
elif isinstance(value, List):
|
||||
for k, idx in enumerate(value):
|
||||
if type(idx) in [int, float]:
|
||||
assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06)
|
||||
elif type(idx) == Tensor:
|
||||
assert math.isclose(idx.item(), expected[i][j][k].item(), rel_tol=1e-06)
|
||||
else:
|
||||
raise TypeError("Unexpected list composition")
|
||||
|
||||
|
||||
test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],
|
||||
ResultsKey.CLASS_PROBS: [Tensor([0.6, 0.4]), Tensor([0.3, 0.7]), Tensor([0.6, 0.4]), Tensor([0.0, 1.0]),
|
||||
Tensor([0.7, 0.3]), Tensor([0.8, 0.2]), Tensor([0.1, 0.9]), Tensor([0.01, 0.99])],
|
||||
ResultsKey.TRUE_LABEL: [0, 1, 1, 1, 1, 0, 0, 0],
|
||||
ResultsKey.BAG_ATTN:
|
||||
[Tensor([[0.10, 0.00, 0.20, 0.15]]),
|
||||
Tensor([[0.10, 0.18, 0.15, 0.13]]),
|
||||
Tensor([[0.25, 0.23, 0.20, 0.21]]),
|
||||
Tensor([[0.33, 0.31, 0.37, 0.35]]),
|
||||
Tensor([[0.43, 0.01, 0.07, 0.25]]),
|
||||
Tensor([[0.53, 0.11, 0.17, 0.55]]),
|
||||
Tensor([[0.63, 0.21, 0.27, 0.05]]),
|
||||
Tensor([[0.73, 0.31, 0.37, 0.15]])],
|
||||
ResultsKey.TILE_X:
|
||||
[Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424])],
|
||||
ResultsKey.TILE_Y:
|
||||
[Tensor([200, 424, 200, 424]),
|
||||
Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424]),
|
||||
Tensor([200, 200, 424, 424])]
|
||||
}
|
||||
|
||||
|
||||
def test_select_k_tiles() -> None:
|
||||
nslides = 2
|
||||
ntiles = 2
|
||||
# TP
|
||||
top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
|
||||
bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
|
||||
print(top_tp)
|
||||
assert_equal_lists(top_tp, [(4, Tensor([0.0, 1.0]), [3, 4], [Tensor([0.37]), Tensor([0.35])]),
|
||||
(2, Tensor([0.3, 0.7]), [2, 3], [Tensor([0.18]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_tp, [(4, Tensor([0.0, 1.0]), [2, 1], [Tensor([0.31]), Tensor([0.33])]),
|
||||
(2, Tensor([0.3, 0.7]), [1, 4], [Tensor([0.10]), Tensor([0.13])])])
|
||||
|
||||
# FN
|
||||
top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
|
||||
bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_fn, [(5, Tensor([0.7, 0.3]), [1, 4], [Tensor([0.43]), Tensor([0.25])]),
|
||||
(3, Tensor([0.6, 0.4]), [1, 2], [Tensor([0.25]), Tensor([0.23])])])
|
||||
assert_equal_lists(bottom_fn, [(5, Tensor([0.7, 0.3]), [2, 3], [Tensor([0.01]), Tensor([0.07])]),
|
||||
(3, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.20]), Tensor([0.21])])])
|
||||
|
||||
# TN
|
||||
top_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
|
||||
bottom_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_tn, [(6, Tensor([0.8, 0.2]), [4, 1], [Tensor([0.55]), Tensor([0.53])]),
|
||||
(1, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.2]), Tensor([0.15])])])
|
||||
assert_equal_lists(bottom_tn, [(6, Tensor([0.8, 0.2]), [2, 3], [Tensor([0.11]), Tensor([0.17])]),
|
||||
(1, Tensor([0.6, 0.4]), [2, 1], [Tensor([0.00]), Tensor([0.10])])])
|
||||
|
||||
# FP
|
||||
top_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
|
||||
bottom_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
|
||||
assert_equal_lists(top_fp, [(8, Tensor([0.01, 0.99]), [1, 3], [Tensor([0.73]), Tensor([0.37])]),
|
||||
(7, Tensor([0.1, 0.9]), [1, 3], [Tensor([0.63]), Tensor([0.27])])])
|
||||
assert_equal_lists(bottom_fp, [(8, Tensor([0.01, 0.99]), [4, 2], [Tensor([0.15]), Tensor([0.31])]),
|
||||
(7, Tensor([0.1, 0.9]), [4, 2], [Tensor([0.05]), Tensor([0.21])])])
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
|
||||
def test_plot_scores_hist(test_output_dirs: OutputFolderForTests) -> None:
|
||||
fig = plot_scores_hist(test_dict)
|
||||
assert isinstance(fig, matplotlib.figure.Figure)
|
||||
file = Path(test_output_dirs.root_dir) / "plot_score_hist.png"
|
||||
resize_and_save(5, 5, file)
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / "score_hist.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6])
|
||||
def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
|
||||
set_random_seed(0)
|
||||
slide_image = np.random.rand(3, 1000, 2000)
|
||||
fig = plot_slide(slide_image=slide_image, scale=scale)
|
||||
assert isinstance(fig, matplotlib.figure.Figure)
|
||||
file = Path(test_output_dirs.root_dir) / "plot_slide.png"
|
||||
resize_and_save(5, 5, file)
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
|
||||
def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None:
|
||||
set_random_seed(0)
|
||||
slide_image = np.random.rand(3, 1000, 2000)
|
||||
location_bbox = [100, 100]
|
||||
slide = 1
|
||||
tile_size = 224
|
||||
level = 0
|
||||
fig = plot_heatmap_overlay(slide=slide, # type: ignore
|
||||
slide_image=slide_image,
|
||||
results=test_dict, # type: ignore
|
||||
location_bbox=location_bbox,
|
||||
tile_size=tile_size,
|
||||
level=level)
|
||||
assert isinstance(fig, matplotlib.figure.Figure)
|
||||
file = Path(test_output_dirs.root_dir) / "plot_heatmap_overlay.png"
|
||||
resize_and_save(5, 5, file)
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / "heatmap_overlay.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
|
||||
def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests, n_classes: int) -> None:
|
||||
set_random_seed(0)
|
||||
if n_classes > 1:
|
||||
cm = np.random.randint(1, 1000, size=(n_classes, n_classes))
|
||||
class_names = [str(i) for i in range(n_classes)]
|
||||
else:
|
||||
cm = np.random.randint(1, 1000, size=(n_classes + 1, n_classes + 1))
|
||||
class_names = [str(i) for i in range(n_classes + 1)]
|
||||
cm_n = cm / cm.sum(axis=1, keepdims=True)
|
||||
assert (cm_n <= 1).all()
|
||||
|
||||
fig = plot_normalized_confusion_matrix(cm=cm_n, class_names=class_names)
|
||||
assert isinstance(fig, matplotlib.figure.Figure)
|
||||
file = Path(test_output_dirs.root_dir) / f"plot_confusion_matrix_{n_classes}.png"
|
||||
resize_and_save(5, 5, file)
|
||||
assert file.exists()
|
||||
expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png"
|
||||
# To update the stored results, uncomment this line:
|
||||
# expected.write_bytes(file.read_bytes())
|
||||
assert_binary_files_match(file, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("level", [0, 1, 2])
|
||||
def test_location_selected_tiles(level: int) -> None:
|
||||
set_random_seed(0)
|
||||
slide = 1
|
||||
location_bbox = [100, 100]
|
||||
slide_image = np.random.rand(3, 1000, 2000)
|
||||
|
||||
coords = []
|
||||
slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore
|
||||
slide_idx = slide_ids.index(slide)
|
||||
for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore
|
||||
tile_coords = np.transpose(
|
||||
np.array([test_dict[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), # type: ignore
|
||||
test_dict[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) # type: ignore
|
||||
coords.append(tile_coords)
|
||||
|
||||
coords = np.array(coords)
|
||||
tile_coords_transformed = location_selected_tiles(tile_coords=coords,
|
||||
location_bbox=location_bbox,
|
||||
level=level)
|
||||
tile_xs, tile_ys = tile_coords_transformed.T
|
||||
level_dict = {0: 1, 1: 4, 2: 16}
|
||||
factor = level_dict[level]
|
||||
assert min(tile_xs) >= 0
|
||||
assert max(tile_xs) <= slide_image.shape[2] // factor
|
||||
assert min(tile_ys) >= 0
|
||||
assert max(tile_ys) <= slide_image.shape[1] // factor
|
|
@ -1,28 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from InnerEye.ML.Histopathology.utils.tcga_utils import extract_fields
|
||||
|
||||
|
||||
def test_extract_fields() -> None:
|
||||
slide_id = "TCGA-XX-0123"
|
||||
tile_id = "ABCDEFGHIJKL"
|
||||
split = "train"
|
||||
label = 0
|
||||
path = (f"CRC_DX_{split.upper()}/"
|
||||
f"{['MSS', 'MSIMUT'][label]}/"
|
||||
f"blk-{tile_id}-{slide_id}-01Z-00-DX1.png")
|
||||
fields = {
|
||||
'slide_id': slide_id,
|
||||
'tile_id': tile_id,
|
||||
'image': path,
|
||||
'split': split,
|
||||
'label': label
|
||||
}
|
||||
extracted_fields = extract_fields(pd.Series(fields))
|
||||
assert fields == extracted_fields
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from InnerEye.ML.models.architectures.sequential.gru import LayerNormGRU, LayerNormGRUCell
|
||||
|
||||
|
||||
def test_cell_initialisation() -> None:
|
||||
model = LayerNormGRU(input_size=2, hidden_size=2, num_layers=3)
|
||||
assert len(model.cells) == 3
|
||||
|
||||
|
||||
def test_layer_norm_initialisation() -> None:
|
||||
cell = LayerNormGRUCell(input_size=2, hidden_size=2, use_layer_norm=True, dropout=0.50)
|
||||
assert isinstance(cell.ln_r, nn.LayerNorm)
|
||||
assert cell.dropout == 0.50
|
||||
|
||||
|
||||
def test_gru_forward_pass() -> None:
|
||||
num_layers = 2
|
||||
batch_size = 10
|
||||
hidden_dim = 2
|
||||
input_dim = 3
|
||||
seq_length = 5
|
||||
|
||||
model = LayerNormGRU(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers)
|
||||
input_sequence = torch.rand(size=(batch_size, seq_length, input_dim))
|
||||
initial_hidden_state = torch.zeros(size=(num_layers, batch_size, hidden_dim))
|
||||
output_sequence = model(input_sequence, initial_hidden_state)
|
||||
|
||||
assert output_sequence.size(0) == batch_size
|
||||
assert output_sequence.size(1) == seq_length
|
||||
assert output_sequence.size(2) == hidden_dim
|
|
@ -1,596 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from io import StringIO
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter, RandomAffine
|
||||
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, logging_to_stdout
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, SEQUENCE_POSITION_HUE_NAME_PREFIX
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
|
||||
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
|
||||
from InnerEye.ML.lightning_models import transfer_batch_to_device
|
||||
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
|
||||
from InnerEye.ML.model_testing import create_metrics_dict_for_scalar_models
|
||||
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoder, ImagingFeatureType
|
||||
from InnerEye.ML.models.architectures.sequential.rnn_classifier import RNNClassifier, RNNClassifierWithEncoder
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.scalar_config import ScalarLoss
|
||||
from InnerEye.ML.sequence_config import SEQUENCE_LENGTH_FILE, SEQUENCE_LENGTH_STATS_FILE, SequenceModelBase
|
||||
from InnerEye.ML.utils import ml_util
|
||||
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.io_util import ImageAndSegmentations
|
||||
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
from InnerEye.ML.visualizers.grad_cam_hooks import VisualizationMaps
|
||||
from Tests.ML.util import get_default_azure_config, model_train_unittest
|
||||
|
||||
SCAN_SIZE = (6, 64, 60)
|
||||
|
||||
|
||||
def prepare_sequences(num_sequences: int, sequence_length: int, batch_size: int) -> Tuple[List, List]:
|
||||
# Returns [batch][sequence, label]
|
||||
num_mini_batches = num_sequences // batch_size
|
||||
|
||||
inputs = np.random.choice([0, 1], size=(num_sequences, sequence_length), p=[1. / 3, 2. / 3]).astype(np.float32)
|
||||
inputs = torch.tensor(inputs)
|
||||
labels = torch.sum(inputs, dim=1) > (sequence_length // 2)
|
||||
labels = labels.long()
|
||||
data = list()
|
||||
|
||||
for batch_index in range(num_mini_batches):
|
||||
_input = inputs[batch_index * batch_size: (batch_index + 1) * batch_size]
|
||||
_label = labels[batch_index * batch_size: (batch_index + 1) * batch_size]
|
||||
data.append((_input, _label))
|
||||
|
||||
return data[:num_mini_batches // 2], data[num_mini_batches // 2:]
|
||||
|
||||
|
||||
class ToySequenceModel(SequenceModelBase):
|
||||
def __init__(self, use_combined_model: bool = False,
|
||||
imaging_feature_type: ImagingFeatureType = ImagingFeatureType.Image,
|
||||
combine_hidden_states: bool = False,
|
||||
use_encoder_layer_norm: bool = False,
|
||||
sequence_target_positions: Optional[List[int]] = None,
|
||||
use_mean_teacher_model: bool = False,
|
||||
**kwargs: Any) -> None:
|
||||
num_epochs = 3
|
||||
mean_teacher_alpha = 0.999 if use_mean_teacher_model else None
|
||||
sequence_target_positions = [2] if sequence_target_positions is None else sequence_target_positions
|
||||
image_column = "image" if use_combined_model else None
|
||||
categorical_feature_encoder = CategoricalToOneHotEncoder.create_from_dataframe(
|
||||
dataframe=_get_mock_sequence_dataset(), columns=["cat1"])
|
||||
super().__init__(
|
||||
local_dataset=full_ml_test_data_path("sequence_data_for_classification"),
|
||||
temperature_scaling_config=TemperatureScalingConfig(),
|
||||
label_value_column="label",
|
||||
numerical_columns=["numerical1", "numerical2"],
|
||||
categorical_columns=["cat1"],
|
||||
categorical_feature_encoder=categorical_feature_encoder,
|
||||
sequence_column="seqColumn",
|
||||
sequence_target_positions=sequence_target_positions,
|
||||
image_file_column=image_column,
|
||||
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
|
||||
num_epochs=num_epochs,
|
||||
num_dataload_workers=0,
|
||||
train_batch_size=3,
|
||||
l_rate=1e-1,
|
||||
load_segmentation=True,
|
||||
use_mixed_precision=True,
|
||||
label_smoothing_eps=0.05,
|
||||
drop_last_batch_in_training=True,
|
||||
mean_teacher_alpha=mean_teacher_alpha,
|
||||
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
|
||||
max_num_gpus=1,
|
||||
**kwargs
|
||||
)
|
||||
self.use_combined_model = use_combined_model
|
||||
self.imaging_feature_type = imaging_feature_type
|
||||
self.combine_hidden_state = combine_hidden_states
|
||||
self.use_encoder_layer_norm = use_encoder_layer_norm
|
||||
|
||||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
return DatasetSplits.from_proportions(
|
||||
df=dataset_df,
|
||||
proportion_train=0.7,
|
||||
proportion_test=0.2,
|
||||
proportion_val=0.1,
|
||||
)
|
||||
|
||||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
if self.use_combined_model:
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ImageTransformationPipeline(
|
||||
transforms=[RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
|
||||
ColorJitter(brightness=0.2)]))
|
||||
else:
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
def create_model(self) -> RNNClassifier:
|
||||
if self.use_combined_model:
|
||||
image_encoder: Optional[ImageEncoder] = \
|
||||
ImageEncoder(num_image_channels=1,
|
||||
imaging_feature_type=self.imaging_feature_type,
|
||||
num_non_image_features=self.get_total_number_of_non_imaging_features(),
|
||||
stride_size_per_encoding_block=(1, 2, 2),
|
||||
initial_feature_channels=4,
|
||||
num_encoder_blocks=3,
|
||||
)
|
||||
assert image_encoder is not None # for mypy
|
||||
input_dims = image_encoder.final_num_feature_channels
|
||||
else:
|
||||
image_encoder = None
|
||||
input_dims = self.get_total_number_of_non_imaging_features()
|
||||
|
||||
ref_indices = [0, 1] if self.combine_hidden_state else None
|
||||
|
||||
return RNNClassifierWithEncoder(input_dim=input_dims,
|
||||
hidden_dim=3,
|
||||
output_dim=1,
|
||||
num_rnn_layers=1,
|
||||
rnn_dropout=0.0,
|
||||
ref_indices=ref_indices,
|
||||
image_encoder=image_encoder,
|
||||
use_encoder_batch_norm=self.use_encoder_layer_norm,
|
||||
target_indices=self.get_target_indices())
|
||||
|
||||
|
||||
def _get_mock_sequence_dataset(dataset_contents: Optional[str] = None) -> pd.DataFrame:
|
||||
# The dataset has "measurements" for 3 different positions 0, 1, and 2, with columns for numerical1 and numerical2.
|
||||
# Labels are attached to position 3 only.
|
||||
if dataset_contents is None:
|
||||
dataset_contents = """subject,numerical1,numerical2,cat1,seqColumn,label,image
|
||||
2137.00005,362,71,A,0,0,scan1.npy
|
||||
2137.00005,357,69,B,1,0,scan2.npy
|
||||
2137.00005,355,64,C,2,0,scan3.npy
|
||||
2137.00005,355,63,C,3,1,scan4.npy
|
||||
2137.00125,348,64,A,0,0,scan1.npy
|
||||
2137.00125,316,68,A,1,0,scan3.npy
|
||||
2137.00125,349,68,A,2,0,scan2.npy
|
||||
2137.00125,361,67,B,3,0,scan1.npy
|
||||
2137.00125,350,68,B,4,0,scan1.npy
|
||||
2627.00001,477,58,A,0,0,scan2.npy
|
||||
2627.00001,220,59,A,1,0,scan2.npy
|
||||
2627.00001,222,60,A,2,0,scan1.npy
|
||||
2627.00001,217,65,A,5,1,scan3.npy
|
||||
2627.12341,210,60,B,0,0,scan4.npy
|
||||
2627.12341,217,61,B,1,0,scan1.npy
|
||||
2627.12341,224,63,B,2,1,scan2.npy
|
||||
3250.00005,344,76,C,0,0,scan2.npy
|
||||
3250.00005,233,76,C,1,0,scan4.npy
|
||||
3250.00005,212,84,C,2,0,scan3.npy
|
||||
3250.00005,215,84,C,3,0,scan1.npy
|
||||
3250.00005,215,82,C,4,0,scan1.npy
|
||||
3250.12345,233,84,C,0,0,scan3.npy
|
||||
3250.12345,218,84,C,1,0,scan3.npy
|
||||
3250.12345,221,84,C,2,0,scan1.npy
|
||||
3250.12345,238,84,C,3,0,scan1.npy
|
||||
"""
|
||||
return pd.read_csv(StringIO(dataset_contents), dtype=str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
|
||||
[(False, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Segmentation),
|
||||
(True, ImagingFeatureType.ImageAndSegmentation)])
|
||||
@pytest.mark.parametrize("combine_hidden_state", (True, False))
|
||||
@pytest.mark.parametrize("use_encoder_layer_norm", (True, False))
|
||||
@pytest.mark.parametrize("use_mean_teacher_model", (False,))
|
||||
@pytest.mark.gpu
|
||||
def test_rnn_classifier_via_config_1(use_combined_model: bool,
|
||||
imaging_feature_type: ImagingFeatureType,
|
||||
combine_hidden_state: bool,
|
||||
use_encoder_layer_norm: bool,
|
||||
use_mean_teacher_model: bool,
|
||||
test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test if we can build a simple RNN model that only feeds off non-image features.
|
||||
This just tests the mechanics of training, but not if the model learned.
|
||||
"""
|
||||
logging_to_stdout()
|
||||
config = ToySequenceModel(use_combined_model,
|
||||
imaging_feature_type=imaging_feature_type,
|
||||
combine_hidden_states=combine_hidden_state,
|
||||
use_encoder_layer_norm=use_encoder_layer_norm,
|
||||
use_mean_teacher_model=use_mean_teacher_model,
|
||||
should_validate=False)
|
||||
config.use_mixed_precision = True
|
||||
# Necessary because torch otherwise says "index_add_cuda_ does not have a deterministic implementation"
|
||||
config.pl_deterministic = False
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_mock_sequence_dataset()
|
||||
# Patch the load_images function that will be called once we access a dataset item
|
||||
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
|
||||
segmentations=np.random.randint(0, 2, SCAN_SIZE))
|
||||
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
|
||||
model_train_unittest(config, output_folder=test_output_dirs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
|
||||
[(False, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Segmentation),
|
||||
(True, ImagingFeatureType.ImageAndSegmentation)])
|
||||
def test_run_ml_with_sequence_model(use_combined_model: bool,
|
||||
imaging_feature_type: ImagingFeatureType,
|
||||
test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models, when it is started together via run_ml.
|
||||
"""
|
||||
logging_to_stdout()
|
||||
config = ToySequenceModel(use_combined_model, imaging_feature_type,
|
||||
should_validate=False, sequence_target_positions=[2, 10])
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_mock_sequence_dataset()
|
||||
config.num_epochs = 1
|
||||
config.max_batch_grad_cam = 1
|
||||
|
||||
# make sure we are testing with at least one sequence position that will not exist
|
||||
# to ensure correct handling of sequences that do not contain all the expected target positions
|
||||
assert max(config.sequence_target_positions) > config.dataset_data_frame[config.sequence_column].astype(float).max()
|
||||
|
||||
# Patch the load_images function that will be called once we access a dataset item
|
||||
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
|
||||
segmentations=np.random.randint(0, 2, SCAN_SIZE))
|
||||
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
|
||||
azure_config = get_default_azure_config()
|
||||
azure_config.train = True
|
||||
MLRunner(config, azure_config=azure_config).run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
|
||||
[(False, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Image),
|
||||
(True, ImagingFeatureType.Segmentation),
|
||||
(True, ImagingFeatureType.ImageAndSegmentation)])
|
||||
def test_visualization_with_sequence_model(use_combined_model: bool,
|
||||
imaging_feature_type: ImagingFeatureType,
|
||||
test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ToySequenceModel(use_combined_model, imaging_feature_type, should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_mock_sequence_dataset()
|
||||
config.num_epochs = 1
|
||||
model = config.create_model()
|
||||
if config.use_gpu:
|
||||
model = model.cuda()
|
||||
dataloader = SequenceDataset(config,
|
||||
data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
|
||||
batch_size=2)
|
||||
# Patch the load_images function that will be called once we access a dataset item
|
||||
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
|
||||
segmentations=np.random.randint(0, 2, SCAN_SIZE))
|
||||
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
|
||||
batch = next(iter(dataloader))
|
||||
if config.use_gpu:
|
||||
batch = transfer_batch_to_device(batch, torch.device(0))
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
|
||||
target_indices=config.get_target_indices(),
|
||||
sample=batch) # type: ignore
|
||||
number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
|
||||
number_subjects = len(model_inputs_and_labels.subject_ids)
|
||||
visualizer = VisualizationMaps(model, config)
|
||||
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
|
||||
model_inputs_and_labels.model_inputs)
|
||||
if use_combined_model:
|
||||
if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
|
||||
assert grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
|
||||
else:
|
||||
assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences)
|
||||
assert grad_cams.shape[:2] == (number_subjects, number_sequences)
|
||||
else:
|
||||
assert guided_grad_cams is None
|
||||
assert grad_cams is None
|
||||
assert pseudo_cam_non_img.shape[:2] == (number_subjects, number_sequences)
|
||||
assert probas.shape[0] == number_subjects
|
||||
non_image_features = config.numerical_columns + config.categorical_columns
|
||||
non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
|
||||
non_image_features,
|
||||
index=0,
|
||||
target_position=3)
|
||||
assert non_imaging_plot_labels == ['numerical1_0',
|
||||
'numerical2_0',
|
||||
'cat1_0',
|
||||
'numerical1_1',
|
||||
'numerical2_1',
|
||||
'cat1_1',
|
||||
'numerical1_2',
|
||||
'numerical2_2',
|
||||
'cat1_2',
|
||||
'numerical1_3',
|
||||
'numerical2_3',
|
||||
'cat1_3']
|
||||
|
||||
|
||||
class ToySequenceModel2(SequenceModelBase):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(
|
||||
temperature_scaling_config=TemperatureScalingConfig(),
|
||||
local_dataset=full_ml_test_data_path("sequence_data_for_classification"),
|
||||
label_value_column="label",
|
||||
numerical_columns=["feature"],
|
||||
sequence_column="index",
|
||||
sequence_target_positions=[2],
|
||||
loss_type=ScalarLoss.BinaryCrossEntropyWithLogits,
|
||||
num_epochs=20,
|
||||
num_dataload_workers=0,
|
||||
train_batch_size=40,
|
||||
l_rate=1e-2,
|
||||
drop_last_batch_in_training=True,
|
||||
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
|
||||
max_num_gpus=1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
return DatasetSplits.from_proportions(
|
||||
df=dataset_df,
|
||||
proportion_train=0.8,
|
||||
proportion_test=0.1,
|
||||
proportion_val=0.1,
|
||||
)
|
||||
|
||||
def create_model(self) -> Any:
|
||||
return RNNClassifier(input_dim=self.get_total_number_of_non_imaging_features(),
|
||||
hidden_dim=12,
|
||||
output_dim=1,
|
||||
num_rnn_layers=1,
|
||||
rnn_dropout=0.25,
|
||||
use_layer_norm=False,
|
||||
target_indices=self.get_target_indices())
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_rnn_classifier_via_config_2(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test if we can build an RNN classifier that learns sequences, of the same kind as in
|
||||
test_rnn_classifier_toy_problem, but built via the config.
|
||||
Only test the non-combined model because otherwise the build takes too much time.
|
||||
"""
|
||||
expected_max_train_loss = 0.71
|
||||
expected_max_val_loss = 0.71
|
||||
num_sequences = 100
|
||||
ml_util.set_random_seed(123)
|
||||
dataset_contents = "subject,index,feature,label\n"
|
||||
for subject in range(num_sequences):
|
||||
# Sequences have variable length
|
||||
sequence_length = np.random.choice([9, 10, 11, 12])
|
||||
# Each sequence is a series of 0 and 1
|
||||
inputs = np.random.choice([0, 1], size=(sequence_length,), p=[1. / 3, 2. / 3])
|
||||
label = np.sum(inputs) > (sequence_length // 2)
|
||||
for i, value in enumerate(inputs.tolist()):
|
||||
dataset_contents += f"S{subject},{i},{value},{label}\n"
|
||||
logging_to_stdout()
|
||||
config = ToySequenceModel2(should_validate=False)
|
||||
# Necessary because torch otherwise says "index_add_cuda_ does not have a deterministic implementation"
|
||||
config.pl_deterministic = False
|
||||
config.num_epochs = 2
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents)
|
||||
results, _ = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
|
||||
actual_train_loss = results.get_metric(is_training=True, metric_type=MetricType.LOSS.value)[-1]
|
||||
actual_val_loss = results.get_metric(is_training=False, metric_type=MetricType.LOSS.value)[-1]
|
||||
print(f"Training loss after {config.num_epochs} epochs: {actual_train_loss}")
|
||||
print(f"Validation loss after {config.num_epochs} epochs: {actual_val_loss}")
|
||||
assert actual_train_loss <= expected_max_train_loss, "Training loss too high"
|
||||
assert actual_val_loss <= expected_max_val_loss, "Validation loss too high"
|
||||
# Issue #374: put back in when temperature scaling is enabled again
|
||||
# assert np.allclose(results.optimal_temperature_scale_values_per_checkpoint_epoch, [0.97], rtol=0.1)
|
||||
|
||||
|
||||
class ToyMultiLabelSequenceModel(SequenceModelBase):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
num_epochs = 3
|
||||
super().__init__(
|
||||
temperature_scaling_config=TemperatureScalingConfig(),
|
||||
label_value_column="Label",
|
||||
numerical_columns=["NUM1", "NUM2"],
|
||||
sequence_column="Position",
|
||||
sequence_target_positions=[1, 2, 3],
|
||||
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
|
||||
num_epochs=num_epochs,
|
||||
num_dataload_workers=0,
|
||||
train_batch_size=3,
|
||||
l_rate=1e-1,
|
||||
label_smoothing_eps=0.05,
|
||||
categorical_columns=["CAT1"],
|
||||
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
|
||||
max_num_gpus=1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
return DatasetSplits.from_proportions(
|
||||
df=dataset_df,
|
||||
proportion_train=0.7,
|
||||
proportion_test=0.2,
|
||||
proportion_val=0.1,
|
||||
)
|
||||
|
||||
def create_model(self) -> Any:
|
||||
return RNNClassifier(input_dim=self.get_total_number_of_non_imaging_features(),
|
||||
hidden_dim=3,
|
||||
output_dim=1,
|
||||
num_rnn_layers=2,
|
||||
rnn_dropout=0.0,
|
||||
target_indices=self.get_target_indices())
|
||||
|
||||
|
||||
def test_run_ml_with_multi_label_sequence_model(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models that predicts at multiple time points,
|
||||
when it is started via run_ml.
|
||||
"""
|
||||
logging_to_stdout()
|
||||
config = ToyMultiLabelSequenceModel()
|
||||
assert config.get_target_indices() == [1, 2, 3]
|
||||
expected_prediction_targets = [f"{SEQUENCE_POSITION_HUE_NAME_PREFIX} {x}"
|
||||
for x in ["01", "02", "03"]]
|
||||
_target_indices = config.get_target_indices()
|
||||
assert _target_indices is not None
|
||||
assert len(_target_indices) == len(expected_prediction_targets)
|
||||
metrics_dict = create_metrics_dict_for_scalar_models(config)
|
||||
assert metrics_dict.get_hue_names(include_default=False) == expected_prediction_targets
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
# Create a fake dataset directory to make config validation pass
|
||||
config.local_dataset = test_output_dirs.root_dir
|
||||
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
|
||||
config.pre_process_dataset_dataframe()
|
||||
config.num_epochs = 1
|
||||
config.max_batch_grad_cam = 1
|
||||
azure_config = get_default_azure_config()
|
||||
azure_config.train = True
|
||||
MLRunner(config, azure_config=azure_config).run()
|
||||
# The metrics file should have one entry per epoch per subject per prediction target,
|
||||
# for all the 3 prediction targets.
|
||||
metrics_file = config.outputs_folder / "Train" / SUBJECT_METRICS_FILE_NAME
|
||||
assert metrics_file.exists()
|
||||
metrics = pd.read_csv(metrics_file)
|
||||
assert LoggingColumns.Patient.value in metrics
|
||||
assert LoggingColumns.Epoch.value in metrics
|
||||
assert LoggingColumns.Hue.value in metrics
|
||||
assert metrics[LoggingColumns.Hue.value].unique().tolist() == expected_prediction_targets
|
||||
group_by_subject = metrics.groupby(by=[LoggingColumns.Patient.value,
|
||||
LoggingColumns.Epoch.value])
|
||||
expected_prediction_target_lengths = [3, 2, 3, 3]
|
||||
for i, x in enumerate(group_by_subject):
|
||||
assert len(x[1]) == expected_prediction_target_lengths[i]
|
||||
group_by_subject_and_target = metrics.groupby(by=[LoggingColumns.Patient.value,
|
||||
LoggingColumns.Epoch.value,
|
||||
LoggingColumns.Hue.value])
|
||||
for _, group in group_by_subject_and_target:
|
||||
assert len(group) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combine_hidden_states", [True, False])
|
||||
def test_pad_gru_output(combine_hidden_states: bool) -> None:
|
||||
"""
|
||||
Test to make sure if model output does not cover the target indices then it is padded
|
||||
"""
|
||||
config = ToySequenceModel(
|
||||
sequence_target_positions=[5, 7],
|
||||
combine_hidden_states=combine_hidden_states,
|
||||
should_validate=False
|
||||
)
|
||||
model: RNNClassifier = config.create_model()
|
||||
# base case where no padding is required
|
||||
test_input = torch.rand(max(config.get_target_indices()) + 1, 1)
|
||||
padded = model.pad_gru_output(test_input)
|
||||
assert torch.equal(test_input, padded)
|
||||
# case when padding is required
|
||||
test_input = torch.rand(min(config.get_target_indices()) - 1, 1)
|
||||
expected = torch.cat([test_input, test_input.new_full((4, 1), fill_value=0)], dim=0)
|
||||
padded = model.pad_gru_output(test_input)
|
||||
assert torch.allclose(expected, padded)
|
||||
|
||||
|
||||
def test_visualization_for_different_target_weeks(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Tests that the visualizations are differentiated depending on the target week
|
||||
for which we visualize it.
|
||||
"""
|
||||
config = ToyMultiLabelSequenceModel(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
|
||||
config.pre_process_dataset_dataframe()
|
||||
model = create_model_with_temperature_scaling(config)
|
||||
dataloader = SequenceDataset(config,
|
||||
data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
|
||||
batch_size=2)
|
||||
batch = next(iter(dataloader))
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
|
||||
target_indices=config.get_target_indices(),
|
||||
sample=batch)
|
||||
|
||||
visualizer = VisualizationMaps(model, config)
|
||||
if config.use_gpu:
|
||||
device = visualizer.grad_cam.device
|
||||
batch = transfer_batch_to_device(batch, device)
|
||||
model = model.to(device)
|
||||
# Pseudo-grad cam explaining the prediction at target sequence 2
|
||||
_, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(model_inputs_and_labels.model_inputs,
|
||||
target_position=2,
|
||||
target_label_index=2)
|
||||
# Pseudo-grad cam explaining the prediction at target sequence 0
|
||||
_, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(model_inputs_and_labels.model_inputs,
|
||||
target_position=0,
|
||||
target_label_index=0)
|
||||
assert pseudo_cam_non_img_1.shape[1] == 1
|
||||
assert pseudo_cam_non_img_3.shape[1] == 3
|
||||
# Both visualizations should not be equal
|
||||
assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
|
||||
assert np.any(probas_3 != probas_1)
|
||||
|
||||
|
||||
def _get_multi_label_sequence_dataframe() -> pd.DataFrame:
|
||||
"""
|
||||
Returns a mock dataset for multi label sequence model.
|
||||
"""
|
||||
dataset_contents = """subject,NUM1,CAT1,NUM2,Position,Label
|
||||
2137.00005,362,A,71,0,
|
||||
2137.00005,357,B,69,1,0
|
||||
2137.00005,355,C,64,2,0
|
||||
2137.00005,355,C,63,3,1
|
||||
2137.00125,348,A,64,0,0
|
||||
2137.00125,316,A,68,1,1
|
||||
2137.00125,349,B,68,2,0
|
||||
2137.00125,361,B,67,3,1
|
||||
2137.00125,350,B,68,4,0
|
||||
2627.00001,477,C,58,0,0
|
||||
2627.00001,220,C,59,1,0
|
||||
2627.00001,222,A,60,2,0
|
||||
2627.00001,217,A,65,5,1
|
||||
2627.12341,210,B,60,0,0
|
||||
2627.12341,217,B,61,1,0
|
||||
2627.12341,224,B,63,2,1
|
||||
3250.00005,344,B,76,0,0
|
||||
3250.00005,233,A,76,1,0
|
||||
3250.00005,212,A,84,2,0
|
||||
3250.00005,215,A,84,3,0
|
||||
3250.00005,215,A,82,4,0
|
||||
3250.12345,233,A,84,0,1
|
||||
3250.12345,218,A,84,1,0
|
||||
3250.12345,221,B,84,2,0
|
||||
3250.12345,238,B,84,3,0
|
||||
"""
|
||||
return pd.read_csv(StringIO(dataset_contents), dtype=str)
|
||||
|
||||
|
||||
def test_sequence_dataset_stats_hook(test_output_dirs: OutputFolderForTests) -> None:
|
||||
model = ToySequenceModel()
|
||||
model.set_output_to(test_output_dirs.root_dir)
|
||||
model.dataset_data_frame = _get_mock_sequence_dataset()
|
||||
model.create_and_set_torch_datasets()
|
||||
length_file = model.logs_folder / SEQUENCE_LENGTH_FILE
|
||||
assert length_file.is_file()
|
||||
assert length_file.read_text().splitlines() == [
|
||||
"cross_validation_split_index,data_split,subject,sequence_length",
|
||||
"-1,Train,2137.00005,4",
|
||||
"-1,Train,2627.12341,3",
|
||||
"-1,Train,3250.00005,5",
|
||||
"-1,Train,3250.12345,4",
|
||||
"-1,Test,2627.00001,3",
|
||||
"-1,Val,2137.00125,5"]
|
||||
stats_file = model.logs_folder / SEQUENCE_LENGTH_STATS_FILE
|
||||
assert stats_file.is_file()
|
||||
assert stats_file.read_text().splitlines() == [
|
||||
" sequence_length ",
|
||||
" count mean std min 25% 50% 75% max",
|
||||
"data_split ",
|
||||
"Test 1.0 3.0 NaN 3.0 3.00 3.0 3.00 3.0",
|
||||
"Train 4.0 4.0 0.816497 3.0 3.75 4.0 4.25 5.0",
|
||||
"Val 1.0 5.0 NaN 5.0 5.00 5.0 5.00 5.0"]
|
|
@ -19,19 +19,16 @@ from InnerEye.Common.output_directories import OutputFolderForTests
|
|||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset, ScalarItemAugmentation
|
||||
from InnerEye.ML.lightning_models import transfer_batch_to_device
|
||||
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
|
||||
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoderWithMlp, \
|
||||
ImagingFeatureType
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.scalar_config import AggregationType, ScalarLoss, ScalarModelBase, get_non_image_features_dict
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES, segmentation_to_one_hot
|
||||
from InnerEye.ML.utils.io_util import ImageAndSegmentations, NumpyFile
|
||||
from InnerEye.ML.utils.ml_util import is_gpu_available, set_random_seed
|
||||
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
|
||||
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
from InnerEye.ML.visualizers.grad_cam_hooks import VisualizationMaps
|
||||
from InnerEye.ML.visualizers.model_summary import ModelSummary
|
||||
from Tests.ML.util import get_default_azure_config, model_train_unittest
|
||||
|
||||
|
@ -324,86 +321,3 @@ def test_segmentation_to_one_hot(use_gpu: bool, input_on_gpu: bool) -> None:
|
|||
else:
|
||||
expected = torch.zeros((B,) + dim, device=one_hot.device)
|
||||
assert one_hot[:, i, ...].float().allclose(expected), f"Dimension {i} should have all ones"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("encode_channels_jointly", [True, False])
|
||||
@pytest.mark.parametrize("use_non_imaging_features", [True, False])
|
||||
@pytest.mark.parametrize("imaging_feature_type", [ImagingFeatureType.Image,
|
||||
ImagingFeatureType.Segmentation,
|
||||
ImagingFeatureType.ImageAndSegmentation])
|
||||
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
|
||||
imaging_feature_type: ImagingFeatureType,
|
||||
encode_channels_jointly: bool,
|
||||
test_output_dirs: OutputFolderForTests) -> None:
|
||||
dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
|
||||
S1,week0,scan1.npy,,1,10,Male,Val1
|
||||
S1,week1,scan2.npy,True,2,20,Female,Val2
|
||||
S2,week0,scan3.npy,,3,30,Female,Val3
|
||||
S2,week1,scan4.npy,False,4,40,Female,Val1
|
||||
S3,week0,scan1.npy,,5,50,Male,Val2
|
||||
S3,week1,scan3.npy,True,6,60,Male,Val2
|
||||
"""
|
||||
dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
|
||||
numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
|
||||
categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
|
||||
non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
|
||||
specific_channels={"categorical2": ["week1"]}) \
|
||||
if use_non_imaging_features else {}
|
||||
|
||||
config = ImageEncoder(
|
||||
local_dataset=Path(),
|
||||
encode_channels_jointly=encode_channels_jointly,
|
||||
should_validate=False,
|
||||
numerical_columns=numerical_columns,
|
||||
categorical_columns=categorical_columns,
|
||||
imaging_feature_type=imaging_feature_type,
|
||||
non_image_feature_channels=non_image_feature_channels,
|
||||
categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
|
||||
dataframe=dataset_dataframe, columns=categorical_columns)
|
||||
)
|
||||
|
||||
dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
|
||||
.as_data_loader(shuffle=False, batch_size=2)
|
||||
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.num_epochs = 1
|
||||
model = create_model_with_temperature_scaling(config)
|
||||
visualizer = VisualizationMaps(model, config)
|
||||
# Patch the load_images function that will be called once we access a dataset item
|
||||
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
|
||||
segmentations=np.random.randint(0, 2, (6, 64, 60)))
|
||||
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
|
||||
batch = next(iter(dataloader))
|
||||
if config.use_gpu:
|
||||
device = visualizer.grad_cam.device
|
||||
batch = transfer_batch_to_device(batch, device)
|
||||
visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
|
||||
target_indices=[],
|
||||
sample=batch)
|
||||
number_channels = len(config.image_channels)
|
||||
number_subjects = len(model_inputs_and_labels.subject_ids)
|
||||
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
|
||||
model_inputs_and_labels.model_inputs)
|
||||
|
||||
if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
|
||||
else:
|
||||
assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)
|
||||
|
||||
assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
|
||||
else (number_subjects, number_channels)
|
||||
|
||||
if use_non_imaging_features:
|
||||
non_image_features = config.numerical_columns + config.categorical_columns
|
||||
non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
|
||||
non_image_features,
|
||||
index=0)
|
||||
assert non_imaging_plot_labels == ['numerical1_week1',
|
||||
'numerical1_week0',
|
||||
'numerical2_week1',
|
||||
'numerical2_week0',
|
||||
'categorical1_week1',
|
||||
'categorical1_week0',
|
||||
'categorical2_week1']
|
||||
assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))
|
||||
|
|
|
@ -24,7 +24,6 @@ from InnerEye.ML.pipelines.scalar_inference import ScalarEnsemblePipeline, Scala
|
|||
from InnerEye.ML.run_ml import is_classification_model
|
||||
from InnerEye.ML.scalar_config import EnsembleAggregationType
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToySequenceModel
|
||||
from Tests.ML.utils.test_model_util import create_model_and_store_checkpoint
|
||||
|
||||
|
||||
|
@ -190,7 +189,6 @@ def test_is_classification_model() -> None:
|
|||
assert is_classification_model(GlaucomaPublic())
|
||||
assert is_classification_model(ClassificationModelForTesting())
|
||||
assert not is_classification_model(BasicModel2Epochs())
|
||||
assert not is_classification_model(ToySequenceModel())
|
||||
|
||||
|
||||
def test_inference_required_single_runs() -> None:
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pack_sequence
|
||||
|
||||
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels, map_packed_sequence_data, \
|
||||
sequences_to_padded_tensor
|
||||
|
||||
|
||||
def test_map_packed_sequence_data() -> None:
|
||||
"""
|
||||
Test to ensure helper function to apply a map transform to a packed sequence returns expected results.
|
||||
"""
|
||||
packed = pack_sequence([torch.tensor([[1.0], [2.0]])], enforce_sorted=False)
|
||||
mapped = map_packed_sequence_data(packed, lambda x: x * 2)
|
||||
assert torch.equal(mapped.data, torch.tensor([[2.0], [4.0]]))
|
||||
|
||||
with pytest.raises(Exception):
|
||||
map_packed_sequence_data(packed, lambda x: x.unsqueeze(dim=0))
|
||||
|
||||
|
||||
def test_get_masked_model_outputs_and_labels() -> None:
|
||||
"""
|
||||
Test to ensure the helper function to get masked model outputs, labels and their associated subject ids
|
||||
returns the expected results.
|
||||
"""
|
||||
|
||||
def _create_masked_and_check_expected(_model_outputs: torch.Tensor,
|
||||
_labels: torch.Tensor,
|
||||
_subject_ids: Sequence[str],
|
||||
_sorted_indices: Optional[torch.Tensor] = None) -> None:
|
||||
_masked = get_masked_model_outputs_and_labels(_model_outputs, _labels, _subject_ids)
|
||||
assert _masked is not None
|
||||
sorted_indices = _masked.labels.sorted_indices if _sorted_indices is None else _sorted_indices
|
||||
if sorted_indices is not None:
|
||||
_labels = _labels[sorted_indices]
|
||||
_model_outputs = _model_outputs[sorted_indices]
|
||||
_subject_ids = np.array(_subject_ids)[sorted_indices].tolist()
|
||||
|
||||
_expected_labels = _labels.transpose(dim0=0, dim1=1).flatten()
|
||||
_mask = ~torch.isnan(_expected_labels)
|
||||
_expected_labels = _expected_labels[_mask]
|
||||
_expected_model_outputs = _model_outputs.transpose(dim0=0, dim1=1).flatten()[_mask]
|
||||
_expected_subject_ids = _subject_ids
|
||||
|
||||
assert torch.equal(_expected_model_outputs, _masked.model_outputs.data)
|
||||
assert torch.equal(_expected_labels, _masked.labels.data)
|
||||
assert _expected_subject_ids[:_masked.labels.sorted_indices.shape[0]] == _masked.subject_ids
|
||||
|
||||
# test base case where no masking needs to be applied
|
||||
model_outputs = torch.rand((3, 4, 1))
|
||||
labels = torch.rand((3, 4, 1)).round()
|
||||
subject_ids = ['1', '2', '3']
|
||||
|
||||
_create_masked_and_check_expected(model_outputs, labels, subject_ids)
|
||||
|
||||
# test with unequal length sequences where masking will be performed
|
||||
model_outputs = sequences_to_padded_tensor([torch.rand(x + 1, 1) for x in range(3)], padding_value=np.nan)
|
||||
labels = sequences_to_padded_tensor([torch.rand(x + 1, 1) for x in range(3)], padding_value=np.nan)
|
||||
|
||||
_create_masked_and_check_expected(model_outputs, labels, subject_ids)
|
||||
|
||||
# test where one sequence is totally removed
|
||||
model_outputs[0] = np.nan
|
||||
labels[0] = np.nan
|
||||
|
||||
_create_masked_and_check_expected(model_outputs, labels, subject_ids, _sorted_indices=torch.tensor([2, 1, 0]))
|
|
@ -13,14 +13,13 @@ from pandas.core.dtypes.common import is_string_dtype
|
|||
|
||||
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
|
||||
from InnerEye.Common.common_util import CROSSVAL_RESULTS_FOLDER, FULL_METRICS_DATAFRAME_FILE, \
|
||||
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME, logging_to_stdout
|
||||
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME
|
||||
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.deep_learning_config import ModelCategory
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.utils.csv_util import CSV_INSTITUTION_HEADER, CSV_SERIES_HEADER
|
||||
from InnerEye.ML.visualizers.plot_cross_validation import COL_MODE, \
|
||||
METRICS_BY_MODE_AND_STRUCTURE_FILE, METRICS_BY_MODE_FILE, \
|
||||
|
@ -29,8 +28,6 @@ from InnerEye.ML.visualizers.plot_cross_validation import COL_MODE, \
|
|||
create_results_breakdown, download_crossval_result_files, get_split_id, load_dataframes, \
|
||||
plot_cross_validation_from_files, save_outliers
|
||||
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, get_most_recent_run_id
|
||||
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToyMultiLabelSequenceModel, \
|
||||
_get_multi_label_sequence_dataframe
|
||||
from Tests.ML.util import assert_text_files_match, get_default_azure_config
|
||||
|
||||
|
||||
|
@ -430,29 +427,6 @@ def test_download_or_get_local_file_2(test_output_dirs: OutputFolderForTests) ->
|
|||
assert file_in_folder == download_to_folder / file2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is only used to create input for test_load_files_with_prediction_target")
|
||||
def test_run_ml_with_multi_label_sequence_in_crossval(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test training and testing of sequence models that predicts at multiple time points,
|
||||
including aggregation of cross validation results.
|
||||
"""
|
||||
logging_to_stdout()
|
||||
config = ToyMultiLabelSequenceModel(should_validate=False)
|
||||
assert config.get_target_indices() == [1, 2, 3]
|
||||
expected_prediction_targets = ["Seq_pos 01", "Seq_pos 02", "Seq_pos 03"]
|
||||
target_indices = config.get_target_indices()
|
||||
assert target_indices
|
||||
assert len(target_indices) == len(expected_prediction_targets)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
|
||||
config.pre_process_dataset_dataframe()
|
||||
config.num_epochs = 1
|
||||
config.number_of_cross_validation_splits = 2
|
||||
azure_config = get_default_azure_config()
|
||||
azure_config.train = True
|
||||
MLRunner(config, azure_config=azure_config).run()
|
||||
|
||||
|
||||
def test_load_files_with_prediction_target() -> None:
|
||||
"""
|
||||
For multi-week RNNs that predict at multiple sequence points: Test that the dataframes
|
||||
|
|
|
@ -73,7 +73,7 @@ steps:
|
|||
scanType: 'Register'
|
||||
verbosity: 'Normal'
|
||||
alertWarningLevel: 'High'
|
||||
failOnAlert: true
|
||||
failOnAlert: false
|
||||
failOnStderr: true
|
||||
|
||||
- task: PublishBuildArtifacts@1
|
||||
|
|
|
@ -16,5 +16,5 @@ steps:
|
|||
scanType: 'Register'
|
||||
verbosity: 'Normal'
|
||||
alertWarningLevel: 'High'
|
||||
failOnAlert: true
|
||||
failOnAlert: false
|
||||
failOnStderr: true
|
||||
|
|
|
@ -119,64 +119,12 @@ You can find instructions for other Linux distributions on NVidia website: https
|
|||
The following steps describe how to set up specific tools. You can execute most of those at a later
|
||||
point, if you want to dig deeper into the code.
|
||||
|
||||
## PyCharm
|
||||
## VSCode
|
||||
|
||||
Our team uses [PyCharm](https://www.jetbrains.com/pycharm/) for development, but any good editor
|
||||
([VSCode](https://code.visualstudio.com/) for example) will do as well.
|
||||
([VSCode](https://code.visualstudio.com/) for example)
|
||||
|
||||
This repository already contains a PyCharm configuration file in `.idea/InnerEye-DeepLearning.iml`. It should
|
||||
automatically pick the WSL Python interpreter (see [WSL.md](WSL.md)) as the default (no need to import the settings file)
|
||||
- if it doesn't happen you will need to adjust that as described [here](https://www.jetbrains.com/help/pycharm/configuring-python-interpreter.html).
|
||||
|
||||
|
||||
## How to manually set up flake8 as a PyCharm external tool
|
||||
|
||||
Go to File / Settings / Tools / External Tools / Add.
|
||||
|
||||
* Name: Flake8
|
||||
* Program: $PyInterpreterDirectory$/python
|
||||
* Arguments: -m flake8 $ProjectFileDir$
|
||||
* Working directory: $ProjectFileDir$
|
||||
* Advanced Options / Output Filters: $FILE_PATH$\:$LINE$\:$COLUMN$\:.*
|
||||
|
||||
Run Flake8 by right-clicking on a source file, External Tools / Flake8
|
||||
|
||||
## How to manually set up mypy as a PyCharm external tool
|
||||
|
||||
Go to File / Settings / Tools / External Tools / Add.
|
||||
|
||||
* Name: mypy
|
||||
* Program: $PyInterpreterDirectory$/python
|
||||
* Arguments: $ProjectFileDir$/mypy_runner.py -m <path to mypy executable>
|
||||
You can find the path to the mypy executable by typing `where mypy` on Windows or `which mypy` on Linux.
|
||||
If you have configured a virtual environment in PyCharm, the path will usually be
|
||||
`$PyInterpreterDirectory$/Scripts/mypy.exe` on Windows and `$PyInterpreterDirectory$/mypy` on Linux.
|
||||
* Working directory: $ProjectFileDir$
|
||||
* Advanced Options / Output Filters: $FILE_PATH$\:$LINE$\:.*
|
||||
|
||||
Run mypy by right-clicking on a source file, External Tools / mypy
|
||||
|
||||
## Deleting and creating a Conda environment
|
||||
|
||||
To delete, make sure the environment being deleted is not your current environment (just run `deactivate`). Then run
|
||||
`conda env remove --name environmentToDelete`.
|
||||
|
||||
To create an environment from scratch and then export it to a YAML file:
|
||||
|
||||
conda create --name envName python
|
||||
pip install whatEverPackage
|
||||
pip install packageWithVersion==1.0.42
|
||||
conda env export --no-builds --file=my_env.yml
|
||||
|
||||
With conda installation, the Apex library is built without the C++ files that are intended be used in backend-op
|
||||
computations such as fused_adam and fused_layernorm. This is mainly because we are unable to pass the
|
||||
required input install arguments to the setup file through a conda environment file. By building the library with
|
||||
these arguments, one could expect further speed-ups in both forward-backward model passes. If you are interested in
|
||||
installing Apex with these flags, please run the following commands in your shell:
|
||||
|
||||
git clone https://github.com/NVIDIA/apex; cd apex
|
||||
git checkout 880ab925bce9f817a93988b021e12db5f67f7787
|
||||
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
||||
## Conda
|
||||
- `conda env create -f environment.yml`
|
||||
|
||||
## Conda updates
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# How to setup Azure Machine Learning for InnerEye
|
||||
|
||||
Our preferred way to use AzureML is using the [AzureTRE](https://microsoft.github.io/AzureTRE/)
|
||||
|
||||
In order to be able to train models on Azure Machine Learning (AML) you will need to setup your environment in the
|
||||
Azure Portal first. In this document we will walk you through this process step-by-step.
|
||||
|
|
Загрузка…
Ссылка в новой задаче