* Remove rnns

* Fix flake8

* Edit README

* Edit README

* Remove sequence

* Remove sequence

* Fix all

* Remove more

* Remove ignore

* Fix tests

* Undo config

* Fix config

* Revert pycharm

* Fix tests

* Undo outputlogger

* Fix flake8

* Fix ignore file

* Revert hi-ml

* Disable fail on alert
This commit is contained in:
Javier 2022-03-09 14:53:12 +00:00 коммит произвёл GitHub
Родитель 8a78ec8c1e
Коммит 1606729c7a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
84 изменённых файлов: 87 добавлений и 8193 удалений

5
.gitignore поставляемый
Просмотреть файл

@ -166,5 +166,6 @@ InnerEye-DataQuality/name_stats_scoring.png
InnerEye-DataQuality/cifar-10-batches-py
InnerEye-DataQuality/logs
InnerEye-DataQuality/data
!**/InnerEye/ML/Histopathology/datasets
None
cifar-10-batches-py
cifar-10-python.tar.gz

8
.vscode/settings.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,8 @@
{
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestArgs": [
"Tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

Просмотреть файл

@ -13,7 +13,7 @@ created.
## Upcoming
### Added
-([#671](https://github.com/microsoft/InnerEye-DeepLearning/pull/671)) Remove sequence models and unused variables. Simplify README.
- ([#678](https://github.com/microsoft/InnerEye-DeepLearning/pull/678)) Add function to get log level name and use it for logging.
- ([#666](https://github.com/microsoft/InnerEye-DeepLearning/pull/666)) Replace RadIO with TorchIO for patch-based inference.
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train

Просмотреть файл

@ -24,8 +24,6 @@ from InnerEye.Common.generic_parsing import GenericConfig
# The name of the "azureml" property of AzureConfig
AZURECONFIG_SUBMIT_TO_AZUREML = "azureml"
INPUT_DATA_KEY = "input_data"
@dataclass(frozen=True)
class GitInformation:

Просмотреть файл

@ -21,8 +21,6 @@ from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.utils.config_loader import ModelConfigLoader
SLEEP_TIME_SECONDS = 30
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
# Environment variables used for multi-node training

Просмотреть файл

@ -52,8 +52,7 @@ up each score from one set of results with a score from the other set.
"""
from collections import defaultdict
from itertools import filterfalse, tee
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
@ -67,8 +66,6 @@ from InnerEye.Common.common_util import FULL_METRICS_DATAFRAME_FILE
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.visualizers.metrics_scatterplot import create_scatterplots
INTERSECT = lambda l, r: np.intersect1d(l, r, False)
"""
The factor by which the Wilcoxon Z value should be divided to allow for incomplete independence of the data.
Experimentation (from comparing models built from exactly the same code and data, but different random seeds)
@ -206,15 +203,6 @@ def compose_pairwise_result(threshold: float, results: Dict[str, Dict[str, float
return []
def partition_results(pred: Callable, results: List[Any]) -> Tuple[List[Any], List[Any]]:
"""
Helper function to partition results into passed/failed
"""
t1, t2 = tee(results)
map_func = lambda r: (r, results[r])
return list(map(map_func, filterfalse(pred, t1))), list(map(map_func, filter(pred, t2)))
def read_data(csv_file: str, subset: str = 'all', exclude: Optional[List[str]] = None) \
-> Dict[str, Dict[str, Dict[str, float]]]:
"""

Просмотреть файл

@ -1,171 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd
class CacheMode(Enum):
NONE = 'none'
MEMORY = 'memory'
DISK = 'disk'
class CacheLocation(Enum):
NONE = 'none'
CPU = 'cpu'
SAME = 'same'
class TilesDataModule(LightningDataModule):
"""Base class to load the tiles of a dataset as train, val, test sets"""
def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
seed: Optional[int] = None, transform: Optional[Callable] = None,
cache_mode: CacheMode = CacheMode.NONE,
precache_location: CacheLocation = CacheLocation.NONE,
cache_dir: Optional[Path] = None,
number_of_cross_validation_splits: int = 0,
cross_validation_split_index: int = 0) -> None:
"""
:param root_path: Root directory of the source dataset.
:param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances.
:param batch_size: Number of slides to load per batch.
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
:param transform: A transform to apply to the source tiles dataset, or a composition of
transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`.
:param cache_mode: The type of caching to perform, i.e. whether the results of all
transforms up to the first randomised one should be computed only once and reused in
subsequent iterations:
- `MEMORY`: MONAI CacheDataset is used, the entire transformed dataset is kept in memory for fastest access;
- `DISK`: MONAI PersistentDataset is used, each transformed sample is saved to disk and loaded on-demand;
- `NONE` (default): standard MONAI dataset is used, no caching is performed.
:param precache_location: Whether to pre-cache the entire transformed dataset upfront and save
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so
multiple processes can afterwards access the same cache without contention in DDP settings. This parameter also allow to
choose if the cache will be re-loaded into CPU or GPU memory:
- `NONE (default)`: no pre-cache is performed;
- `CPU`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded into CPU;
- `SAME`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded on the same device it was saved from;
If cache_mode is `DISK` precache_location `CPU` and `GPU` are equivalent.
:param cache_dir: The directory onto which to cache data if caching is enabled.
:param number_of_cross_validation_splits: Number of folds to perform.
:param cross_validation_split_index: Index of the cross validation split to be performed.
"""
if precache_location is not CacheLocation.NONE and cache_mode is CacheMode.NONE:
raise ValueError("Can only pre-cache if caching is enabled")
if precache_location is not CacheLocation.NONE and cache_dir is None:
raise ValueError("A cache directory is required for pre-caching")
if cache_mode is CacheMode.DISK and cache_dir is None:
raise ValueError("A cache directory is required for on-disk caching")
super().__init__()
self.root_path = root_path
self.max_bag_size = max_bag_size
self.transform = transform
self.cache_mode = cache_mode
self.precache_location = precache_location
self.cache_dir = cache_dir
self.batch_size = batch_size
self.number_of_cross_validation_splits = number_of_cross_validation_splits
self.cross_validation_split_index = cross_validation_split_index
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
self.class_weights = self.train_dataset.get_class_weights()
self.seed = seed
def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]:
"""Create the training, validation, and test datasets"""
raise NotImplementedError
def prepare_data(self) -> None:
if self.precache_location != CacheLocation.NONE:
self._load_dataset(self.train_dataset, stage='train', shuffle=True)
self._load_dataset(self.val_dataset, stage='val', shuffle=True)
self._load_dataset(self.test_dataset, stage='test', shuffle=True)
def _dataset_pickle_path(self, stage: str) -> Optional[Path]:
if self.cache_dir is None:
return None
return self.cache_dir / f"{stage}_dataset.pt"
def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset:
dataset_pickle_path = self._dataset_pickle_path(stage)
if dataset_pickle_path and dataset_pickle_path.is_file():
if self.precache_location == CacheLocation.CPU:
memory_location = torch.device('cpu')
print(f"Loading dataset from {dataset_pickle_path} into {memory_location}")
else:
# by default torch.load will reload on the same device it was saved from
memory_location = None # type: ignore
with dataset_pickle_path.open('rb') as f:
return torch.load(f, map_location=memory_location)
generator = _create_generator(self.seed)
bag_dataset = BagDataset(tiles_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=self.max_bag_size,
shuffle_samples=shuffle,
generator=generator)
transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN)
# Save and restore PRNG state for consistency across (pre-)caching options
generator_state = generator.get_state()
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore
generator.set_state(generator_state)
# Dataset is saved if cache_dir is True, regardless of CacheMode
if dataset_pickle_path:
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True)
with dataset_pickle_path.open('wb') as f:
torch.save(transformed_bag_dataset, f)
return transformed_bag_dataset
def _get_transformed_dataset(self, base_dataset: BagDataset,
transform: Union[Sequence[Callable], Callable]) -> Dataset:
if self.cache_mode is CacheMode.MEMORY:
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore
elif self.cache_mode is CacheMode.DISK:
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore
if self.precache_location != CacheLocation.NONE:
import tqdm # TODO: Make optional
for i in tqdm.trange(len(dataset), desc="Loading dataset"):
dataset[i] # empty loop to pre-compute all transformed samples
else:
dataset = Dataset(base_dataset, transform) # type: ignore
return dataset
def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool,
**dataloader_kwargs: Any) -> DataLoader:
transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle)
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
generator = bag_dataset.bag_sampler.generator
return DataLoader(transformed_bag_dataset, batch_size=self.batch_size,
collate_fn=multibag_collate, shuffle=shuffle, generator=generator,
pin_memory=False, # disable pinning as loaded data may already be on GPU
**dataloader_kwargs)
def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, 'train', shuffle=True)
def val_dataloader(self) -> DataLoader:
return self._get_dataloader(self.val_dataset, 'val', shuffle=True)
def test_dataloader(self) -> DataLoader:
return self._get_dataloader(self.test_dataset, 'test', shuffle=True)

Просмотреть файл

@ -1,31 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Tuple, Any
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits
class PandaTilesDataModule(TilesDataModule):
""" PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset
Method get_splits() returns the train, val, test splits from the PANDA dataset
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
dataset = PandaTilesDataset(self.root_path)
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
proportion_train=.8,
proportion_test=.1,
proportion_val=.1,
subject_column=dataset.TILE_ID_COLUMN,
group_column=dataset.SLIDE_ID_COLUMN)
return (PandaTilesDataset(self.root_path, dataset_df=splits.train),
PandaTilesDataset(self.root_path, dataset_df=splits.val),
PandaTilesDataset(self.root_path, dataset_df=splits.test))

Просмотреть файл

@ -1,38 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Tuple, Any
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits
class TcgaCrckTilesDataModule(TilesDataModule):
""" TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset
Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset
Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]:
trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True)
splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(),
proportion_train=0.8,
proportion_test=0.0,
proportion_val=0.2,
subject_column=trainval_dataset.TILE_ID_COLUMN,
group_column=trainval_dataset.SLIDE_ID_COLUMN,
random_seed=5)
if self.number_of_cross_validation_splits > 1:
# Function get_k_fold_cross_validation_splits() will concatenate train and val splits
splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index]
return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train),
TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val),
TcgaCrck_TilesDataset(self.root_path, train=False))

Просмотреть файл

@ -1,220 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset
from InnerEye.ML.Histopathology.utils.naming import SlideKey
class TilesDataset(Dataset):
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.
:param TILE_ID_COLUMN: CSV column name for tile ID.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch.
:param LABEL_COLUMN: CSV column name for tile label.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional).
:param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional).
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
"""
TILE_ID_COLUMN: str = 'tile_id'
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
PATH_COLUMN: str = 'image_path'
LABEL_COLUMN: str = 'label'
SPLIT_COLUMN: Optional[str] = 'split'
TILE_X_COLUMN: Optional[str] = 'tile_x'
TILE_Y_COLUMN: Optional[str] = 'tile_y'
TRAIN_SPLIT_LABEL: str = 'train'
TEST_SPLIT_LABEL: str = 'test'
DEFAULT_CSV_FILENAME: str = "dataset.csv"
N_CLASSES: int = 1 # binary classification by default
def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
self.root_dir = Path(root)
if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN,
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
for column in columns:
if column is not None and column not in dataset_df.columns:
raise ValueError(f"Expected column '{column}' not found in the dataframe")
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
def __len__(self) -> int:
return self.dataset_df.shape[0]
def __getitem__(self, index: int) -> Dict[str, Any]:
tile_id = self.dataset_df.index[index]
sample = {
self.TILE_ID_COLUMN: tile_id,
**self.dataset_df.loc[tile_id].to_dict()
}
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
# we're replicating this column because we want to propagate the path to the batch
sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN]
return sample
@property
def slide_ids(self) -> pd.Series:
return self.dataset_df[self.SLIDE_ID_COLUMN]
def get_slide_labels(self) -> pd.Series:
return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode)
def get_class_weights(self) -> torch.Tensor:
slide_labels = self.get_slide_labels()
classes = np.unique(slide_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
return torch.as_tensor(class_weights)
class SlidesDataset(Dataset):
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
The output dictionaries are indexed by `..utils.naming.SlideKey`.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param LABEL_COLUMN: CSV column name for tile label.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
"""
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
LABEL_COLUMN: str = 'label'
MASK_COLUMN: Optional[str] = None
SPLIT_COLUMN: Optional[str] = None
TRAIN_SPLIT_LABEL: str = 'train'
TEST_SPLIT_LABEL: str = 'test'
METADATA_COLUMNS: Tuple[str, ...] = ()
DEFAULT_CSV_FILENAME: str = "dataset.csv"
N_CLASSES: int = 1 # binary classification by default
def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None,
validate_columns: bool = True) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
:param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
self.root_dir = Path(root)
if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
if validate_columns:
self.validate_columns()
def validate_columns(self) -> None:
"""Check that loaded dataframe contains expected columns, raises `ValueError` otherwise.
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
call `validate_columns()` after creating derived columns, for example.
"""
columns = [self.IMAGE_COLUMN, self.LABEL_COLUMN, self.MASK_COLUMN,
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
for column in columns:
if column is not None and column not in self.dataset_df.columns:
raise ValueError(f"Expected column '{column}' not found in the dataframe")
def __len__(self) -> int:
return self.dataset_df.shape[0]
def __getitem__(self, index: int) -> Dict[SlideKey, Any]:
slide_id = self.dataset_df.index[index]
slide_row = self.dataset_df.loc[slide_id]
sample = {SlideKey.SLIDE_ID: slide_id}
rel_image_path = slide_row[self.IMAGE_COLUMN]
sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path)
# we're replicating this column because we want to propagate the path to the batch
sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE]
if self.MASK_COLUMN:
rel_mask_path = slide_row[self.MASK_COLUMN]
sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path)
sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK]
sample[SlideKey.LABEL] = slide_row[self.LABEL_COLUMN]
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS}
return sample
@classmethod
def has_mask(cls) -> bool:
return cls.MASK_COLUMN is not None

Просмотреть файл

@ -1,15 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
PANDA_DATASET_ID = "PANDA"
PANDA_TILES_DATASET_ID = "PANDA_tiles"
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"
DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
PANDA_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_DATASET_ID
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID

Просмотреть файл

@ -1,128 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Dict, Union, Optional
import pandas as pd
from health_ml.utils import box_utils
from monai.config import KeysCollection
from monai.data.image_reader import ImageReader, WSIReader
from monai.transforms import MapTransform
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
try:
from cucim import CuImage
except:
logging.warning("cucim library not available, code may fail.")
class PandaDataset(SlidesDataset):
"""Dataset class for loading files from the PANDA challenge dataset.
Iterating over this dataset returns a dictionary following the `SlideKey` schema plus meta-data
from the original dataset (`'data_provider'`, `'isup_grade'`, and `'gleason_score'`).
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
"""
SLIDE_ID_COLUMN = 'image_id'
IMAGE_COLUMN = 'image'
MASK_COLUMN = 'mask'
LABEL_COLUMN = 'isup_grade'
METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score')
DEFAULT_CSV_FILENAME = "train.csv"
def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None) -> None:
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
# PANDA CSV does not come with paths for image and mask files
slide_ids = self.dataset_df.index
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
self.validate_columns()
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
class ReadImaged(MapTransform):
"""Basic transform to read image files."""
def __init__(self, reader: ImageReader, keys: KeysCollection,
allow_missing_keys: bool = False, **kwargs: Any) -> None:
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.reader = reader
self.kwargs = kwargs
def __call__(self, data: Dict) -> Dict:
for key in self.keys:
if key in data or not self.allow_missing_keys:
data[key] = self.reader.read(data[key], **self.kwargs)
return data
class LoadPandaROId(MapTransform):
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI).
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
- `'location'` (tuple): top-right coordinates of the bounding box
- `'size'` (tuple): width and height of the bounding box
- `'level'` (int): chosen magnification level
- `'scale'` (float): corresponding scale, loaded from the file
"""
def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = 'mask',
level: int = 0, margin: int = 0, **kwargs: Any) -> None:
"""
:param reader: And instance of MONAI's `WSIReader`.
:param image_key: Image key in the input and output dictionaries.
:param mask_key: Mask key in the input and output dictionaries.
:param level: Magnification level to load from the raw multi-scale files.
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
"""
super().__init__([image_key, mask_key], allow_missing_keys=False)
self.reader = reader
self.image_key = image_key
self.mask_key = mask_key
self.level = level
self.margin = margin
self.kwargs = kwargs
def _get_bounding_box(self, mask_obj: 'CuImage') -> box_utils.Box:
# Estimate bounding box at the lowest resolution (i.e. highest level)
highest_level = mask_obj.resolutions['level_count'] - 1
scale = mask_obj.resolutions['level_downsamples'][highest_level]
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
return bbox
def __call__(self, data: Dict) -> Dict:
mask_obj: CuImage = self.reader.read(data[self.mask_key])
image_obj: CuImage = self.reader.read(data[self.image_key])
level0_bbox = self._get_bounding_box(mask_obj)
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
scale = mask_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
size=(scaled_bbox.w, scaled_bbox.h),
level=self.level)
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
data.update(get_data_kwargs)
data['scale'] = scale
mask_obj.close()
image_obj.close()
return data

Просмотреть файл

@ -1,90 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import pandas as pd
from torchvision.datasets.vision import VisionDataset
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
class PandaTilesDataset(TilesDataset):
"""
Dataset class for loading PANDA tiles.
Iterating over this dataset returns a dictionary containing:
- `'slide_id'` (str): parent slide ID (`'image_id'` in the PANDA dataset)
- `'tile_id'` (str)
- `'image'` (`PIL.Image`): RGB tile
- `'mask'` (str): path to mask PNG file
- `'tile_x'`, `'tile_y'` (int): top-right tile coordinates
- `'data_provider'`, `'slide_isup_grade'`, `'slide_gleason_score'` (str): parent slide metadata
"""
LABEL_COLUMN = "slide_isup_grade"
SPLIT_COLUMN = None # PANDA does not have an official train/test split
N_CLASSES = 6
_RELATIVE_ROOT_FOLDER = Path("PANDA_tiles_20210926-135446/panda_tiles_level1_224")
def __init__(self,
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
occupancy_threshold: Optional[float] = None) -> None:
super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER,
dataset_csv=dataset_csv,
dataset_df=dataset_df,
train=None)
if occupancy_threshold is not None:
dataset_df_filtered = self.dataset_df.loc[self.dataset_df['occupancy'] > occupancy_threshold] # type: ignore
self.dataset_df = dataset_df_filtered
class PandaTilesDatasetReturnImageLabel(VisionDataset):
"""
Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a
class label.
"""
occupancy_threshold = 0
def __init__(self,
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
transform: Optional[Callable] = None,
**kwargs: Any) -> None:
super().__init__(root=root, transform=transform)
self.base_dataset = PandaTilesDataset(root=root,
dataset_csv=dataset_csv,
dataset_df=dataset_df,
occupancy_threshold=self.occupancy_threshold)
def __getitem__(self, index: int) -> Tuple: # type: ignore
sample = self.base_dataset[index]
# TODO change to a meaningful evaluation
image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN])
if self.transform:
image = self.transform(image)
# get binary label
label = 0 if sample[self.base_dataset.LABEL_COLUMN] == 0 else 1
return image, label
def __len__(self) -> int:
return len(self.base_dataset)
class PandaTilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex, PandaTilesDatasetReturnImageLabel):
"""
Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData.
This class is just a shorthand notation for this double inheritance. Please note that this class needs
to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel.
"""
@property
def num_classes(self) -> int:
return 2

Просмотреть файл

@ -1,70 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import pandas as pd
from torchvision.datasets.vision import VisionDataset
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
class TcgaCrck_TilesDataset(TilesDataset):
"""Dataset class for loading TCGA-CRCk tiles.
Iterating over this dataset returns a dictionary containing:
- `'slide_id'` (str): parent slide ID
- `'tile_id'` (str)
- `'image'` (`PIL.Image`): RGB tile
- `'label'` (str): MSS (0) vs MSIMUT (1)
"""
TILE_X_COLUMN = TILE_Y_COLUMN = None # no tile coordinates available
# This dataset conforms to all other defaults in TilesDataset
class TcgaCrck_TilesDatasetReturnImageLabel(VisionDataset):
"""
Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a
class label.
"""
def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None,
transform: Optional[Callable] = None,
**kwargs: Any) -> None:
super().__init__(root=root, transform=transform)
self.base_dataset = TcgaCrck_TilesDataset(root=root,
dataset_csv=dataset_csv,
dataset_df=dataset_df,
train=train)
def __getitem__(self, index: int) -> Tuple: # type: ignore
sample = self.base_dataset[index]
# TODO change to a meaningful evaluation
image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN])
if self.transform:
image = self.transform(image)
return image, sample[self.base_dataset.LABEL_COLUMN]
def __len__(self) -> int:
return len(self.base_dataset)
class TcgaCrck_TilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex,
TcgaCrck_TilesDatasetReturnImageLabel):
"""
Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData.
This class is just a shorthand notation for this double inheritance. Please note that this class needs
to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel.
"""
@property
def num_classes(self) -> int:
return 2

Просмотреть файл

@ -1,42 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Optional, Union
import pandas as pd
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
class TcgaPradDataset(SlidesDataset):
"""Dataset class for loading TCGA-PRAD slides.
Iterating over this dataset returns a dictionary containing:
- `'slide_id'` (str)
- `'case_id'` (str)
- `'image_path'` (str): absolute slide image path
- `'label'` (int, 0 or 1): label for predicting positive or negative
"""
IMAGE_COLUMN: str = 'image_path'
LABEL_COLUMN: str = 'label'
DEFAULT_CSV_FILENAME: str = "dataset.csv"
def __init__(self, root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from
`"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
"""
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
# Example of how to define a custom label column from existing columns:
self.dataset_df[self.LABEL_COLUMN] = (self.dataset_df['label1']
| self.dataset_df['label2']).astype(int)
self.validate_columns()

Просмотреть файл

Просмотреть файл

@ -1,440 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
import pandas as pd
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, List
import torch
import matplotlib.pyplot as plt
import more_itertools as mi
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
from InnerEye.ML.Histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_tiles,
plot_scores_hist, plot_heatmap_overlay,
plot_slide, plot_normalized_confusion_matrix)
from InnerEye.ML.Histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey
from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict
from health_ml.utils import log_on_epoch
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
def _format_cuda_memory_stats() -> str:
return (f"GPU {torch.cuda.current_device()} memory: "
f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB allocated, "
f"{torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB reserved")
class DeepMILModule(LightningModule):
"""Base class for deep multiple-instance learning"""
def __init__(self,
label_column: str,
n_classes: int,
encoder: TileEncoder,
pooling_layer: Callable[[int, int, int], nn.Module],
pool_hidden_dim: int = 128,
pool_out_dim: int = 1,
dropout_rate: Optional[float] = None,
class_weights: Optional[Tensor] = None,
l_rate: float = 5e-4,
weight_decay: float = 1e-4,
adam_betas: Tuple[float, float] = (0.9, 0.99),
verbose: bool = False,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1,
class_names: Optional[List[str]] = None,
is_finetune: bool = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
:param encoder: The tile encoder to use for feature extraction. If no encoding is needed,
you should use `IdentityEncoder`.
:param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
:param pool_out_dim: Output dimension of pooling layer (default=1).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param class_weights: Tensor containing class weights (default=None).
:param l_rate: Optimiser learning rate.
:param weight_decay: Weight decay parameter for L2 regularisation.
:param adam_betas: Beta parameters for Adam optimiser.
:param verbose: if True statements about memory usage are output at each step.
:param slide_dataset: Slide dataset object, if available.
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
"""
super().__init__()
# Dataset specific attributes
self.label_column = label_column
self.n_classes = n_classes
self.pool_hidden_dim = pool_hidden_dim
self.pool_out_dim = pool_out_dim
self.pooling_layer = pooling_layer
self.dropout_rate = dropout_rate
self.class_weights = class_weights
self.encoder = encoder
self.num_encoding = self.encoder.num_encoding
if class_names is not None:
self.class_names = class_names
else:
if self.n_classes > 1:
self.class_names = [str(i) for i in range(self.n_classes)]
else:
self.class_names = ['0', '1']
if self.n_classes > 1 and len(self.class_names) != self.n_classes:
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes})")
if self.n_classes == 1 and len(self.class_names) != 2:
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes+1})")
# Optimiser hyperparameters
self.l_rate = l_rate
self.weight_decay = weight_decay
self.adam_betas = adam_betas
# Slide specific attributes
self.slide_dataset = slide_dataset
self.tile_size = tile_size
self.level = level
self.save_hyperparameters()
self.verbose = verbose
# Finetuning attributes
self.is_finetune = is_finetune
self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
self.activation_fn = self.get_activation()
# Metrics Objects
self.train_metrics = self.get_metrics()
self.val_metrics = self.get_metrics()
self.test_metrics = self.get_metrics()
def get_pooling(self) -> Tuple[Callable, int]:
pooling_layer = self.pooling_layer(self.num_encoding,
self.pool_hidden_dim,
self.pool_out_dim)
num_features = self.num_encoding*self.pool_out_dim
return pooling_layer, num_features
def get_classifier(self) -> Callable:
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
if self.dropout_rate is None:
return classifier_layer
elif 0 <= self.dropout_rate < 1:
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
else:
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
def get_loss(self) -> Callable:
if self.n_classes > 1:
if self.class_weights is None:
return nn.CrossEntropyLoss()
else:
class_weights = self.class_weights.float()
return nn.CrossEntropyLoss(weight=class_weights)
else:
pos_weight = None
if self.class_weights is not None:
pos_weight = Tensor([self.class_weights[1]/(self.class_weights[0]+1e-5)])
return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
def get_activation(self) -> Callable:
if self.n_classes > 1:
return nn.Softmax()
else:
return nn.Sigmoid()
@staticmethod
def get_bag_label(labels: Tensor) -> Tensor:
# Get bag (batch) labels as majority vote
bag_label = mode(labels).values
return bag_label.view(1)
def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes, average='micro'),
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
else:
threshold = 0.5
return nn.ModuleDict({MetricsKey.ACC: Accuracy(threshold=threshold),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(threshold=threshold),
MetricsKey.RECALL: Recall(threshold=threshold),
MetricsKey.F1: F1(threshold=threshold),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=2, threshold=threshold)})
def log_metrics(self,
stage: str) -> None:
valid_stages = ['train', 'test', 'val']
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
if metric_name == MetricsKey.CONF_MATRIX:
metric_value = metric_object.compute()
metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True)
for i in range(metric_value_n.shape[0]):
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(1, -1)
bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay,
betas=self.adam_betas)
def get_metrics_dict(self, stage: str) -> nn.ModuleDict:
return getattr(self, f'{stage}_metrics')
def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsKey, Tensor]:
# The batch dict contains lists of tensors of different sizes, for all bags in the batch.
# This means we can't stack them along a new axis without padding to the same length.
# We could alternatively concatenate them, but this would require other changes (e.g. in
# the attention layers) to correctly split the tensors by bag/slide ID.
bag_labels_list = []
bag_logits_list = []
bag_attn_list = []
for bag_idx in range(len(batch[self.label_column])):
images = batch[TilesDataset.IMAGE_COLUMN][bag_idx]
labels = batch[self.label_column][bag_idx]
bag_labels_list.append(self.get_bag_label(labels))
logit, attn = self(images)
bag_logits_list.append(logit.view(-1))
bag_attn_list.append(attn)
bag_logits = torch.stack(bag_logits_list)
bag_labels = torch.stack(bag_labels_list).view(-1)
if self.n_classes > 1:
loss = self.loss_fn(bag_logits, bag_labels.long())
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())
predicted_probs = self.activation_fn(bag_logits)
if self.n_classes > 1:
predicted_labels = argmax(predicted_probs, dim=1)
probs_perclass = predicted_probs
else:
predicted_labels = round(predicted_probs)
probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()] for i in range(len(predicted_probs))])
loss = loss.view(-1, 1)
predicted_labels = predicted_labels.view(-1, 1)
if self.n_classes == 1:
predicted_probs = predicted_probs.view(-1, 1)
bag_labels = bag_labels.view(-1, 1)
results = dict()
for metric_object in self.get_metrics_dict(stage).values():
if self.n_classes > 1:
metric_object.update(predicted_probs, bag_labels.squeeze())
else:
metric_object.update(predicted_probs, bag_labels)
results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN],
ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN],
ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss,
ResultsKey.PROB: predicted_probs, ResultsKey.CLASS_PROBS: probs_perclass,
ResultsKey.PRED_LABEL: predicted_labels,
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})
if (TilesDataset.TILE_X_COLUMN in batch.keys()) and (TilesDataset.TILE_Y_COLUMN in batch.keys()):
results.update({ResultsKey.TILE_X: batch[TilesDataset.TILE_X_COLUMN],
ResultsKey.TILE_Y: batch[TilesDataset.TILE_Y_COLUMN]}
)
else:
logging.warning("Coordinates not found in batch. If this is not expected check your input tiles dataset.")
return results
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
train_result = self._shared_step(batch, batch_idx, 'train')
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
if self.verbose:
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
self.log_metrics('train')
return train_result[ResultsKey.LOSS]
def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
val_result = self._shared_step(batch, batch_idx, 'val')
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log_metrics('val')
return val_result[ResultsKey.LOSS]
def test_step(self, batch: Dict, batch_idx: int) -> Dict[ResultsKey, Any]: # type: ignore
test_result = self._shared_step(batch, batch_idx, 'test')
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log_metrics('test')
return test_result
def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
# example of batch_key ResultsKey.SLIDE_ID_COL
# for batch keys that contains multiple values for slides e.g. ResultsKey.BAG_ATTN_COL
# outputs[batch_idx][batch_key][bag_idx][tile_idx]
# contains the tile value
# collate the batches
results: Dict[str, List[Any]] = {}
[results.update({col: []}) for col in outputs[0].keys()]
for key in results.keys():
for batch_id in range(len(outputs)):
results[key] += outputs[batch_id][key]
print("Saving outputs ...")
# collate at slide level
list_slide_dicts = []
list_encoded_features = []
# any column can be used here, the assumption is that the first dimension is the N of slides
for slide_idx in range(len(results[ResultsKey.SLIDE_ID])):
slide_dict = dict()
for key in results.keys():
if key not in [ResultsKey.IMAGE, ResultsKey.LOSS]:
slide_dict[key] = results[key][slide_idx]
list_slide_dicts.append(slide_dict)
list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx])
outputs_path = fixed_paths.repository_parent_directory() / 'outputs'
print(f"Metrics results will be output to {outputs_path}")
outputs_fig_path = outputs_path / 'fig'
csv_filename = outputs_path / 'test_output.csv'
encoded_features_filename = outputs_path / 'test_encoded_features.pickle'
# Collect the list of dictionaries in a list of pandas dataframe and save
df_list = []
for slide_dict in list_slide_dicts:
slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False)
df_list.append(pd.DataFrame.from_dict(slide_dict))
df = pd.concat(df_list, ignore_index=True)
df.to_csv(csv_filename, mode='w', header=True)
# Collect all features in a list and save
features_list = self.move_list_to_device(list_encoded_features, use_gpu=False)
torch.save(features_list, encoded_features_filename)
print("Selecting tiles ...")
# Class 0
tn_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'highest_att'))
tn_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'lowest_att'))
fp_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'highest_att'))
fp_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'lowest_att'))
report_cases = {'TN': [tn_top_tiles, tn_bottom_tiles], 'FP': [fp_top_tiles, fp_bottom_tiles]}
# Class 1 to n_classes-1
n_classes_to_select = self.n_classes if self.n_classes > 1 else 2
for i in range(1, n_classes_to_select):
fn_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'highest_att'))
fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'lowest_att'))
tp_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'highest_att'))
tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'lowest_att'))
report_cases.update({'TP_'+str(i): [tp_top_tiles, tp_bottom_tiles], 'FN_'+str(i): [fn_top_tiles, fn_bottom_tiles]})
for key in report_cases.keys():
print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...")
key_folder_path = outputs_fig_path / f'{key}'
Path(key_folder_path).mkdir(parents=True, exist_ok=True)
nslides = len(report_cases[key][0])
for i in range(nslides):
slide, score, paths, top_attn = report_cases[key][0][i]
fig = plot_attention_tiles(slide, score, paths, top_attn, key + '_top', ncols=4)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png'))
slide, score, paths, bottom_attn = report_cases[key][1][i]
fig = plot_attention_tiles(slide, score, paths, bottom_attn, key + '_bottom', ncols=4)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png'))
if self.slide_dataset is not None:
slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore
_ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.LOCATION]
fig = plot_slide(slide_image=slide_image, scale=1.0)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png'))
fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results,
location_bbox=location_bbox, tile_size=self.tile_size, level=self.level)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png'))
print("Plotting histogram ...")
fig = plot_scores_hist(results)
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png')
print("Computing and saving confusion matrix...")
metrics_dict = self.get_metrics_dict('test')
cf_matrix = metrics_dict[MetricsKey.CONF_MATRIX].compute()
cf_matrix = np.array(cf_matrix.cpu())
# We can't log tensors in the normal way - just print it to console
print('test/confusion matrix:')
print(cf_matrix)
# Save the normalized confusion matrix as a figure in outputs
cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1, keepdims=True)
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=self.class_names)
self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png')
@staticmethod
def save_figure(fig: plt.figure, figpath: Path) -> None:
fig.savefig(figpath, bbox_inches='tight')
@staticmethod
def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict:
# slide-level dictionaries are processed by making value dimensions uniform and converting to numpy arrays.
# these steps are required to convert the dictionary to pandas dataframe.
device = 'cuda' if use_gpu else 'cpu'
dict_new = dict()
bag_size = len(dict_old[ResultsKey.SLIDE_ID])
for key, value in dict_old.items():
if key not in [ResultsKey.CLASS_PROBS, ResultsKey.PROB]:
if isinstance(value, Tensor):
value = value.squeeze(0).to(device).numpy()
if value.ndim == 0:
value = np.full(bag_size, fill_value=value)
dict_new[key] = value
elif key == ResultsKey.CLASS_PROBS:
if isinstance(value, Tensor):
value = value.squeeze(0).to(device).numpy()
for i in range(len(value)):
dict_new[key+str(i)] = np.repeat(value[i], bag_size)
return dict_new
@staticmethod
def move_list_to_device(list_encoded_features: List, use_gpu: bool) -> List:
# a list of features on cpu obtained from original list on gpu
features_list = []
device = 'cuda' if use_gpu else 'cpu'
for feature in list_encoded_features:
feature = feature.squeeze(0).to(device)
features_list.append(feature)
return features_list

Просмотреть файл

@ -1,147 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple
import numpy as np
import torch
from pl_bolts.models.self_supervised import SimCLR
from torch import nn
from torchvision.models import resnet18
from torchvision.transforms import Compose
from InnerEye.ML.Histopathology.utils.layer_utils import (get_imagenet_preprocessing,
load_weights_to_model,
setup_feature_extractor)
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.utils import create_ssl_image_classifier
class TileEncoder(nn.Module):
"""Base tile encoder class for use in dataset transforms or as part of a bigger model"""
def __init__(self, tile_size: int = 0, n_channels: int = 3,
input_dim: Optional[Sequence[int]] = None) -> None:
"""The `TileEncoder` constructor should be called after setting any attributes needed in
`_get_preprocessing()` or `_get_encoder()`.
:param tile_size: Tile width/height, in pixels.
:param n_channels: Number of channels in the tile (default=3).
:param input_dim: Input shape, to override default of `(n_channels, tile_size, tile_size)`.
"""
super().__init__()
if input_dim is None:
if tile_size == 0:
raise ValueError("Either input_dim or tile_size must be specified")
input_dim = (n_channels, tile_size, tile_size)
self.input_dim = tuple(input_dim)
self.preprocessing_fn = self._get_preprocessing()
self.feature_extractor_fn, self.num_encoding = self._get_encoder()
def _get_preprocessing(self) -> Callable:
return Compose([])
def _get_encoder(self) -> Tuple[Callable, int]:
raise NotImplementedError
def forward(self, images: torch.Tensor) -> torch.Tensor:
prep_images = self.preprocessing_fn(images)
return self.feature_extractor_fn(prep_images)
class IdentityEncoder(TileEncoder):
"""Dummy encoder that just flattens the input"""
def _get_encoder(self) -> Tuple[Callable, int]:
return nn.Flatten(), np.prod(self.input_dim)
class ImageNetEncoder(TileEncoder):
"""Feature extractor pretrained for classification on ImageNet"""
def __init__(self, feature_extraction_model: Callable[..., nn.Module],
tile_size: int, n_channels: int = 3) -> None:
"""
:param feature_extraction_model: A function accepting a `pretrained` keyword argument that
returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`.
:param tile_size: Tile width/height, in pixels.
:param n_channels: Number of channels in the tile (default=3).
"""
self.create_feature_extractor_fn = feature_extraction_model
super().__init__(tile_size=tile_size, n_channels=n_channels)
def _get_preprocessing(self) -> Callable:
return get_imagenet_preprocessing()
def _get_encoder(self) -> Tuple[Callable, int]:
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
return setup_feature_extractor(pretrained_model, self.input_dim) # type: ignore
class ImageNetSimCLREncoder(TileEncoder):
"""SimCLR encoder pretrained on ImageNet"""
WEIGHTS_URL = ("https://pl-bolts-weights.s3.us-east-2.amazonaws.com/"
"simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt")
EMBEDDING_DIM = 2048
def _get_preprocessing(self) -> Callable:
return get_imagenet_preprocessing()
def _get_encoder(self) -> Tuple[SimCLR, int]:
simclr = SimCLR.load_from_checkpoint(self.WEIGHTS_URL, strict=False)
simclr.freeze()
return simclr, self.EMBEDDING_DIM
class InnerEyeSSLEncoder(TileEncoder):
"""SSL encoder trained on Azure ML using InnerEye"""
def __init__(self, pl_checkpoint_path: Path, tile_size: int, n_channels: int = 3) -> None:
"""
:param pl_checkpoint_path: The path of the downloaded checkpoint file.
:param tile_size: Tile width/height, in pixels.
:param n_channels: Number of channels in the tile (default=3).
"""
self.pl_checkpoint_path = pl_checkpoint_path
super().__init__(tile_size=tile_size, n_channels=n_channels)
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
model: SSLClassifier = create_ssl_image_classifier( # type: ignore
num_classes=1, # dummy value
freeze_encoder=True,
pl_checkpoint_path=str(self.pl_checkpoint_path)
)
encoder = model.encoder # type: ignore
for param in encoder.parameters():
param.requires_grad = False # freeze_encoder does not disable gradients
classifier_head = model.classifier_head
embedding_dim = classifier_head.n_input # type: ignore
return encoder, embedding_dim
class HistoSSLEncoder(TileEncoder):
"""HistoSSL encoder pretrained on multiple histological datasets
Reference:
- Ciga, Xu, Martel (2021). Self supervised contrastive learning for digital histopathology.
arXiv:2011.13971
"""
WEIGHTS_URL = ("https://github.com/ozanciga/self-supervised-histopathology/releases/"
"download/tenpercent/tenpercent_resnet18.ckpt")
def _get_encoder(self) -> Tuple[Callable, int]:
resnet18_model = resnet18(pretrained=False)
num_features = resnet18_model.fc.in_features
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
histossl_encoder.fc = torch.nn.Sequential()
for param in histossl_encoder.parameters():
param.requires_grad = False
return histossl_encoder, num_features # type: ignore

Просмотреть файл

@ -1,169 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Mapping, Sequence, Union
import torch
import numpy as np
import PIL
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform, Randomizable
from torchvision.transforms.functional import to_tensor
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
PathOrString = Union[Path, str]
def load_pil_image(image_path: PathOrString) -> PIL.Image.Image:
"""Load a PIL image in RGB format from the given path"""
with PIL.PngImagePlugin.PngImageFile(image_path) as pil_png:
image = np.asarray(pil_png)
return image
def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor:
"""Load an image as a tensor from the given path"""
pil_image = load_pil_image(image_path)
return to_tensor(pil_image)
def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
progress: bool = False) -> torch.Tensor:
"""Load a batch of images of the same size as a tensor from the given paths"""
loading_generator = (load_image_as_tensor(path) for path in image_paths)
if progress:
from tqdm import tqdm
loading_generator = tqdm(loading_generator, desc="Loading image stack",
total=len(image_paths), leave=False)
image_tensors = list(loading_generator)
return torch.stack(image_tensors, dim=0)
class LoadTiled(MapTransform):
"""Dictionary transform to load an individual image tile as a tensor from an input path"""
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys)
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
for key in self.key_iterator(out_data):
out_data[key] = load_image_as_tensor(data[key])
return out_data
class LoadTilesBatchd(MapTransform):
"""Dictionary transform to load a batch of image tiles as a tensor from a list of input paths"""
# Cannot reuse MONAI readers because they support stacking only images with no channels
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False,
progress: bool = False) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
:param progress: Whether to display a tqdm progress bar.
"""
super().__init__(keys, allow_missing_keys)
self.progress = progress
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
for key in self.key_iterator(out_data):
out_data[key] = load_image_stack_as_tensor(data[key], progress=self.progress)
return out_data
class EncodeTilesBatchd(MapTransform):
"""Dictionary transform to extract features from a batch tensor of image tiles"""
def __init__(self,
keys: KeysCollection,
encoder: TileEncoder,
allow_missing_keys: bool = False,
chunk_size: int = 0) -> None:
"""
:param keys: Key(s) for the image tensor(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
:param chunk_size: if > 0, extracts features in chunks of size chunk_size.
"""
super().__init__(keys, allow_missing_keys)
self.encoder = encoder
self.chunk_size = chunk_size
@torch.no_grad()
def _encode_tiles(self, images: torch.Tensor) -> torch.Tensor:
device = next(self.encoder.parameters()).device
if self.chunk_size > 0:
embeddings = []
chunks = torch.split(images, self.chunk_size)
# TODO parallelize encoding - keep metadata and images aligned
for chunk in chunks:
chunk_embeddings = self._encode_images(chunk, device)
embeddings.append(chunk_embeddings)
return torch.cat(embeddings)
else:
return self._encode_images(images, device)
def _encode_images(self, images: torch.Tensor, device: torch.device) -> torch.Tensor:
images = images.to(device)
embeddings = self.encoder(images)
del images
torch.cuda.empty_cache()
return embeddings
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
for key in self.key_iterator(out_data):
out_data[key] = self._encode_tiles(data[key])
return out_data
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
if isinstance(data, (np.ndarray, torch.Tensor)):
return data[indices]
elif isinstance(data, Sequence):
return [data[i] for i in indices]
else:
raise ValueError(f"Data of type {type(data)} is not indexable")
class Subsampled(MapTransform, Randomizable):
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
def __init__(self, keys: KeysCollection, max_size: int,
allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for all batch elements that must be subsampled.
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
random down to `max_size` along their first dimension. If shorter, the elements are merely
shuffled.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.max_size = max_size
self._indices: np.ndarray
def randomize(self, total_size: int) -> None:
subsample_size = min(self.max_size, total_size)
self._indices = self.R.choice(total_size, size=subsample_size)
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
size = len(data[self.keys[0]])
self.randomize(size)
for key in self.key_iterator(out_data):
out_data[key] = take_indices(data[key], self._indices)
return out_data

Просмотреть файл

Просмотреть файл

@ -1,230 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""This script is specific to PANDA and is kept only for retrocompatibility.
`create_tiles_dataset.py` is the new supported way to process slide datasets.
"""
import functools
import os
import logging
import shutil
import traceback
import warnings
from pathlib import Path
from typing import Sequence, Tuple, Union
import numpy as np
import PIL
from monai.data import Dataset
from monai.data.image_reader import WSIReader
from tqdm import tqdm
from InnerEye.ML.Histopathology.preprocessing import tiling
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
TMP_SUFFIX = "_tmp"
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
-> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]:
if occupancy_threshold < 0. or occupancy_threshold > 1.:
raise ValueError("Tile occupancy threshold must be between 0 and 1")
foreground_mask = mask_tile > 0
occupancy = foreground_mask.mean(axis=(-2, -1))
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
path.parent.mkdir(parents=True, exist_ok=True)
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
pil_image = PIL.Image.fromarray(array_hwc)
pil_image.convert('RGB').save(path)
return pil_image
def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
constant_values=255)
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)
selected: np.ndarray
occupancies: np.ndarray
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
n_discarded = (~selected).sum()
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")
image_tiles = image_tiles[selected]
mask_tiles = mask_tiles[selected]
tile_locations = tile_locations[selected]
occupancies = occupancies[selected]
abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)
return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded
# TODO refactor this to separate metadata identification from saving. We might want the metadata
# even if the saving fails
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
tile_location: Sequence[int], output_dir: Path) -> dict:
slide_id = sample['image_id']
descriptor = get_tile_descriptor(tile_location)
image_tile_filename = f"train_images/{descriptor}.png"
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"
save_image(image_tile, output_dir / image_tile_filename)
save_image(mask_tile, output_dir / mask_tile_filename)
tile_metadata = {
'slide_id': slide_id,
'tile_id': get_tile_id(slide_id, tile_location),
'image': image_tile_filename,
'mask': mask_tile_filename,
'tile_x': tile_location[0],
'tile_y': tile_location[1],
'data_provider': sample['data_provider'],
'slide_isup_grade': sample['isup_grade'],
'slide_gleason_score': sample['gleason_score'],
}
return tile_metadata
def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
output_dir: Path, tile_progress: bool = False) -> None:
slide_id = sample['image_id']
slide_dir: Path = output_dir / (slide_id + "/")
logging.info(f">>> Slide dir {slide_dir}")
if slide_dir.exists(): # already processed slide - skip
logging.info(f">>> Skipping {slide_dir} - already processed")
return
else:
try:
slide_dir.mkdir(parents=True)
dataset_csv_path = slide_dir / "dataset.csv"
dataset_csv_file = dataset_csv_path.open('w')
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
tiles_failure = 0
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
failed_tiles_file = failed_tiles_csv_path.open('w')
failed_tiles_file.write('tile_id' + '\n')
logging.info(f"Loading slide {slide_id} ...")
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
sample = loader(sample) # load 'image' and 'mask' from disk
logging.info(f"Tiling slide {slide_id} ...")
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
generate_tiles(sample, tile_size, occupancy_threshold)
n_tiles = image_tiles.shape[0]
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
try:
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
slide_dir)
tile_metadata['occupancy'] = occupancies[i]
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
dataset_csv_file.write(dataset_row + '\n')
except Exception as e:
tiles_failure += 1
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
failed_tiles_file.write(descriptor)
traceback.print_exc()
warnings.warn(f"An error occurred while saving tile "
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
dataset_csv_file.close()
failed_tiles_file.close()
if tiles_failure > 0:
# TODO what we want to do with slides that have some failed tiles?
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
except Exception as e:
traceback.print_exc()
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
full_csv = dataset_dir / "dataset.csv"
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
# and it seems to find many more files
# print("List of files")
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
with full_csv.open('w') as full_csv_file:
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
first_file = True
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
logging.info(f"Merging slide {slide_csv}")
content = slide_csv.read_text()
if not first_file:
content = content[content.index('\n') + 1:] # discard header row for all but the first file
full_csv_file.write(content)
first_file = False
return full_csv
def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
# to select a subsample use keyword n_slides
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore
output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")
if overwrite and output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=not overwrite)
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
tile_progress=not parallel)
if parallel:
import multiprocessing
pool = multiprocessing.Pool()
map_func = pool.imap_unordered # type: ignore
else:
map_func = map # type: ignore
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
if parallel:
pool.close()
logging.info("Merging slide files in a single file")
merge_dataset_csv_files(output_dir)
if __name__ == '__main__':
main(panda_dir="/tmp/datasets/PANDA",
root_output_dir="/datadrive",
level=1,
tile_size=224,
margin=64,
occupancy_threshold=0.05,
parallel=True,
overwrite=False)

Просмотреть файл

@ -1,306 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import functools
import logging
import shutil
import traceback
import warnings
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
import numpy as np
import PIL
from monai.data import Dataset
from monai.data.image_reader import WSIReader
from tqdm import tqdm
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.preprocessing import tiling
from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, segment_foreground
from InnerEye.ML.Histopathology.utils.naming import SlideKey, TileKey
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
def select_tiles(foreground_mask: np.ndarray, occupancy_threshold: float) \
-> Tuple[np.ndarray, np.ndarray]:
"""Exclude tiles that are mostly background based on estimated occupancy.
:param foreground_mask: Boolean array of shape (*, H, W).
:param occupancy_threshold: Tiles with lower occupancy (between 0 and 1) will be discarded.
:return: A tuple containing which tiles were selected and the estimated occupancies. These will
be boolean and float arrays of shape (*,), or scalars if `foreground_mask` is a single tile.
"""
if occupancy_threshold < 0. or occupancy_threshold > 1.:
raise ValueError("Tile occupancy threshold must be between 0 and 1")
occupancy = foreground_mask.mean(axis=(-2, -1))
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() # type: ignore
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
"""Format the XY tile coordinates into a tile descriptor."""
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
"""Format the slide ID and XY tile coordinates into a unique tile ID."""
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
"""Save an image array in (C, H, W) format to disk."""
path.parent.mkdir(parents=True, exist_ok=True)
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
pil_image = PIL.Image.fromarray(array_hwc)
pil_image.convert('RGB').save(path)
return pil_image
def generate_tiles(slide_image: np.ndarray, tile_size: int, foreground_threshold: float,
occupancy_threshold: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""Split the foreground of an input slide image into tiles.
:param slide_image: The RGB image array in (C, H, W) format.
:param tile_size: Lateral dimensions of each tile, in pixels.
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
:return: A tuple containing the image tiles (N, C, H, W), tile coordinates (N, 2), occupancies
(N,), and total number of discarded empty tiles.
"""
image_tiles, tile_locations = tiling.tile_array_2d(slide_image, tile_size=tile_size,
constant_values=255)
foreground_mask, _ = segment_foreground(image_tiles, foreground_threshold)
selected, occupancies = select_tiles(foreground_mask, occupancy_threshold)
n_discarded = (~selected).sum()
logging.info(f"Percentage tiles discarded: {n_discarded / len(selected) * 100:.2f}")
image_tiles = image_tiles[selected]
tile_locations = tile_locations[selected]
occupancies = occupancies[selected]
return image_tiles, tile_locations, occupancies, n_discarded
def get_tile_info(sample: Dict[SlideKey, Any], occupancy: float, tile_location: Sequence[int],
rel_slide_dir: Path) -> Dict[TileKey, Any]:
"""Map slide information and tiling outputs into tile-specific information dictionary.
:param sample: Slide dictionary.
:param occupancy: Estimated tile foreground occuppancy.
:param tile_location: Tile XY coordinates.
:param rel_slide_dir: Directory where tiles are saved, relative to dataset root.
:return: Tile information dictionary.
"""
slide_id = sample[SlideKey.SLIDE_ID]
descriptor = get_tile_descriptor(tile_location)
rel_image_path = f"{rel_slide_dir}/{descriptor}.png"
tile_info = {
TileKey.SLIDE_ID: slide_id,
TileKey.TILE_ID: get_tile_id(slide_id, tile_location),
TileKey.IMAGE: rel_image_path,
TileKey.LABEL: sample[SlideKey.LABEL],
TileKey.TILE_X: tile_location[0],
TileKey.TILE_Y: tile_location[1],
TileKey.OCCUPANCY: occupancy,
TileKey.SLIDE_METADATA: {TileKey.from_slide_metadata_key(key): value
for key, value in sample[SlideKey.METADATA].items()}
}
return tile_info
def format_csv_row(tile_info: Dict[TileKey, Any], keys_to_save: Iterable[TileKey],
metadata_keys: Iterable[str]) -> str:
"""Format tile information dictionary as a row to write to a dataset CSV tile.
:param tile_info: Tile information dictionary.
:param keys_to_save: Which main keys to include in the row, and in which order.
:param metadata_keys: Likewise for metadata keys.
:return: The formatted CSV row.
"""
tile_slide_metadata = tile_info.pop(TileKey.SLIDE_METADATA)
fields = [str(tile_info[key]) for key in keys_to_save]
fields.extend(str(tile_slide_metadata[key]) for key in metadata_keys)
dataset_row = ','.join(fields)
return dataset_row
def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_size: int,
foreground_threshold: Optional[float], occupancy_threshold: float, output_dir: Path,
tile_progress: bool = False) -> None:
"""Load and process a slide, saving tile images and information to a CSV file.
:param sample: Slide information dictionary, returned by the input slide dataset.
:param level: Magnification level at which to process the slide.
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
:param tile_size: Lateral dimensions of each tile, in pixels.
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
If `None` (default), an optimal threshold will be estimated automatically.
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
:param output_dir: Root directory for the output dataset; outputs for a single slide will be
saved inside `output_dir/slide_id/`.
:param tile_progress: Whether to display a progress bar in the terminal.
"""
slide_metadata: Dict[str, Any] = sample[SlideKey.METADATA]
keys_to_save = (TileKey.SLIDE_ID, TileKey.TILE_ID, TileKey.IMAGE, TileKey.LABEL,
TileKey.TILE_X, TileKey.TILE_Y, TileKey.OCCUPANCY)
metadata_keys = tuple(TileKey.from_slide_metadata_key(key) for key in slide_metadata)
csv_columns: Tuple[str, ...] = (*keys_to_save, *metadata_keys)
slide_id: str = sample[SlideKey.SLIDE_ID]
rel_slide_dir = Path(slide_id)
slide_dir = output_dir / rel_slide_dir
logging.info(f">>> Slide dir {slide_dir}")
if slide_dir.exists(): # already processed slide - skip
logging.info(f">>> Skipping {slide_dir} - already processed")
return
else:
try:
slide_dir.mkdir(parents=True)
dataset_csv_path = slide_dir / "dataset.csv"
dataset_csv_file = dataset_csv_path.open('w')
dataset_csv_file.write(','.join(csv_columns) + '\n') # write CSV header
n_failed_tiles = 0
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
failed_tiles_file = failed_tiles_csv_path.open('w')
failed_tiles_file.write('tile_id' + '\n')
logging.info(f"Loading slide {slide_id} ...")
loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin,
foreground_threshold=foreground_threshold)
sample = loader(sample) # load 'image' from disk
logging.info(f"Tiling slide {slide_id} ...")
image_tiles, rel_tile_locations, occupancies, _ = \
generate_tiles(sample[SlideKey.IMAGE], tile_size,
sample[SlideKey.FOREGROUND_THRESHOLD],
occupancy_threshold)
tile_locations = (sample[SlideKey.SCALE] * rel_tile_locations
+ sample[SlideKey.ORIGIN]).astype(int)
n_tiles = image_tiles.shape[0]
logging.info(f"Saving tiles for slide {slide_id} ...")
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
try:
tile_info = get_tile_info(sample, occupancies[i], tile_locations[i], rel_slide_dir)
save_image(image_tiles[i], output_dir / tile_info[TileKey.IMAGE])
dataset_row = format_csv_row(tile_info, keys_to_save, metadata_keys)
dataset_csv_file.write(dataset_row + '\n')
except Exception as e:
n_failed_tiles += 1
descriptor = get_tile_descriptor(tile_locations[i])
failed_tiles_file.write(descriptor + '\n')
traceback.print_exc()
warnings.warn(f"An error occurred while saving tile "
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
dataset_csv_file.close()
failed_tiles_file.close()
if n_failed_tiles > 0:
# TODO what we want to do with slides that have some failed tiles?
logging.warning(f"{slide_id} is incomplete. {n_failed_tiles} tiles failed.")
logging.info(f"Finished processing slide {slide_id}")
except Exception as e:
traceback.print_exc()
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
"""Combines all "*/dataset.csv" files into a single "dataset.csv" file in the given directory."""
full_csv = dataset_dir / "dataset.csv"
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
# and it seems to find many more files
# print("List of files")
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
with full_csv.open('w') as full_csv_file:
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
first_file = True
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
logging.info(f"Merging slide {slide_csv}")
content = slide_csv.read_text()
if not first_file:
content = content[content.index('\n') + 1:] # discard header row for all but the first file
full_csv_file.write(content)
first_file = False
return full_csv
def main(slides_dataset: SlidesDataset, root_output_dir: Union[str, Path],
level: int, tile_size: int, margin: int, foreground_threshold: Optional[float],
occupancy_threshold: float, parallel: bool = False, overwrite: bool = False,
n_slides: Optional[int] = None) -> None:
"""Process a slides dataset to produce a tiles dataset.
:param slides_dataset: Input tiles dataset object.
:param root_output_dir: The root directory of the output tiles dataset.
:param level: Magnification level at which to process the slide.
:param tile_size: Lateral dimensions of each tile, in pixels.
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
If `None` (default), an optimal threshold will be estimated automatically.
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
:param parallel: Whether slides should be processed in parallel with multiprocessing.
:param overwrite: Whether to overwrite an existing output tiles dataset. If `True`, will delete
and recreate `root_output_dir`, otherwise will resume by skipping already processed slides.
:param n_slides: If given, limit the total number of slides for debugging.
"""
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
# to select a subsample use keyword n_slides
dataset = Dataset(slides_dataset)[:n_slides] # type: ignore
output_dir = Path(root_output_dir)
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} "
f"{slides_dataset.__class__.__name__} tiles at: {output_dir}")
if overwrite and output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=not overwrite)
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
foreground_threshold=foreground_threshold,
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
tile_progress=not parallel)
if parallel:
import multiprocessing
pool = multiprocessing.Pool()
map_func = pool.imap_unordered # type: ignore
else:
map_func = map # type: ignore
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
if parallel:
pool.close()
logging.info("Merging slide files in a single file")
merge_dataset_csv_files(output_dir)
if __name__ == '__main__':
from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset
# Example set up for an existing slides dataset:
main(slides_dataset=TcgaPradDataset("/tmp/datasets/TCGA-PRAD"),
root_output_dir="/datadrive/TCGA-PRAD_tiles",
n_slides=5,
level=3,
tile_size=224,
margin=64,
foreground_threshold=None,
occupancy_threshold=0.05,
parallel=False,
overwrite=True)

Просмотреть файл

@ -1,114 +0,0 @@
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import skimage.filters
from health_ml.utils import box_utils
from monai.data.image_reader import WSIReader
from monai.transforms import MapTransform
from InnerEye.ML.Histopathology.utils.naming import SlideKey
try:
from cucim import CuImage
except:
logging.warning("cucim library not available, code may fail.")
def get_luminance(slide: np.ndarray) -> np.ndarray:
"""Compute a grayscale version of the input slide.
:param slide: The RGB image array in (*, C, H, W) format.
:return: The single-channel luminance array as (*, H, W).
"""
# TODO: Consider more sophisticated luminance calculation if necessary
return slide.mean(axis=-3) # type: ignore
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
-> Tuple[np.ndarray, float]:
"""Segment the given slide by thresholding its luminance.
:param slide: The RGB image array in (*, C, H, W) format.
:param threshold: Pixels with luminance below this value will be considered foreground.
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
:return: A tuple containing the boolean output array in (*, H, W) format and the threshold used.
"""
luminance = get_luminance(slide)
if threshold is None:
threshold = skimage.filters.threshold_otsu(luminance)
return luminance < threshold, threshold
def load_slide_at_level(reader: WSIReader, slide_obj: 'CuImage', level: int) -> np.ndarray:
"""Load full slide array at the given magnification level.
This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415)
fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417).
:param reader: A MONAI `WSIReader` using cuCIM backend.
:param slide_obj: The cuCIM image object returned by `reader.read(<image_file>)`.
:param level: Index of the desired magnification level as defined in the `slide_obj` headers.
:return: The loaded image array in (C, H, W) format.
"""
size = slide_obj.resolutions['level_dimensions'][level][::-1]
slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image
return slide
class LoadROId(MapTransform):
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI).
Operates on dictionaries, replacing the file path in `image_key` with the loaded array in
(C, H, W) format. Also adds the following entries:
- `SlideKey.ORIGIN` (tuple): top-right coordinates of the bounding box
- `SlideKey.SCALE` (float): corresponding scale, loaded from the file
- `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground
"""
def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0,
margin: int = 0, foreground_threshold: Optional[float] = None) -> None:
"""
:param reader: And instance of MONAI's `WSIReader`.
:param image_key: Image key in the input and output dictionaries.
:param level: Magnification level to load from the raw multi-scale file.
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
:param foreground_threshold: Pixels with luminance below this value will be considered foreground.
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
"""
super().__init__([image_key], allow_missing_keys=False)
self.reader = reader
self.image_key = image_key
self.level = level
self.margin = margin
self.foreground_threshold = foreground_threshold
def _get_bounding_box(self, slide_obj: 'CuImage') -> Tuple[box_utils.Box, float]:
# Estimate bounding box at the lowest resolution (i.e. highest level)
highest_level = slide_obj.resolutions['level_count'] - 1
scale = slide_obj.resolutions['level_downsamples'][highest_level]
slide = load_slide_at_level(self.reader, slide_obj, level=highest_level)
foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold)
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
return bbox, threshold
def __call__(self, data: Dict) -> Dict:
image_obj: CuImage = self.reader.read(data[self.image_key])
level0_bbox, threshold = self._get_bounding_box(image_obj)
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
origin = (level0_bbox.x, level0_bbox.y)
scale = image_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
size=(scaled_bbox.w, scaled_bbox.h))
data[SlideKey.ORIGIN] = origin
data[SlideKey.SCALE] = scale
data[SlideKey.FOREGROUND_THRESHOLD] = threshold
image_obj.close()
return data

Просмотреть файл

@ -1,128 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# These tiling implementations are adapted from PANDA Kaggle solutions, for example:
# https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py
from typing import Any, Optional, Tuple
import numpy as np
def get_1d_padding(length: int, tile_size: int) -> Tuple[int, int]:
"""Computes symmetric padding for `length` to be divisible by `tile_size`."""
pad = (tile_size - length % tile_size) % tile_size
return (pad // 2, pad - pad // 2)
def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
"""Symmetrically pads a 2D `array` such that both dimensions are divisible by `tile_size`.
:param array: 2D image array.
:param tile_size: Width/height of each tile in pixels.
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
:return: A tuple containing:
- `padded_array`: Resulting array, in the same CHW/HWC layout as the input.
- `offset`: XY offset introduced by the padding. Add this to coordinates relative to the
original array to obtain indices for the padded array.
"""
height, width = array.shape[1:] if channels_first else array.shape[:-1]
padding_h = get_1d_padding(height, tile_size)
padding_w = get_1d_padding(width, tile_size)
padding = [padding_h, padding_w]
channels_axis = 0 if channels_first else 2
padding.insert(channels_axis, (0, 0)) # zero padding on channels axis
padded_array = np.pad(array, padding, **pad_kwargs)
offset = (padding_w[0], padding_h[0])
return padded_array, np.array(offset)
def tile_array_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
"""Split an image array into square non-overlapping tiles.
The array will be padded symmetrically if its dimensions are not exact multiples of `tile_size`.
:param array: Image array.
:param tile_size: Width/height of each tile in pixels.
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
:return: A tuple containing:
- `tiles`: A batch of tiles in NCHW layout.
- `coords`: XY coordinates of each tile, in the same order.
"""
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first, **pad_kwargs)
if channels_first:
channels, height, width = padded_array.shape
else:
height, width, channels = padded_array.shape
n_tiles_h = height // tile_size
n_tiles_w = width // tile_size
if channels_first:
intermediate_shape = (channels, n_tiles_h, tile_size, n_tiles_w, tile_size)
axis_order = (1, 3, 0, 2, 4) # (n_tiles_h, n_tiles_w, channels, tile_size, tile_size)
output_shape = (n_tiles_h * n_tiles_w, channels, tile_size, tile_size)
else:
intermediate_shape = (n_tiles_h, tile_size, n_tiles_w, tile_size, channels)
axis_order = (0, 2, 1, 3, 4) # (n_tiles_h, n_tiles_w, tile_size, tile_size, channels)
output_shape = (n_tiles_h * n_tiles_w, tile_size, tile_size, channels)
tiles = padded_array.reshape(intermediate_shape) # Split width and height axes
tiles = tiles.transpose(axis_order)
tiles = tiles.reshape(output_shape) # Flatten tile batch dimension
# Compute top-left coordinates of every tile, relative to the original array's origin
coords_h = tile_size * np.arange(n_tiles_h) - offset_h
coords_w = tile_size * np.arange(n_tiles_w) - offset_w
# Shape: (n_tiles_h * n_tiles_w, 2)
coords = np.stack(np.meshgrid(coords_w, coords_h), axis=-1).reshape(-1, 2)
return tiles, coords
def assemble_tiles_2d(tiles: np.ndarray, coords: np.ndarray, fill_value: Optional[float] = np.nan,
channels_first: Optional[bool] = True) -> Tuple[np.ndarray, np.ndarray]:
"""Assembles a 2D array from sequences of tiles and coordinates.
:param tiles: Stack of tiles with batch dimension first.
:param coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]).
:param tile_size: Size of each tile; must be >0.
:param fill_value: Value to assign to empty elements (default: `NaN`).
:param channels_first: Whether each tile is in CHW (`True`, default) or HWC (`False`) layout.
:return: A tuple containing:
- `array`: The reassembled 2D array with the smallest dimensions to contain all given tiles.
- `offset`: The lowest XY coordinates.
- `offset`: XY offset introduced by the assembly. Add this to tile coordinates to obtain
indices for the assembled array.
"""
if coords.shape[0] != tiles.shape[0]:
raise ValueError(f"Tile coordinates and values must have the same length, "
f"got {coords.shape[0]} and {tiles.shape[0]}")
if channels_first:
n_tiles, channels, tile_size, _ = tiles.shape
else:
n_tiles, tile_size, _, channels = tiles.shape
tile_xs, tile_ys = coords.T
x_min, x_max = min(tile_xs), max(tile_xs + tile_size)
y_min, y_max = min(tile_ys), max(tile_ys + tile_size)
width = x_max - x_min
height = y_max - y_min
output_shape = (channels, height, width) if channels_first else (height, width, channels)
array = np.full(output_shape, fill_value)
offset = np.array([-x_min, -y_min])
for idx in range(n_tiles):
row = coords[idx, 1] + offset[1]
col = coords[idx, 0] + offset[0]
if channels_first:
array[:, row:row + tile_size, col:col + tile_size] = tiles[idx]
else:
array[row:row + tile_size, col:col + tile_size, :] = tiles[idx]
return array, offset

Просмотреть файл

@ -1,40 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""
Script to find mean and standard deviation of desired metrics from cross validation child runs.
"""
import os
import pandas as pd
from health_azure import aggregate_hyperdrive_metrics, get_workspace
from InnerEye.Common import fixed_paths
def get_cross_validation_metrics_df(run_id: str) -> pd.DataFrame:
"""
Function to aggregate the metric over cross-validation runs
:param run_id: run id of the hyperdrive run containing child runs
"""
aml_workspace = get_workspace()
os.chdir(fixed_paths.repository_root_directory())
df = aggregate_hyperdrive_metrics(run_id=run_id,
child_run_arg_name="cross_validation_split_index",
aml_workspace=aml_workspace)
return df
if __name__ == "__main__":
metrics_list = ['test/accuracy', 'test/auroc', 'test/f1score', 'test/precision', 'test/recall']
run_id = "hsharma_features_viz:HD_eff4c009-2f9f-4c2c-94c6-c0c84944a412"
metrics_df = get_cross_validation_metrics_df(run_id=run_id)
for metric in metrics_list:
if metric in metrics_df.index.values:
mean = metrics_df.loc[[metric]].mean(axis=1)[metric]
std = metrics_df.loc[[metric]].std(axis=1)[metric]
print(f"{metric}: {round(mean,4)} ± {round(std,4)}")
else:
print(f"Metric {metric} not found in the Hyperdrive run metrics for run id {run_id}.")

Просмотреть файл

@ -1,26 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from health_azure import DatasetConfig
from health_azure.utils import get_workspace
def mount_dataset(dataset_id: str) -> str:
ws = get_workspace()
target_folder = "/tmp/datasets/"
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
dataset_mount_folder, mount_ctx = dataset.to_input_dataset_local(ws)
mount_ctx.start()
assert next(dataset_mount_folder.iterdir()), "Mounted data folder is empty"
return str(dataset_mount_folder)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--id', type=str, dest='dataset_id',
help='Name of the Azure dataset e.g. PANDA or TCGA-CRCk')
args = parser.parse_args()
mount_dataset(args.dataset_id)

Просмотреть файл

@ -1,98 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
from typing import List, Any
import umap
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
"""
Get the t-sne projection of high dimensional data in a lower dimensional space
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
:param **kwargs: keyword arguments to be passed to TSNE()
:return: list of features in lower dimensional space (n x c for n samples and c components)
"""
tsne_2d = TSNE(n_components=n_components, n_jobs=n_jobs, **kwargs)
tsne_proj = tsne_2d.fit_transform(features)
return tsne_proj
def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
"""
Get the umap projection of high dimensional data in a lower dimensional space
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
:param **kwargs: keyword arguments to be passed to UMAP()
:return: list of features in lower dimensional space (n x c for n samples and c components)
"""
umap_2d = umap.UMAP(n_components=n_components, n_jobs=n_jobs, **kwargs)
umap_proj = umap_2d.fit_transform(features)
return umap_proj
def normalize_array_minmax(arr: List[float]) -> List[float]:
"""
Normalize an array in range 0 to 1
:param arr: array to be normalized
:return: normalized array
"""
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
def normalize_array_mean(arr: List[float]) -> List[float]:
"""
Normalize an array with zero mean and unit variance
:param arr: array to be normalized
:return: normalized array
"""
return (arr - np.mean(arr)) / np.std(arr)
def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str], title: str = "") -> None:
"""
Plot a scatter plot of projected features in two dimensions
:param data: features projected in 2d space (nx2)
:param labels: corresponding labels of the data (nx1)
:param classes: list of classes in the dataset
:param title: plot title string
"""
plt.figure()
scatter = plt.scatter(data[:, 0], data[:, 1], 20, labels)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
plt.title(title)
def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outliers: bool, title: str = "") -> None:
"""
Plot a box whisker plot of column data
:param columns: data to be plotted in columns
:param column_names: names of the columns
:param show_outliers: whether outliers need to be shown
:param title: plot title string
"""
plt.figure()
_, ax = plt.subplots()
ax.boxplot(data_list, showfliers=show_outliers)
positions = range(1, len(column_names)+1)
means = []
for i in range(len(data_list)):
means.append(np.mean(data_list[i]))
ax.plot(positions, means, 'rs')
plt.xticks(positions, column_names)
plt.title(title)
def plot_histogram(data: List[Any], title: str = "") -> None:
"""
Plot a histogram given some data
:param data: data to be plotted
:param title: plot title string
"""
plt.figure()
plt.hist(data, bins=50)
plt.gca().set(title=title, xlabel='Values', ylabel='Frequency')

Просмотреть файл

@ -1,36 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from health_azure import download_files_from_run_id, get_workspace
from InnerEye.Common import fixed_paths
def download_file_if_necessary(run_id: str, remote_dir: Path, download_dir: Path, filename: str) -> None:
"""
Function to download any file from an AML run if it doesn't exist locally
:param run_id: run ID of the AML run
:param remote_dir: remote directory from where the file is downloaded
:param download_dir: local directory where to save the downloaded file
:param filename: name of the file to be downloaded (e.g. `"test_output.csv"`).
"""
aml_workspace = get_workspace()
os.chdir(fixed_paths.repository_root_directory())
local_path = download_dir / run_id.split(":")[1] / "outputs" / filename
remote_path = remote_dir / filename
if local_path.exists():
print("File already exists at", local_path)
else:
local_dir = local_path.parent.parent
local_dir.mkdir(exist_ok=True, parents=True)
download_files_from_run_id(run_id=run_id,
output_folder=local_dir,
prefix=str(remote_path),
aml_workspace=aml_workspace,
validate_checksum=True)
assert local_path.exists()
print("File is downloaded at", local_path)

Просмотреть файл

@ -1,31 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import List
import numpy as np
def location_selected_tiles(tile_coords: np.ndarray,
location_bbox: List[int],
level: int) -> np.ndarray:
""" Return the scaled and shifted tile co-ordinates for selected tiles in the slide.
:param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]) in original resolution.
:param location_bbox: Location of the bounding box on the slide in original resolution.
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available.
(e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled).
"""
level_dict = {0: 1, 1: 4, 2: 16}
factor = level_dict[level]
x_tr, y_tr = location_bbox
tile_xs, tile_ys = tile_coords.T
tile_xs = tile_xs - x_tr
tile_ys = tile_ys - y_tr
tile_xs = tile_xs//factor
tile_ys = tile_ys//factor
sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()])
return sel_coords

Просмотреть файл

@ -1,58 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Tuple
from torch import as_tensor, device, nn, no_grad, prod, rand
from torch.hub import load_state_dict_from_url
from torchvision.transforms import Normalize
def get_imagenet_preprocessing() -> nn.Module:
return Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def setup_feature_extractor(pretrained_model: nn.Module,
input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]:
try:
# Attempt to auto-detect final classification layer:
num_features: int = pretrained_model.fc.in_features # type: ignore
pretrained_model.fc = nn.Flatten()
feature_extractor = pretrained_model
except AttributeError:
# Otherwise fallback to sequence of child modules:
layers = list(pretrained_model.children())[:-1]
layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps
feature_extractor = nn.Sequential(*layers)
with no_grad():
feature_shape = feature_extractor(rand(1, *input_dim)).shape
num_features = int(prod(as_tensor(feature_shape)).item())
# fix weights, no fine-tuning
for param in feature_extractor.parameters():
param.requires_grad = False
return feature_extractor, num_features
def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module:
"""
Load weights to the histoSSL model from the given URL
https://github.com/ozanciga/self-supervised-histopathology
"""
map_location = device('cpu')
state = load_state_dict_from_url(weights_url, map_location=map_location)
state_dict = state['state_dict']
model_dict = model.state_dict()
new_weights = {}
for key, value in state_dict.items():
model_key = key.replace('model.', '').replace('resnet.', '')
if model_key in model_dict:
new_weights[model_key] = value
if len(new_weights) == 0:
raise RuntimeError("Weights could not be loaded.")
model_dict.update(new_weights) # type: ignore
model.load_state_dict(model_dict) # type: ignore
return model

Просмотреть файл

@ -1,184 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Tuple, List, Any, Dict
import torch
import matplotlib.pyplot as plt
from math import ceil
import numpy as np
import matplotlib.patches as patches
import matplotlib.collections as collection
import seaborn as sns
from InnerEye.ML.Histopathology.models.transforms import load_pil_image
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles
def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1,
select: Tuple = ('lowest_pred', 'highest_att'),
slide_col: str = ResultsKey.SLIDE_ID, gt_col: str = ResultsKey.TRUE_LABEL,
attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.CLASS_PROBS,
return_col: str = ResultsKey.IMAGE_PATH) -> List[Tuple[Any, Any, List[Any], List[Any]]]:
"""
:param results: List that contains slide_level dicts
:param n_tiles: number of tiles to be selected for each slide
:param n_slides: number of slides to be selected
:param label: which label to use to select slides
:param select: criteria to be used to sort the slides (select[0]) and the tiles (select[1])
:param slide_col: column name that contains slide identifiers
:param gt_col: column name that contains labels
:param attn_col: column name that contains scores used to sort tiles
:param prob_col: column name that contains scores used to sort slides
:param return_col: column name of the values we want to return for each tile
:return: tuple containing the slides id, the slide score, the tile ids, the tiles scores
"""
tmp_s = [(results[prob_col][i][label], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore
if select[0] == 'lowest_pred':
tmp_s.sort(reverse=False)
elif select[0] == 'highest_pred':
tmp_s.sort(reverse=True)
else:
ValueError('select value not recognised')
_, sorted_idx = zip(*tmp_s)
k_idx = []
if select[1] == 'highest_att':
descending = True
elif select[1] == 'lowest_att':
descending = False
for _, slide_idx in enumerate(sorted_idx[:n_slides]):
tmp = results[attn_col][slide_idx]
_, t_indices = torch.sort(tmp, descending=descending)
k_tiles = []
scores = []
for t_idx in t_indices[0][:n_tiles]:
k_tiles.append(results[return_col][slide_idx][t_idx])
scores.append(results[attn_col][slide_idx][0][t_idx])
# slide_ids are duplicated
k_idx.append((results[slide_col][slide_idx][0],
results[prob_col][slide_idx],
k_tiles, scores))
return k_idx
def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.CLASS_PROBS,
gt_col: str = ResultsKey.TRUE_LABEL) -> plt.figure:
"""
:param results: List that contains slide_level dicts
:param prob_col: column name that contains the scores
:param gt_col: column name that contains the true label
:return: matplotlib figure of the scores histogram by class
"""
n_classes = len(results[prob_col][0])
scores_class = []
for j in range(n_classes):
scores = [results[prob_col][i][j].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == j]
scores_class.append(scores)
fig, ax = plt.subplots()
ax.hist(scores_class, label=[str(i) for i in range(n_classes)], alpha=0.5)
ax.set_xlabel("Predicted Score")
ax.legend()
return fig
def plot_attention_tiles(slide: str, scores: List[float], paths: List, attn: List, case: str, ncols: int = 5,
size: Tuple = (10, 10)) -> plt.figure:
"""
:param slide: slide identifier
:param scores: predicted scores of each class for the slide
:param paths: list of paths to tiles belonging to the slide
:param attn: list of scores belonging to the tiles in paths. paths and attn are expected to have the same shape
:param case: string used to define the title of the plot e.g. TP
:param ncols: number of cols the produced figure should have
:param size: size of the plot
:return: matplotlib figure of each tile in paths with attn score
"""
nrows = int(ceil(len(paths) / ncols))
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=size)
fig.suptitle(f"{case}: {slide} P=%.2f" % max(scores))
for i in range(len(paths)):
img = load_pil_image(paths[i])
axs.ravel()[i].imshow(img, clim=(0, 255), cmap='gray')
axs.ravel()[i].set_title("%.6f" % attn[i].cpu().item())
for i in range(len(axs.ravel())):
axs.ravel()[i].set_axis_off()
return fig
def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure:
"""Plots a slide thumbnail from a given slide image and scale.
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
:return: matplotlib figure of the slide thumbnail.
"""
fig, ax = plt.subplots()
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
ax.set_axis_off()
original_size = fig.get_size_inches()
fig.set_size_inches((original_size[0]*scale, original_size[1]*scale))
return fig
def plot_heatmap_overlay(slide: str,
slide_image: np.ndarray,
results: Dict[str, List[Any]],
location_bbox: List[int],
tile_size: int = 224,
level: int = 1) -> plt.figure:
"""Plots heatmap of selected tiles (e.g. tiles in a bag) overlay on the corresponding slide.
:param slide: slide identifier.
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
:param tile_size: Size of each tile. Default 224.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1.
:param location_bbox: Location of the bounding box of the slide.
:return: matplotlib figure of the heatmap of the given tiles on slide.
"""
fig, ax = plt.subplots()
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
ax.set_xlim(0, slide_image.shape[1])
ax.set_ylim(slide_image.shape[0], 0)
coords = []
slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]]
slide_idx = slide_ids.index(slide)
attentions = results[ResultsKey.BAG_ATTN][slide_idx]
# for each tile in the bag
for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])):
tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(),
results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()]))
coords.append(tile_coords)
coords = np.array(coords)
attentions = np.array(attentions.cpu()).reshape(-1)
sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level)
cmap = plt.cm.get_cmap('Reds')
tile_xs, tile_ys = sel_coords.T
rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)]
pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None)
pc.set_array(np.array(attentions))
pc.set_clim([0, 1])
ax.add_collection(pc)
plt.colorbar(pc, ax=ax)
return fig
def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: List[str]) -> plt.figure:
"""Plots a normalized confusion matrix and returns the figure.
param cm: Normalized confusion matrix to be plotted.
param class_names: List of class names.
"""
fig, ax = plt.subplots()
ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%")
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.xaxis.set_ticklabels(class_names)
ax.yaxis.set_ticklabels(class_names)
return fig

Просмотреть файл

@ -1,67 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from enum import Enum
class SlideKey(str, Enum):
SLIDE_ID = 'slide_id'
IMAGE = 'image'
IMAGE_PATH = 'image_path'
MASK = 'mask'
MASK_PATH = 'mask_path'
LABEL = 'label'
SPLIT = 'split'
SCALE = 'scale'
ORIGIN = 'origin'
FOREGROUND_THRESHOLD = 'foreground_threshold'
METADATA = 'metadata'
LOCATION = 'location'
class TileKey(str, Enum):
TILE_ID = 'tile_id'
SLIDE_ID = 'slide_id'
IMAGE = 'image'
IMAGE_PATH = 'image_path'
MASK = 'mask'
MASK_PATH = 'mask_path'
LABEL = 'label'
SPLIT = 'split'
TILE_X = 'tile_x'
TILE_Y = 'tile_y'
OCCUPANCY = 'occupancy'
FOREGROUND_THRESHOLD = 'foreground_threshold'
SLIDE_METADATA = 'slide_metadata'
@staticmethod
def from_slide_metadata_key(slide_metadata_key: str) -> str:
return 'slide_' + slide_metadata_key
class ResultsKey(str, Enum):
SLIDE_ID = 'slide_id'
TILE_ID = 'tile_id'
IMAGE = 'image'
IMAGE_PATH = 'image_path'
LOSS = 'loss'
PROB = 'prob'
CLASS_PROBS = 'prob_class'
PRED_LABEL = 'pred_label'
TRUE_LABEL = 'true_label'
BAG_ATTN = 'bag_attn'
TILE_X = "x"
TILE_Y = "y"
class MetricsKey(str, Enum):
ACC = 'accuracy'
ACC_MACRO = 'macro_accuracy'
ACC_WEIGHTED = 'weighted_accuracy'
CONF_MATRIX = 'confusion_matrix'
AUROC = 'auroc'
PRECISION = 'precision'
RECALL = 'recall'
F1 = 'f1score'

Просмотреть файл

@ -1,21 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pandas as pd
def extract_fields(row: pd.Series) -> dict:
# Paths are structured as follows:
# "CRC_DX_[TEST|TRAIN]/[MSS|MSIMUT]/blk-{tile_id}-{slide_id}-01Z-00-DX1.png"
# - tile_id is an uppercase string of 12 letters
# - slide_id is "TCGA-XX-XXXX"
parts = row.image.split('/')
return {
'slide_id': parts[2][17:29],
'tile_id': parts[2][4:16],
'image': row.image,
'label': {'MSS': 0, 'MSIMUT': 1}[parts[1]],
'split': parts[0][7:].lower(),
}

Просмотреть файл

@ -1,61 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from typing import Any, Dict
import matplotlib.pyplot as plt
from monai.data.dataset import Dataset
from monai.data.image_reader import WSIReader
from torch.utils.data import DataLoader
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
from InnerEye.ML.Histopathology.utils.naming import SlideKey
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'],
'data_provider': ['karolinska'],
'isup_grade': tensor([0]),
'gleason_score': ['0+0']}
:param level: level of resolution to be loaded
:param margin: margin to be included
:return: a dict containing the image data and metadata
"""
loader = LoadPandaROId(WSIReader('cuCIM'), level=level, margin=margin)
img = loader(sample)
return img
def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int,
title_key: str = 'data_provider') -> None:
"""
:param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path.
Look at the PandaDataset for more detail
:param nsamples: number of random samples to be visualized
:param ncols: number of columns in the figure grid. Nrows is automatically inferred
:param level: level of resolution to be loaded
:param margin: margin to be included
:param title_key: metadata key in image_dict used to label each subplot
"""
panda_dataset = Dataset(PandaDataset(root=panda_dir))[:nsamples] # type: ignore
loader = DataLoader(panda_dataset, batch_size=1)
nrows = math.ceil(nsamples/ncols)
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9))
for dict_images, ax in zip(loader, axes.flat):
slide_id = dict_images[SlideKey.SLIDE_ID]
title = dict_images[SlideKey.METADATA][title_key]
print(f">>> Slide {slide_id}")
img = load_image_dict(dict_images, level=level, margin=margin)
ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0))
ax.set_title(title)
fig.tight_layout()

Просмотреть файл

@ -1,122 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments.
It is responsible for instantiating the encoder and full DeepMIL model. Subclasses should define
their datamodules and configure experiment-specific parameters.
"""
from pathlib import Path
from typing import Optional, Type # noqa
import param
from torch import nn
from torchvision.models.resnet import resnet18
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
ImageNetEncoder, ImageNetSimCLREncoder,
InnerEyeSSLEncoder, TileEncoder)
class BaseMIL(LightningContainer):
# Model parameters:
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:"
"`False` (default), or `True`.")
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
# Encoder parameters:
encoder_type: str = param.String(doc="Name of the encoder class to use.")
tile_size: int = param.Integer(224, bounds=(1, None), doc="Tile width/height, in pixels.")
n_channels: int = param.Integer(3, bounds=(1, None), doc="Number of channels in the tile.")
# Data module parameters:
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
max_bag_size: int = param.Integer(1000, bounds=(0, None),
doc="Upper bound on number of tiles in each loaded bag. "
"If 0 (default), will return all samples in each bag. "
"If > 0, bags larger than `max_bag_size` will yield "
"random subsets of instances.")
cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode,
doc="The type of caching to perform: "
"'memory' (default), 'disk', or 'none'.")
precache_location: str = param.ClassSelector(default=CacheLocation.NONE, class_=CacheLocation,
doc="Whether to pre-cache the entire transformed dataset upfront "
"and save it to disk and if re-load in cpu or gpu. Options:"
"`none` (default),`cpu`, `gpu`")
encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading"
"enconding_chunk_size tiles per chunk")
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
@property
def cache_dir(self) -> Path:
raise NotImplementedError
def setup(self) -> None:
if self.encoder_type == InnerEyeSSLEncoder.__name__:
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
self.encoder = self.get_encoder()
if not self.is_finetune:
self.encoder.eval()
def get_encoder(self) -> TileEncoder:
if self.encoder_type == ImageNetEncoder.__name__:
return ImageNetEncoder(feature_extraction_model=resnet18,
tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
return ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == HistoSSLEncoder.__name__:
return HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
elif self.encoder_type == InnerEyeSSLEncoder.__name__:
return InnerEyeSSLEncoder(pl_checkpoint_path=self.downloader.local_checkpoint_path,
tile_size=self.tile_size, n_channels=self.n_channels)
else:
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
def get_pooling_layer(self) -> Type[nn.Module]:
if self.pooling_type == AttentionLayer.__name__:
return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
# no-op IdentityEncoder to be used inside the model
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
dropout_rate=self.dropout_rate,
class_weights=self.data_module.class_weights,
l_rate=self.l_rate,
weight_decay=self.weight_decay,
adam_betas=self.adam_betas)
def get_data_module(self) -> TilesDataModule:
raise NotImplementedError
def get_slide_dataset(self) -> SlidesDataset:
raise NotImplementedError

Просмотреть файл

@ -1,175 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset.
Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=<encoder class name>`
For convenience, this module also defines encoder-specific containers that can be invoked without
additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL`
Reference:
- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
"""
from typing import Any, List
from pathlib import Path
import os
from monai.transforms import Compose
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace
from health_ml.networks.layers.attention_layers import AttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
from InnerEye.ML.common import get_best_checkpoint_path
from InnerEye.ML.Histopathology.models.transforms import (
EncodeTilesBatchd,
LoadTilesBatchd,
)
from InnerEye.ML.Histopathology.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetSimCLREncoder,
InnerEyeSSLEncoder,
)
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
class DeepSMILECrck(BaseMIL):
def __init__(self, **kwargs: Any) -> None:
# Define dictionary with default params that can be overriden from subclasses or CLI
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=AttentionLayer.__name__,
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,
# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/TCGA-CRCk"),
azure_dataset_id="TCGA-CRCk",
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
# declared in TrainerParams:
num_epochs=50,
# declared in WorkflowParams:
number_of_cross_validation_splits=5,
cross_validation_split_index=0,
# declared in OptimizerParams:
l_rate=5e-4,
weight_decay=1e-4,
adam_betas=(0.9, 0.99),
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
self.best_checkpoint_filename_with_suffix = (
self.best_checkpoint_filename + ".ckpt"
)
self.checkpoint_folder_path = "outputs/checkpoints/"
best_checkpoint_callback = ModelCheckpoint(
dirpath=self.checkpoint_folder_path,
monitor="val/auroc",
filename=self.best_checkpoint_filename,
auto_insert_metric_name=False,
mode="max",
)
self.callbacks = best_checkpoint_callback
@property
def cache_dir(self) -> Path:
return Path(
f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/"
)
def setup(self) -> None:
if self.encoder_type == InnerEyeSSLEncoder.__name__:
from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_crck_4ws
self.downloader = CheckpointDownloader(
azure_config_json_path=get_workspace(),
run_id=innereye_ssl_checkpoint_crck_4ws,
checkpoint_filename="best_checkpoint.ckpt",
download_dir="outputs/",
remote_checkpoint_dir=Path("outputs/checkpoints")
)
os.chdir(fixed_paths.repository_parent_directory())
self.downloader.download_checkpoint_if_necessary()
self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
def get_data_module(self) -> TilesDataModule:
image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder),
]
)
return TcgaCrckTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
transform=transform,
cache_mode=self.cache_mode,
precache_location=self.precache_location,
cache_dir=self.cache_dir,
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
cross_validation_split_index=self.cross_validation_split_index,
)
def get_callbacks(self) -> List[Callback]:
return super().get_callbacks() + [self.callbacks]
def get_path_to_best_checkpoint(self) -> Path:
"""
Returns the full path to a checkpoint file that was found to be best during training, whatever criterion
was applied there.
"""
# absolute path is required for registering the model.
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path.is_file():
return absolute_checkpoint_path
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path_parent.is_file():
return absolute_checkpoint_path_parent
checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path))
if checkpoint_path.is_file():
return checkpoint_path
raise ValueError("Path to best checkpoint not found")
class TcgaCrckImageNetMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs)
class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs)
class TcgaCrckHistoSSLMIL(DeepSMILECrck):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -1,208 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List
from pathlib import Path
import os
from monai.transforms import Compose
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.common import get_best_checkpoint_path
from InnerEye.ML.Histopathology.models.transforms import (
EncodeTilesBatchd,
LoadTilesBatchd,
)
from InnerEye.ML.Histopathology.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetSimCLREncoder,
InnerEyeSSLEncoder,
IdentityEncoder
)
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
class DeepSMILEPanda(BaseMIL):
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch.
Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA.
"""
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,
is_finetune=False,
# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
azure_dataset_id="PANDA_tiles",
extra_azure_dataset_ids=["PANDA"],
extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")],
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
# declared in TrainerParams:
num_epochs=200,
# use_mixed_precision = True,
# declared in WorkflowParams:
number_of_cross_validation_splits=5,
cross_validation_split_index=0,
# declared in OptimizerParams:
l_rate=5e-4,
weight_decay=1e-4,
adam_betas=(0.9, 0.99))
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
super().__init__(**default_kwargs)
if not is_running_in_azure_ml():
self.num_epochs = 1
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
self.best_checkpoint_filename_with_suffix = (
self.best_checkpoint_filename + ".ckpt"
)
self.checkpoint_folder_path = "outputs/checkpoints/"
best_checkpoint_callback = ModelCheckpoint(
dirpath=self.checkpoint_folder_path,
monitor="val/accuracy",
filename=self.best_checkpoint_filename,
auto_insert_metric_name=False,
mode="max",
)
self.callbacks = best_checkpoint_callback
@property
def cache_dir(self) -> Path:
return Path(
f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/"
)
def setup(self) -> None:
if self.encoder_type == InnerEyeSSLEncoder.__name__:
from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_binary
self.downloader = CheckpointDownloader(
aml_workspace=get_workspace(),
run_id=innereye_ssl_checkpoint_binary, # innereye_ssl_checkpoint
checkpoint_filename="best_checkpoint.ckpt", # "last.ckpt",
download_dir="outputs/",
remote_checkpoint_dir=Path("outputs/checkpoints")
)
os.chdir(fixed_paths.repository_parent_directory())
self.downloader.download_checkpoint_if_necessary()
self.encoder = self.get_encoder()
if not self.is_finetune:
self.encoder.eval()
def get_data_module(self) -> PandaTilesDataModule:
image_key = PandaTilesDataset.IMAGE_COLUMN
if self.is_finetune:
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
else:
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
])
return PandaTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
transform=transform,
cache_mode=self.cache_mode,
precache_location=self.precache_location,
cache_dir=self.cache_dir,
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
cross_validation_split_index=self.cross_validation_split_index,
)
def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
# no-op IdentityEncoder to be used inside the model
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
class_weights=self.data_module.class_weights,
l_rate=self.l_rate,
weight_decay=self.weight_decay,
adam_betas=self.adam_betas,
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names,
is_finetune=self.is_finetune)
def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
def get_callbacks(self) -> List[Callback]:
return super().get_callbacks() + [self.callbacks]
def get_path_to_best_checkpoint(self) -> Path:
"""
Returns the full path to a checkpoint file that was found to be best during training, whatever criterion
was applied there.
"""
# absolute path is required for registering the model.
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path.is_file():
return absolute_checkpoint_path
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path_parent.is_file():
return absolute_checkpoint_path_parent
checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path))
if checkpoint_path.is_file():
return checkpoint_path
raise ValueError("Path to best checkpoint not found")
class PandaImageNetMIL(DeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
class PandaImageNetSimCLRMIL(DeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs)
class PandaInnerEyeSSLMIL(DeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs)
class PandaHistoSSLMIL(DeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs)

Просмотреть файл

@ -1,7 +0,0 @@
innereye_ssl_checkpoint = "hsharma_panda_explore:hsharma_panda_explore_1638437076_357167ae"
innereye_ssl_checkpoint_binary = "hsharma_panda_tiles_ssl:hsharma_panda_tiles_ssl_1639766433_161e03b9"
innereye_ssl_checkpoint_crck_4ws = "ModifyOldSSLCheckpoint:a9259fdb-3964-4c5b-8962-4660e0b79d44"
innereye_ssl_checkpoint_crck_radiomics = "ModifyOldSSLCheckpoint:704b1af8-7c75-46ed-8460-d80a0e603194"
# outdated checkpoints
# innereye_ssl_checkpoint_crck_radiomics = updated_transforms:updated_transforms_1636471522_5473e3ff

Просмотреть файл

@ -10,7 +10,7 @@ from abc import abstractmethod
from collections import Counter, defaultdict
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Set, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Union
import numpy as np
import pandas as pd
@ -22,13 +22,10 @@ from InnerEye.ML.dataset.full_image_dataset import GeneralDataset
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
from InnerEye.ML.dataset.scalar_sample import ScalarDataSource, ScalarItem, SequenceDataSource
from InnerEye.ML.scalar_config import LabelTransformation, ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.features_util import FeatureStatistics
T = TypeVar('T', bound=ScalarDataSource)
def extract_label_classification(label_string: str, sample_id: str, num_classes: int,
is_classification_dataset: bool) -> List[float]:
@ -180,7 +177,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
metadata_columns: Optional[Set[str]] = None,
is_classification_dataset: bool = True,
num_classes: int = 1,
sequence_position_numeric: Optional[int] = None) -> T:
sequence_position_numeric: Optional[int] = None) -> ScalarDataSource:
"""
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
labels, the non-image features, and the paths to the image files.
@ -328,7 +325,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
return datasource # type: ignore
class DataSourceReader(Generic[T]):
class DataSourceReader():
"""
Class that allows reading of data sources from a scalar dataset data frame.
"""
@ -421,7 +418,7 @@ class DataSourceReader(Generic[T]):
@staticmethod
def load_data_sources_as_per_config(data_frame: pd.DataFrame,
args: ScalarModelBase) -> List[T]:
args: ScalarModelBase) -> List[ScalarDataSource]:
"""
Loads dataset items from the given dataframe, where all column and channel configurations are taken from their
respective model config elements.
@ -436,11 +433,7 @@ class DataSourceReader(Generic[T]):
if args.categorical_feature_encoder is not None:
assert isinstance(args.categorical_feature_encoder, CategoricalToOneHotEncoder) # mypy
sequence_column = None
if isinstance(args, SequenceModelBase):
sequence_column = args.sequence_column
return DataSourceReader[T](
return DataSourceReader(
data_frame=data_frame,
image_channels=args.image_channels,
image_file_column=args.image_file_column,
@ -450,14 +443,13 @@ class DataSourceReader(Generic[T]):
non_image_feature_channels=args.get_non_image_feature_channels_dict(),
numerical_columns=args.numerical_columns,
categorical_data_encoder=args.categorical_feature_encoder,
sequence_column=sequence_column,
subject_column=args.subject_column,
channel_column=args.channel_column,
num_classes=len(args.class_names),
is_classification_dataset=args.is_classification_model
).load_data_sources(num_dataset_reader_workers=args.num_dataset_reader_workers)
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[T]:
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[ScalarDataSource]:
"""
Extracts information from a dataframe to create a list of ClassificationItem. This will create one entry per
unique
@ -484,12 +476,12 @@ class DataSourceReader(Generic[T]):
return list(flatten(filter(None, results)))
def load_datasources_for_subject(self, subject_id: str) -> Optional[List[T]]:
def load_datasources_for_subject(self, subject_id: str) -> Optional[List[ScalarDataSource]]:
rows = self.data_frame[np.in1d(self.data_frame[self.subject_column].values, [subject_id])]
def _load_single_data_source(_rows: pd.DataFrame,
_sequence_position_numeric: Optional[int] = None) -> T:
_sequence_position_numeric: Optional[int] = None) -> ScalarDataSource:
return load_single_data_source(
subject_rows=_rows,
subject_id=subject_id,
@ -506,28 +498,14 @@ class DataSourceReader(Generic[T]):
is_classification_dataset=self.is_classification_dataset,
num_classes=self.num_classes,
sequence_position_numeric=_sequence_position_numeric
)
def _load_sequence_data_source(_sequence_position: Any) -> T:
_sequence_position_numeric = int(_sequence_position)
if _sequence_position_numeric < 0:
raise ValueError(
f"Sequence positions must be non-negative integers, but got: {_sequence_position}")
else:
seq_rows = rows[np.in1d(rows[self.sequence_column].values, [_sequence_position])]
return _load_single_data_source(seq_rows, _sequence_position_numeric)
if self.sequence_column:
seq_positions = rows[self.sequence_column].unique()
return list(map(_load_sequence_data_source, seq_positions))
else:
if len(self.expected_channels) > 0:
missing_channels = self.expected_channels - set(rows[self.channel_column])
if len(missing_channels) > 0:
logging.warning(f"Subject {subject_id} will be skipped completely because the following "
f"channels are missing: {','.join(missing_channels)}.")
return None
return [_load_single_data_source(rows)]
)
if len(self.expected_channels) > 0:
missing_channels = self.expected_channels - set(rows[self.channel_column])
if len(missing_channels) > 0:
logging.warning(f"Subject {subject_id} will be skipped completely because the following "
f"channels are missing: {','.join(missing_channels)}.")
return None
return [_load_single_data_source(rows)]
def files_by_stem(root_path: Path) -> Dict[str, Path]:
@ -582,10 +560,10 @@ def is_valid_item_index(item: ScalarDataSource,
return min_sequence_position_value <= item.metadata.sequence_position <= max_sequence_position_value
def filter_valid_classification_data_sources_items(items: Iterable[T],
def filter_valid_classification_data_sources_items(items: Iterable[ScalarDataSource],
file_to_path_mapping: Optional[Dict[str, Path]],
max_sequence_position_value: Optional[int] = None,
min_sequence_position_value: int = 0) -> List[T]:
min_sequence_position_value: int = 0) -> List[ScalarDataSource]:
"""
Consumes a list of classification data sources, and removes all of those that have missing file names,
or that have NaN or Inf features. If the file_to_path_mapping is given too, all items that have any missing files
@ -601,7 +579,7 @@ def filter_valid_classification_data_sources_items(items: Iterable[T],
:return: A list of items, all of which are valid now.
"""
def all_files_present(item: T) -> bool:
def all_files_present(item: ScalarDataSource) -> bool:
if file_to_path_mapping:
return all(f in file_to_path_mapping for f in item.channel_files)
else:
@ -678,14 +656,14 @@ class ScalarItemAugmentation:
return item
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
"""
A base class for datasets for classification tasks. It contains logic for loading images from disk,
either from a fixed folder or traversing into subfolders.
"""
one_hot_encoder: Optional[CategoricalToOneHotEncoder] = None
status: str = ""
items: List[T]
items: List[ScalarDataSource]
def __init__(self, args: ScalarModelBase,
data_frame: Optional[pd.DataFrame] = None,
@ -710,7 +688,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
self.file_to_full_path = files_by_stem(args.local_dataset)
logging.info("Finished traversing folder.")
def load_all_data_sources(self) -> List[T]:
def load_all_data_sources(self) -> List[ScalarDataSource]:
"""
Uses the dataframe to create data sources to be used by the dataset.
:return:
@ -721,7 +699,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
self.status += f"After filtering: {self.create_status_string(all_data_sources)}"
return all_data_sources
def filter_valid_data_sources_items(self, data_sources: List[T]) -> List[T]:
def filter_valid_data_sources_items(self, data_sources: List[ScalarDataSource]) -> List[ScalarDataSource]:
raise NotImplementedError("filter_valid_data_source_items must be implemented by child classes")
@abstractmethod
@ -737,7 +715,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
If None, they will be computed from the data in the present object.
"""
if self.items:
self.feature_statistics = self.feature_statistics or FeatureStatistics[T].from_data_sources(self.items)
self.feature_statistics = self.feature_statistics or FeatureStatistics.from_data_sources(self.items)
self.items = self.feature_statistics.standardize(self.items)
def load_item(self, item: ScalarDataSource) -> ScalarItem:
@ -757,7 +735,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
return self.transform(sample)
def create_status_string(self, items: List[T]) -> str:
def create_status_string(self, items: List[ScalarDataSource]) -> str:
"""
Creates a human readable string that contains the number of items, and the distinct number of subjects.
:param items: Use the items provided to create the string
@ -767,14 +745,14 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
return f"{len(items)} items for {distinct} subjects. "
class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
class ScalarDataset(ScalarDatasetBase):
"""
A dataset class that can read CSV files with a flexible schema, and extract image file paths and non-image features.
"""
def __init__(self, args: ScalarModelBase,
data_frame: Optional[pd.DataFrame] = None,
feature_statistics: Optional[FeatureStatistics[ScalarDataSource]] = None,
feature_statistics: Optional[FeatureStatistics] = None,
name: Optional[str] = None,
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""

Просмотреть файл

@ -1,304 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
import logging
from collections import Counter, defaultdict
from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional
import numpy as np
import pandas as pd
import torch
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, ScalarItemAugmentation, \
filter_valid_classification_data_sources_items
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.features_util import FeatureStatistics
def get_longest_contiguous_sequence(items: List[SequenceDataSource],
min_sequence_position_value: int = 0,
max_sequence_position_value: Optional[int] = None) -> List[SequenceDataSource]:
"""
From a list of classification items, extract the longest contiguous sequence of items starting
at position value min_sequence_position_value.
For example:
if min_sequence_position_value = 1 and the
input has sequence positions [0, 1, 3, 4], the result retains the items for positions [1].
if min_sequence_position_value = 1 and max_sequence_position_value = 2, then if
input has sequence positions [0, 1, 2, 3], the result retains the items for positions [1, 2].
if min_sequence_position_value = 1 and max_sequence_position_value = 4, then if
input has sequence positions [0, 1, 2, 3], the result retains the items for positions [1, 2, 3].
If the input sequence is [2, 3], the result is an empty list
(the longest sequence must start at 1 but there is no item with position 1)
:param items: A list of classification items, sorted by sequence_position.
:param min_sequence_position_value: The minimum sequence position all sequences start from, 0 is default.
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
end with. Longer sequences will be truncated. None is default.
:return: A list of classification items, with a maximum sequence_position that is the
len(result) - 1.
"""
result: List[SequenceDataSource] = []
# make sure the input sequence is sorted by sequence position first
items = list(sorted(items, key=lambda x: x.metadata.sequence_position))
_last_seq_item = next((x for x in items if x.metadata.sequence_position == min_sequence_position_value), None)
if _last_seq_item is None:
return result
else:
result.append(_last_seq_item)
for item in items[items.index(_last_seq_item) + 1:]:
if max_sequence_position_value is not None and \
item.metadata.sequence_position > max_sequence_position_value:
break
elif item.metadata.sequence_position - _last_seq_item.metadata.sequence_position == 1:
_last_seq_item = item
result.append(item)
else:
break
return result
def group_samples_into_sequences(items: Iterable[SequenceDataSource],
min_sequence_position_value: int = 0,
max_sequence_position_value: Optional[int] = None) -> ListOfSequences:
"""
Turns a flat list of classification items into a list of per-subject classification items. The resulting list
has one entry per unique sample ID in the input. With a single sample ID, the items
are sorted by metadata.sequence_position in ascending order.
Also, all subject data is restricted to the largest contiguous sequence starting at 0
(e.g., if sequence positions are [0, 1, 4], only [0, 1] are retained,
if sequence positions are [1, 2, 3] nothing is retained)
:param items: The items that should be grouped.
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
end with. Longer sequences will be truncated. None is default.
up to and including this value. Entries beyond that sequence_position will be dropped.
:param min_sequence_position_value: All sequences must have a entries with sequence_position starting
from and including this value, 0 is default.
:return:
"""
if min_sequence_position_value < 0:
raise ValueError("Argument min_sequence_position_value must be >= 0")
if max_sequence_position_value:
if max_sequence_position_value < min_sequence_position_value:
raise ValueError(f"Argument max_sequence_position_value: {max_sequence_position_value} must "
f"be >= min_sequence_position_value: {min_sequence_position_value}")
grouped: DefaultDict[str, List[SequenceDataSource]] = defaultdict(list)
for item in items:
grouped[item.id].append(item)
result: List[ClassificationItemSequence[SequenceDataSource]] = []
for sample_id, items in grouped.items():
unique_positions = set(x.metadata.sequence_position for x in items)
if len(unique_positions) != len(items):
raise ValueError(f"The set of sequence positions for subject {sample_id} contains duplicates.")
group_sorted = get_longest_contiguous_sequence(
items=items,
min_sequence_position_value=min_sequence_position_value,
max_sequence_position_value=max_sequence_position_value
)
if len(group_sorted) > 0:
result.append(ClassificationItemSequence(id=sample_id, items=group_sorted))
else:
# No contiguous sequence at all
logging.warning(f"Skipped sequence for subject {sample_id} as it was not contiguous")
return result
def add_difference_features(sequences: ListOfSequences, feature_indices: List[int]) -> ListOfSequences:
"""
For each sequence in the argument, compute feature differences to the first sequence element, and adds them
as new features at the end of the non-image features. Feature differences are only compute for those columns
in numerical_non_image_features that are given in the feature_indices argument.
The first sequence elements gets feature differences that are all zeros. The i.th sequence element will get
additional features that are the differences of numerical_non_image_features[:,j] and the same element in the
0.th sequence
element.
:param sequences: The input sequences.
:param feature_indices: The column indices in numerical_non_image_features for which differences should be computed.
:return: A new list of sequences with the feature differences added as columns in the
numerical_non_image_features field.
"""
def add_features(seq: ClassificationItemSequence) -> ClassificationItemSequence:
items_mapped: List[SequenceDataSource] = []
feature_baseline = None
for item_index, item in enumerate(seq.items):
if item_index == 0:
feature_baseline = torch.stack([item.numerical_non_image_features[:, i] for i in feature_indices],
dim=0)
features_for_diff = torch.stack([item.numerical_non_image_features[:, i] for i in feature_indices], dim=0)
diff = features_for_diff - feature_baseline
new_features = torch.cat([item.numerical_non_image_features, diff.t()], dim=1)
items_mapped.append(item.clone_with_overrides(numerical_non_image_features=new_features))
return ClassificationItemSequence(id=seq.id, items=items_mapped)
return list(map(add_features, sequences))
"""
Example for the use of SequenceDataset:
A sequence dataset groups rows not only by subject ID (as the normal ClassificationDataset does), but also
by a sequence position. That sequence position is read out from a column specified in the `sequence_column`
field of the model configuration.
Within a given subject, a sequence dataset returns instance of ClassificationItemSequence, each of which contains
a ClassificationItem for each individual sequence position.
Example use case:
subject,POSITION,measure0,measure0,image,Label
1,0,92,362,img1,0
1,1,92,357,img1,1
1,2,92,400,,0
2,0,82,477,img2,0
2,1,82,,img2,1
2,2,82,220,img2,0
To read images and measure1 as a non-imaging feature from this file, you would specify:
image_channels = []
image_file_column = "image"
label_channel = None
label_value_column = "Label"
non_image_feature_channels = []
numerical_columns = ["measure1"]
sequence_column = "POSITION"
All of the "*_channel" arguments can be left empty. After grouping by subject and sequence position,
only 1 row remains, and it is hence clear to the data loader which row to read from.
After reading the CSV files, the data loader will remove all rows where
* there is no image file path given in the file
* there is a missing value in the non-image features (missing measure1 in the example above)
* If the traverse_dirs_when_loading is given in the model config, the data loader will also remove items where
the image file does not exist.
After this filtering, the data loader will group the items by subject, and sort by position within a subject.
Within a subject, the sequences must start at position 0, and are kept up to the first "gap". Hence, if only positions
0, 1, and 2 are valid, the sequence that is kept contains items [0, 1]
Assuming that the image files all exist, this would return
* result[0] containing "1" with POSITION numbers 0 and 1 (position 2 has no image file)
* result[1] containing "2" with POSITION number 0 only (position 1 has missing measure1, and hence position 2 has to
be dropped as well)
"""
class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
"""
A dataset class that groups its raw dataset rows by subject ID and a sequence index. Each item in the dataset
has all the rows for a given subject, and within each subject, a sorted sequence of rows.
"""
items: List[ClassificationItemSequence[SequenceDataSource]] # type: ignore
def __init__(self,
args: SequenceModelBase,
data_frame: pd.DataFrame,
feature_statistics: Optional[
FeatureStatistics[ClassificationItemSequence[SequenceDataSource]]] = None,
name: Optional[str] = None,
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
Creates a new sequence dataset from a dataframe.
:param args: The model configuration object.
:param data_frame: The dataframe to read from.
:param feature_statistics: If given, the normalization factor for the non-image features is taken
:param sample_transform: Transformation to apply to each sample in the loading step. By default, no
transformation is applied.
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
:param name: Name of the dataset, used for logging
"""
super().__init__(args=args,
data_frame=data_frame,
feature_statistics=feature_statistics,
name=name,
sample_transform=sample_transform)
if self.args.sequence_column is None:
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
"sequence index should be read from.")
if len(self.args.class_names) > 1:
raise ValueError("Multilabel configs not supported for sequence datasets.")
data_sources = self.load_all_data_sources()
grouped = group_samples_into_sequences(
data_sources,
min_sequence_position_value=self.args.min_sequence_position_value,
max_sequence_position_value=self.args.max_sequence_position_value
)
if self.args.add_differences_for_features:
missing_columns = set(self.args.add_differences_for_features) - set(self.args.numerical_columns)
if len(missing_columns) > 0:
raise ValueError(f"Unable to add differences for these columns because they have not been specified "
f"in the `non_image_feature_channels` property: {missing_columns}")
feature_indices = [self.args.numerical_columns.index(f) for f in self.args.add_differences_for_features]
grouped = add_difference_features(grouped, feature_indices)
self.status += f"After grouping: {len(grouped)} subjects."
self.items = grouped
self.standardize_non_imaging_features()
def get_status(self) -> str:
"""
Creates a human readable string that describes the contents of the dataset.
"""
return self.status
def filter_valid_data_sources_items(self, data_sources: List[SequenceDataSource]) -> List[SequenceDataSource]:
return filter_valid_classification_data_sources_items(
items=data_sources,
file_to_path_mapping=self.file_to_full_path,
min_sequence_position_value=self.args.min_sequence_position_value,
max_sequence_position_value=self.args.max_sequence_position_value
)
def get_labels_for_imbalanced_sampler(self) -> List[float]:
"""
Returns a list of all the labels at the target_index position. Is used to
compute the weights in the ImbalancedSampler. If more than on target position
is specified the ImbalancedSampler cannot be used.
:return:
"""
if len(self.args.get_target_indices()) > 1:
raise NotImplementedError("You cannot use the ImbalancedSampler if you"
"want to predict more than one sequence position."
"Use loss weighting instead.")
return [seq.get_labels_at_target_indices(self.args.get_target_indices())[-1].item()
for seq in self.items]
def get_class_counts(self) -> Dict[int, int]:
"""
Return the label counts (summed over all target indices).
:return: Dictionary of {"label": count}
"""
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
for seq in self.items]) # [N, T, 1]
non_nan_and_nonzero_labels = list(
filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
counts = dict(Counter(non_nan_and_nonzero_labels))
if not len(counts.keys()) == 1 or 1 not in counts.keys():
raise ValueError("get_class_counts supports only binary targets.")
return {0: counts[1]}
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, i: int) -> Dict[str, Any]:
loaded = list(map(self.load_item, self.items[i].items))
return vars(ClassificationItemSequence(id=self.items[i].id, items=loaded))

Просмотреть файл

@ -1,77 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, TypeVar
import numpy as np
import torch
from InnerEye.Common.common_util import check_properties_are_not_none
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
T = TypeVar('T', SequenceDataSource, ScalarItem)
@dataclass(frozen=True)
class ClassificationItemSequence(Generic[T]):
"""
A class that holds a sequence of samples for a given patient ID.
"""
id: str
items: List[T]
def __post_init__(self) -> None:
check_properties_are_not_none(self)
@staticmethod
def create_labels_tensor_for_minibatch(sequences: List[ClassificationItemSequence[ScalarItem]],
target_indices: List[int]) -> torch.Tensor:
"""
Create label tensor for a minibatch training from a list of sequences for the provided
target indices. If sequences are unequal then they are padded with a NaN value.
:param sequences: sequences to create label tensor from.
:param target_indices: label indices for which to extract label for from the provided sequences.
:return: A label tensor with NaN padding if required.
"""
return sequences_to_padded_tensor(
sequences=[seq.get_labels_at_target_indices(target_indices) for seq in sequences],
padding_value=np.nan
)
@staticmethod
def from_minibatch(minibatch: Dict[str, Any]) -> List[ClassificationItemSequence[ScalarItem]]:
"""
Creates a list of ClassificationItemSequence from the output of a data loader. The data loader returns a
dictionary with collated items, this function is effectively the inverse.
:param minibatch: A dictionary that contains the collated fields of ClassificationItemSequence objects.
:return: A list of ClassificationItemSequence objects.
"""
# batched is a de-generate ClassificationItemSequence, with id being a list of strings, and items being
# a list of lists.
batched = ClassificationItemSequence(**minibatch)
return [ClassificationItemSequence(id=sample_id, items=items)
for (sample_id, items) in zip(batched.id, batched.items)]
def get_labels_at_target_indices(self, target_indices: List[int]) -> torch.Tensor:
"""
Gets the label fields for the sequence elements with the given zero-based indices, if they exist
otherwise fill with NaN.
"""
target_indices = sorted(target_indices)
nan = torch.tensor([np.nan]).type_as(self.items[0].label)
def _get_label_or_nan(idx: int) -> torch.Tensor:
return self.items[idx].label if idx < len(self.items) else nan
if any(p < 0 for p in target_indices):
raise ValueError("Argument target_indices cannot contain negative values")
return torch.stack(list(map(_get_label_or_nan, target_indices)))
ListOfSequences = List[ClassificationItemSequence[SequenceDataSource]]

Просмотреть файл

@ -19,11 +19,9 @@ from InnerEye.ML.metrics import compute_dice_across_patches
from InnerEye.ML.metrics_dict import DataframeLogger, MetricsDict
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import image_util, metrics_util, model_util
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss
from pytorch_lightning import Trainer
SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank"
@ -184,12 +182,7 @@ class ScalarLightning(InnerEyeLightning):
super().__init__(config, *args, **kwargs)
self.model = config.create_model()
raw_loss = model_util.create_scalar_loss_function(config)
if isinstance(config, SequenceModelBase):
self.loss_fn = lambda model_output, loss: apply_sequence_model_loss(raw_loss, model_output, loss)
self.target_indices = config.get_target_indices()
else:
self.loss_fn = raw_loss
self.target_indices = []
self.loss_fn = raw_loss
self.target_names = config.target_names
self.is_classification_model = config.is_classification_model
@ -202,11 +195,6 @@ class ScalarLightning(InnerEyeLightning):
self.train_metric_computers = config.create_metric_computers()
self.val_metric_computers = config.create_metric_computers()
self.compute_and_log_metrics = config.compute_and_log_metrics
# if config.compute_grad_cam:
# model_to_evaluate = self.train_val_params.mean_teacher_model if \
# config.compute_mean_teacher_model else self.train_val_params.model
# self.guided_grad_cam = VisualizationMaps(model_to_evaluate, config)
# config.visualization_folder.mkdir(exist_ok=True)
def forward(self, *model_inputs: torch.Tensor) -> torch.Tensor: # type: ignore
"""
@ -246,7 +234,7 @@ class ScalarLightning(InnerEyeLightning):
:param batch_index: The index of the present batch (supplied only for diagnostics).
Runs a minibatch of training or validation data through the model.
"""
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample)
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, sample)
labels = model_inputs_and_labels.labels
if is_training:
logits = self.model(*model_inputs_and_labels.model_inputs)

Просмотреть файл

@ -33,7 +33,6 @@ from InnerEye.ML.pipelines.scalar_inference import ScalarEnsemblePipeline, Scala
ScalarInferencePipelineBase
from InnerEye.ML.reports.segmentation_report import boxplot_per_structure
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import io_util, ml_util
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array
from InnerEye.ML.utils.io_util import ImageHeader, MedicalImageFileType, load_nifti_image, save_lines_to_file
@ -397,11 +396,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
Create an instance of either a ScalarMetricsDict or SequenceMetricsDict, depending on whether the given
configuration is a sequence model configuration or not.
"""
if isinstance(config, SequenceModelBase):
return SequenceMetricsDict.create(is_classification_model=config.is_classification_model,
sequence_target_positions=config.sequence_target_positions)
else:
return ScalarMetricsDict(hues=config.target_names,
return ScalarMetricsDict(hues=config.target_names,
is_classification_metrics=config.is_classification_model)
@ -424,25 +419,19 @@ def classification_model_test(config: ScalarModelBase,
checkpoint_paths=checkpoint_paths)
if pipeline is None:
raise ValueError("Inference pipeline could not be created.")
# for mypy
assert isinstance(pipeline, ScalarInferencePipelineBase)
ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
shuffle=False,
batch_size=1,
num_dataload_workers=0
)
logging.info(f"Starting to evaluate model on {data_split.value} set.")
results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc)
os.makedirs(str(results_folder), exist_ok=True)
metrics_dict = create_metrics_dict_for_scalar_models(config)
if not isinstance(config, SequenceModelBase):
output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV)
else:
output_logger = None
output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV)
for sample in ds:
result = pipeline.predict(sample)
@ -463,15 +452,11 @@ def classification_model_test(config: ScalarModelBase,
labels=label,
loss_type=config.loss_type)
logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")
average = metrics_dict.average(across_hues=False)
logging.info(average.to_string())
if isinstance(metrics_dict, ScalarMetricsDict):
csv_file = results_folder / SUBJECT_METRICS_FILE_NAME
logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")
# If we are running inference after a training run, the validation set metrics may have been written
# during train time. If this is not the case, or we are running on the test set, create the metrics
# file.

Просмотреть файл

@ -1,101 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import RNNCellBase
from InnerEye.ML.models.layers.identity import Identity
class LayerNormGRUCell(RNNCellBase):
"""
Implements GRUCell with layer normalisation and zone-out on top.
It inherits the base RNN cell whose trainable weight matrices are used.
References:
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer Normalization." (2016).
[2] Krueger, David, et al. "Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations." (2016).
:param input_size: Number of input features to the cell
:param hidden_size: Number of hidden states in the cell
:param use_layer_norm: If set to True, layer normalisation is applied to
reset, update and new tensors before activation.
:param dropout: Dropout probability for the hidden states [0,1]
"""
def __init__(self, input_size: int, hidden_size: int, use_layer_norm: bool = False, dropout: float = 0.0):
super(LayerNormGRUCell, self).__init__(input_size, hidden_size, bias=False, num_chunks=3)
self.dropout = dropout
self.ln_r = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
self.ln_z = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
self.ln_n = nn.LayerNorm(self.hidden_size) if use_layer_norm else Identity()
def forward(self, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> torch.Tensor: # type: ignore
if hx is None:
hx = input.new_zeros(size=(input.size(0), self.hidden_size), requires_grad=False)
ih = input.mm(self.weight_ih.t())
hh = hx.mm(self.weight_hh.t())
i_r, i_z, i_n = ih.chunk(3, dim=1)
h_r, h_z, h_n = hh.chunk(3, dim=1)
# Activations with layer normalisation
r = torch.sigmoid(self.ln_r(i_r + h_r))
z = torch.sigmoid(self.ln_z(i_z + h_z))
n = torch.tanh(self.ln_n(i_n + r * h_n))
new_h = (torch.tensor(1.0) - z) * n + z * hx
# Apply zoneout drop-out on hidden states
if self.dropout > 0.0:
bernouli_mask = F.dropout(torch.ones_like(new_h), p=self.dropout, training=bool(self.training))
new_h = bernouli_mask * new_h + (torch.tensor(1.0) - bernouli_mask) * hx
return new_h
class LayerNormGRU(nn.Module):
"""
Implements a stacked GRU layers. Differs from torch.nn.GRU implementation by
the use of layer normalisation and hidden state preserving drop-out techniques
(zone-out) which are currently not provided in the default implementation.
https://arxiv.org/pdf/1607.06450.pdf
https://arxiv.org/pdf/1606.01305.pdf
:param input_size: Number of input features.
:param hidden_size: Number of hidden states in GRU, it is used for all layers.
:param num_layers: Number of stacked GRU layers in the module.
:param batch_first: If set to true, input tensor should have the batch dimension in the first axis.
"""
def __init__(self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
**kwargs: Any):
super(LayerNormGRU, self).__init__()
self.num_layers = num_layers
self.cells = nn.ModuleList([
LayerNormGRUCell(input_size if i == 0 else hidden_size, hidden_size, **kwargs)
for i in range(self.num_layers)
])
def forward(self, x: torch.Tensor, hx: torch.Tensor) -> torch.Tensor: # type: ignore
seq_axis = 1
for i, cell in enumerate(self.cells): # type: ignore
y = []
hidden = hx[i]
for xc in x.chunk(x.size(seq_axis), dim=seq_axis):
xc = xc.squeeze(seq_axis)
hidden = cell(xc, hidden)
y.append(hidden.unsqueeze(0))
x = torch.stack(y, dim=seq_axis + 1).squeeze(0)
return x

Просмотреть файл

@ -1,200 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
from InnerEye.ML.models.architectures.sequential.gru import LayerNormGRU
from InnerEye.ML.models.layers.identity import Identity
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
class RNNClassifier(DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor]):
"""
Recurrent neural network (GRU) to perform a binary classification of sequence datasets.
The class scores that the model outputs are results of a log softmax.
:param input_dim: Number of input channels for the GRU layer.
:param hidden_dim: Number of hidden states
:param output_dim: Number of model output channels
:param num_rnn_layers: Number of RNN layers to be stacked in the classifier. By default, a single GRU layer is used.
:param use_layer_norm: If set to True, hidden state activations are normalised at each time step.
:param target_indices: Output target indices. For many input to one output sequential model,
it should be equal to the last index `-1`. If a tensor of indices are provided,
it will return sequential model outputs at given time indices.
:param ref_indices: Optional, if set then the hidden states from these reference indices is concatenated to the
hidden state of the target position before computing the class posterior.
"""
def __init__(self,
input_dim: int,
hidden_dim: int,
output_dim: int,
target_indices: List[int],
num_rnn_layers: int = 1,
use_layer_norm: bool = False,
rnn_dropout: float = 0.00,
ref_indices: Optional[List[int]] = None) -> None:
super().__init__()
self.target_indices = target_indices or [-1]
self.input_dim = input_dim
self.ref_indices = ref_indices
self.hidden_dim = hidden_dim
# The GRU takes embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
self.gru: LayerNormGRU = LayerNormGRU(input_size=input_dim,
hidden_size=hidden_dim,
num_layers=num_rnn_layers,
use_layer_norm=use_layer_norm,
dropout=rnn_dropout)
# The linear layer that maps from hidden state space to class space
if self.ref_indices is None:
self.hidden2class = nn.Linear(hidden_dim, output_dim)
else:
self.hidden2class = nn.Linear((len(ref_indices) + 1) * hidden_dim, output_dim) # type: ignore
# Create a parameter to learn the initial hidden state
hidden_size = torch.Size([num_rnn_layers, 1, hidden_dim])
self.h0 = nn.Parameter(torch.zeros(size=hidden_size), requires_grad=True)
self.initialise_parameters()
def forward(self, *input_seq: torch.Tensor) -> torch.Tensor: # type: ignore
"""
input_seq: Input sequence (batch_size x sequence_length x input_dim)
return: Sequence classification output (batch_size x target_indices x output_dim)
"""
batch_size, seq_length, _ = input_seq[0].size()
# GRU forward pass and linear mapping from hidden state to the output space.
# gru_out of shape [batch_size, seq_length, hidden_dim]
gru_out: torch.Tensor = self.gru(input_seq[0], self.h0.repeat(1, batch_size, 1))
# pad the gru output if required to ensure values for each target index
gru_out = self.pad_gru_output(gru_out)
if self.ref_indices is None:
return self.hidden2class(gru_out[:, self.target_indices, :])
else:
predictions = []
for target_index in self.target_indices:
input_to_classifier = gru_out[:, self.ref_indices + [target_index], :].view(batch_size, -1)
predictions.append(self.hidden2class(input_to_classifier))
return torch.stack(predictions, dim=1)
def pad_gru_output(self, input: torch.Tensor) -> torch.Tensor:
"""
Pad the GRU output with zeros if required to make sure to make sure there is a value for each target index
this RNN classifier is initialized with.
"""
current_sequence_length = len(range(input.shape[0]))
required_sequence_length = len(range(max(self.target_indices) + 1))
pad_size = max(required_sequence_length - current_sequence_length, 0)
return F.pad(input=input, pad=[0, 0, 0, pad_size], mode='constant', value=0)
def initialise_parameters(self) -> None:
"""
Initialises the initial hidden state parameter in GRU.
"""
# Disable type checking here because these parameters are created via setattr, and hence
# confuse mypy
nn.init.xavier_normal_(self.h0, gain=nn.init.calculate_gain('tanh'))
def get_input_tensors(self, sequences: List[ClassificationItemSequence]) -> List[torch.Tensor]:
"""
Returns the input tensors as a List where the first element corresponds to the non-imaging features.
"""
seq_flattened = [torch.stack([i.get_all_non_imaging_features() for i in seq.items], dim=0)
for seq in sequences]
return [sequences_to_padded_tensor(seq_flattened)]
class RNNClassifierWithEncoder(RNNClassifier):
"""
RNN classifier for a combination of imaging and non-imaging features. The images are first encoded using
an image encoder that is passed to the constructor.
:param image_encode: torch module to use to encode the image features. For example a ImageEncoder could be
for a U-Net like encoding of the images.
:param input_dim: Number of input channels for the GRU layer i.e. number of non_imaging features + number
of features at the output of the image encoder (if images/segmentations are used).
:param hidden_dim: Number of hidden states
:param output_dim: Number of model output channels
:param num_rnn_layers: Number of RNN layers to be stacked in the classifier. By default, a single GRU layer is used.
:param use_layer_norm: If set to True, hidden state activations are normalised at each time step.
:param target_indices: Output target indices. For many input to one output sequential model,
it should be equal to the last index `-1`. If a tensor of indices are provided,
it will return sequential model outputs at given time indices.
:param ref_indices: Optional, if set then the hidden state from these reference indices is concatenated to the
hidden state of the target position before computing the class posterior.
:param use_encoder_batch_norm: If True, apply batchNorm to the encoded features at the output of the image_encoder
module prior to feeding them to the GRU layers. If False, the raw output from the image encoder is fed to the GRU
layer.
"""
def __init__(self,
input_dim: int,
hidden_dim: int,
output_dim: int,
use_encoder_batch_norm: bool = False,
image_encoder: Optional[DeviceAwareModule[ScalarItem, torch.Tensor]] = None,
**kwargs: Any) -> None:
super().__init__(input_dim, hidden_dim, output_dim, **kwargs)
self.image_encoder = image_encoder
if self.image_encoder is not None:
# Adding necessary attributes required by GradCam computation.
self.imaging_feature_type = image_encoder.imaging_feature_type # type: ignore
self.num_non_image_features = image_encoder.num_non_image_features # type: ignore
self.last_encoder_layer = ["image_encoder"] + image_encoder.get_last_encoder_layer_names() # type: ignore
self.conv_in_3d = self.image_encoder.conv_in_3d
self.encode_channels_jointly = False
self.use_encoder_batch_norm = use_encoder_batch_norm
self.layer_norm = nn.BatchNorm1d(input_dim) if use_encoder_batch_norm else Identity()
def forward(self, *input_seq: torch.Tensor) -> torch.Tensor: # type: ignore
"""
input_seq: Input sequence (batch_size x sequence_length x input_dim)
return: Sequence classification output (batch_size x target_indices x output_dim)
"""
batch_size, seq_length = input_seq[0].size()[:2]
# If we have a model that uses images the input will be a List with 2 elements
# (imaging_features, non_imaging_features). Else it will be a List with only one
# element (non_imaging_features).
non_imaging_seq = input_seq[0] if len(input_seq) == 1 else input_seq[1]
encoded_seq = []
if self.image_encoder is not None:
imaging_seq = input_seq[0]
for seq in range(seq_length):
encoded_features = self.image_encoder(imaging_seq[:, seq, :], non_imaging_seq[:, seq, :])
if self.training and batch_size == 1 and self.use_encoder_batch_norm:
# This check is necessary as BatchNorm fails if the
# batch_size is equal to 1 (can't compute the variance).
logging.warning("BatchNorm will not be applied to the encoded image features as the"
"effective batch size is 1 on this device.")
else:
encoded_features = self.layer_norm(encoded_features)
encoded_seq.append(encoded_features)
encoded_input = non_imaging_seq if encoded_seq == [] else torch.stack(encoded_seq, dim=1)
return super().forward(encoded_input)
def get_last_encoder_layer_names(self) -> List[str]:
return self.last_encoder_layer
def get_input_tensors(self, sequences: List[ClassificationItemSequence]) -> List[torch.Tensor]:
"""
Returns the input tensors as a List where the first element corresponds to the non-imaging features.
The second corresponds to the images loaded as required by the image encoder.
"""
non_imaging_seq = super().get_input_tensors(sequences)[0]
if self.image_encoder is not None:
seq_flattened_imaging = [torch.stack([self.image_encoder.get_input_tensors(item)[0]
for item in seq.items], dim=0)
for seq in sequences]
imaging_seq = sequences_to_padded_tensor(seq_flattened_imaging)
return [imaging_seq, non_imaging_seq]
else:
return [non_imaging_seq]

Просмотреть файл

@ -15,7 +15,6 @@ from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_in
from InnerEye.ML.lightning_models import ScalarLightning
from InnerEye.ML.pipelines.inference import InferencePipelineBase
from InnerEye.ML.scalar_config import EnsembleAggregationType, ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
@ -103,10 +102,7 @@ class ScalarInferencePipeline(ScalarInferencePipelineBase):
:return: Returns ScalarInferencePipelineBase.Result with the subject ids, ground truth labels and predictions.
"""
assert isinstance(self.model_config, ScalarModelBase)
target_indices = self.model_config.get_target_indices() \
if isinstance(self.model_config, SequenceModelBase) else []
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model.model,
target_indices=target_indices,
sample=sample)
model_inputs_and_labels.move_to_device(self.model.device)
with torch.no_grad():

Просмотреть файл

@ -54,7 +54,6 @@ from InnerEye.ML.reports.notebook_report import generate_classification_crossval
generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \
get_ipynb_report_name, reports_folder
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler, download_all_checkpoints_from_run
from InnerEye.ML.visualizers import activation_maps
from InnerEye.ML.visualizers.plot_cross_validation import \
@ -96,9 +95,7 @@ def is_classification_model(model: Any) -> bool:
"""
Returns True if the given object is an InnerEye classification, but not a sequence model.
"""
return (isinstance(model, ScalarModelBase)
and model.is_classification_model
and not isinstance(model, SequenceModelBase))
return isinstance(model, ScalarModelBase)
class MLRunner:
@ -843,7 +840,7 @@ class MLRunner:
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
else:
if isinstance(config, ScalarModelBase) and not isinstance(config, SequenceModelBase):
if isinstance(config, ScalarModelBase):
generate_classification_notebook(
result_notebook=reports_dir / get_ipynb_report_name(config.model_category.value),
config=config,

Просмотреть файл

@ -1,157 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, Dict, List, Optional
import pandas as pd
import param
from pandas import DataFrame
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
from InnerEye.ML.metrics_dict import SequenceMetricsDict
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.utils.split_dataset import DatasetSplits
SEQUENCE_LENGTH_STATS_FILE = "sequence_length_stats.txt"
SEQUENCE_LENGTH_FILE = "sequence_length.csv"
class SequenceModelBase(ScalarModelBase):
sequence_column: Optional[str] = \
param.String(allow_None=True, default=None,
doc="If provided, create a sequence dataset, ordering by the column given here. The value in that "
"column is expected to be numeric, starting at 0.")
min_sequence_position_value: int = \
param.Integer(default=0, bounds=(0, None),
doc="When creating a sequence dataset, restrict it to items with a sequence index at min this "
"value. For example, if sequence_min_index==2, only items with sequence positions >= 2 will "
"be retained, and sequences that don't have all positions up to (including) 2 will be "
"discarded.")
max_sequence_position_value: Optional[int] = \
param.Integer(default=None, allow_None=True, bounds=(0, None),
doc="When creating a sequence dataset, restrict it to items with a sequence index "
"at most this value. For example, if sequence_max_index==3, only items with sequence "
"positions 0, 1, 2, and 3 will be retained.")
sequence_target_positions: List[int] = \
param.List(class_=int,
doc="Stores the sequence positions for which the model should make predictions. Sequence positions "
"are given by the value in the sequence_column. For example, if a sequence consists of items "
"with positions [2, 3, 4, 5], and sequence_target_position==[2,5], the model would be evaluated "
"on the first and last sequence elements.")
temperature_scaling_config: Optional[TemperatureScalingConfig] = param.ClassSelector(
class_=TemperatureScalingConfig,
allow_None=True,
default=None,
doc="If a config is provided then it will be used to learn a temperature scaling parameter using the "
"validation set to calibrate the model logits see: https://arxiv.org/abs/1706.04599 for each "
"epoch that requires a checkpoint to be saved. Turned off by default.")
def __init__(self, **params: Any):
super().__init__(**params)
# For sequence models, create a hook for computing dataset statistics by default, because sequence
# length is expected to have a major impact on performance. If an alternative hook is needed,
# overwrite the hook in a derived class or after instantiating the model configuration.
self.dataset_stats_hook = self.compute_dataset_stats_hook
if len(self.sequence_target_positions) == 0:
raise ValueError("sequence_target_positions must not be empty")
if self.temperature_scaling_config:
logging.info(f"Temperature scaling will be performed on the "
f"validation set using the config: {self.temperature_scaling_config}")
def validate(self) -> None:
self.target_names = [SequenceMetricsDict.get_hue_name_from_target_index(p)
for p in self.sequence_target_positions]
def get_target_indices(self) -> List[int]:
"""
Computes the zero based array indices inside of a sequence of items
for which the model should make predictions.
"""
return [pos - self.min_sequence_position_value for pos in self.sequence_target_positions]
def get_total_number_of_numerical_non_imaging_features(self) -> int:
return len(self.numerical_columns)
def get_total_number_of_categorical_non_imaging_features(self) -> int:
if self.categorical_feature_encoder:
return sum([self.categorical_feature_encoder.get_feature_length(col) for col in self.categorical_columns])
else:
return 0
@property
def is_non_imaging_model(self) -> bool:
"""
Returns whether the model uses non image features only
"""
return self.image_file_column is None
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
sample_transform = self.get_scalar_item_transform()
assert sample_transform.train is not None # for mypy
assert sample_transform.val is not None # for mypy
assert sample_transform.test is not None # for mypy
train = SequenceDataset(self,
dataset_splits.train,
name="training",
sample_transform=sample_transform.train)
val = SequenceDataset(self,
dataset_splits.val,
feature_statistics=train.feature_statistics,
name="validation",
sample_transform=sample_transform.val)
test = SequenceDataset(self,
dataset_splits.test,
feature_statistics=train.feature_statistics,
name="test",
sample_transform=sample_transform.test)
return {
ModelExecutionMode.TRAIN: train,
ModelExecutionMode.VAL: val,
ModelExecutionMode.TEST: test
}
def compute_dataset_stats_hook(self, datasets: Dict[ModelExecutionMode, Any]) -> None:
"""
Writes files with details and summary statistics about the datasets for each of the 3 dataset
splits (train/val/test).
"""
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
mode_series = []
id_series = []
length_series = []
for mode in ModelExecutionMode:
dataset = datasets[mode]
assert isinstance(dataset, SequenceDataset)
for seq in dataset.items:
mode_series.append(mode.value)
id_series.append(seq.id)
length_series.append(len(seq.items))
# Add a constant column that is the cross validation index, so that we can more easily merge these files later
# in the post-crossvalidation hook.
df = DataFrame.from_dict({
LoggingColumns.CrossValidationSplitIndex.value: [self.cross_validation_split_index] * len(mode_series),
LoggingColumns.DataSplit.value: mode_series,
LoggingColumns.Patient.value: id_series,
LoggingColumns.SequenceLength.value: length_series
})
self.logs_folder.mkdir(exist_ok=True, parents=True)
details_file = self.logs_folder / SEQUENCE_LENGTH_FILE
df.to_csv(details_file, index=False)
# Drop all columns apart from the sequence length column, so that the stats file will also contain
# the name of the series that is described
stats = df.drop(columns=[LoggingColumns.Patient.value, LoggingColumns.CrossValidationSplitIndex.value]) \
.groupby(by=LoggingColumns.DataSplit.value).describe()
out_file = self.logs_folder / SEQUENCE_LENGTH_STATS_FILE
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 150):
out_file.write_text(str(stats))

Просмотреть файл

@ -5,19 +5,16 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, List, TypeVar
from typing import Any, List
import torch
from InnerEye.Common.common_util import check_properties_are_not_none
from InnerEye.ML.dataset.scalar_dataset import ScalarDataSource, SequenceDataSource
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
FT = TypeVar('FT', ClassificationItemSequence[SequenceDataSource], ScalarDataSource)
from InnerEye.ML.dataset.scalar_dataset import ScalarDataSource
@dataclass(frozen=True)
class FeatureStatistics(Generic[FT]):
class FeatureStatistics:
"""
Class to store statistics (mean and standard deviation) of a set of features in a given dataset.
Allows to perform feature standardization for this set of features.
@ -29,7 +26,7 @@ class FeatureStatistics(Generic[FT]):
check_properties_are_not_none(self)
@staticmethod
def from_data_sources(sources: List[FT]) -> FeatureStatistics:
def from_data_sources(sources: List[ScalarDataSource]) -> FeatureStatistics:
"""
For the provided data sources, compute the mean and std across all non-image features across all entries.
@ -40,11 +37,7 @@ class FeatureStatistics(Generic[FT]):
if len(sources) == 0:
raise ValueError("sources must have a length greater than 0")
data_sources: List[Any] # for mypy
if isinstance(sources[0], ClassificationItemSequence):
data_sources = [item for seq in sources for item in seq.items]
else:
data_sources = sources
data_sources: List[Any] = sources
numerical_non_image_features = [x.numerical_non_image_features for x in data_sources]
if len(numerical_non_image_features) == 0:
@ -88,7 +81,7 @@ class FeatureStatistics(Generic[FT]):
std = torch.sqrt(torch.max(variance, torch.zeros_like(variance)))
return FeatureStatistics(mean=mean, std=std)
def standardize(self, sources: List[FT]) -> List[FT]:
def standardize(self, sources: List[ScalarDataSource]) -> List[ScalarDataSource]:
"""
For the provided data sources, apply standardization to the non-image features in each source. This will
standardize them to mean 0, variance 1 across all sequences.
@ -104,14 +97,7 @@ class FeatureStatistics(Generic[FT]):
new_features[zero_or_nan] = source.numerical_non_image_features[zero_or_nan]
return source.clone_with_overrides(numerical_non_image_features=new_features)
def apply_sequence(seq: ClassificationItemSequence) -> ClassificationItemSequence:
# noinspection PyTypeChecker
return ClassificationItemSequence(id=seq.id, items=list(map(apply_source, seq.items)))
if len(sources) > 0:
if isinstance(sources[0], ClassificationItemSequence):
return list(map(apply_sequence, sources)) # type: ignore
else:
return list(map(apply_source, sources)) # type: ignore
return list(map(apply_source, sources)) # type: ignore
else:
return sources

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------
import logging
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
from typing import Any, Dict, Iterator, List, Optional, Union
import torch
from torch.nn import MSELoss
@ -18,7 +18,6 @@ from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import ModelArchitectureConfig, PaddingMode, SegmentationLoss, SegmentationModelBase, \
basic_size_shrinkage
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
from InnerEye.ML.deep_learning_config import OptimizerParams, OptimizerType
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel, CropSizeConstraints
@ -30,11 +29,9 @@ from InnerEye.ML.models.losses.cross_entropy import CrossEntropyLoss
from InnerEye.ML.models.losses.mixture import MixtureLoss
from InnerEye.ML.models.losses.soft_dice import SoftDiceLoss
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
from InnerEye.ML.utils.supervised_criterion import BinaryCrossEntropyWithLogitsLoss, SupervisedLearningCriterion
from InnerEye.ML.utils.temperature_scaling import ModelWithTemperature
from InnerEye.ML.visualizers.model_summary import ModelSummary
@ -221,9 +218,7 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware
# get_model_input function to convert the dataset item to input tensors, and feed them through the model.
train_dataset = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN)
train_item_0 = next(iter(train_dataset.as_data_loader(shuffle=False, batch_size=1, num_dataload_workers=0)))
target_indices = config.get_target_indices() if isinstance(config, SequenceModelBase) else []
model_inputs = get_scalar_model_inputs_and_labels(model,
target_indices=target_indices,
sample=train_item_0)
# The model inputs may already be converted to float16, assuming that we would do mixed precision.
# However, the model is not yet converted to float16 when this function is called, hence convert back to float32
@ -248,16 +243,11 @@ def create_model_with_temperature_scaling(config: ModelConfigBase) -> Any:
"""
# wrap the model around a temperature scaling model if required
model = config.create_model()
if isinstance(config, SequenceModelBase) and config.temperature_scaling_config:
model = ModelWithTemperature(model, config.temperature_scaling_config)
return model
E = TypeVar('E', List[ClassificationItemSequence[ScalarItem]], ScalarItem)
@dataclass
class ScalarModelInputsAndLabels(Generic[E]):
class ScalarModelInputsAndLabels():
"""
Holds the results of calling get_scalar_model_inputs_and_labels: For a given sample returned by the data loader,
create the model inputs, the labels, the list of subjects (data loader sample can be batched),
@ -266,7 +256,7 @@ class ScalarModelInputsAndLabels(Generic[E]):
model_inputs: List[torch.Tensor]
labels: torch.Tensor
subject_ids: List[str]
data_item: E
data_item: ScalarItem
def __post_init__(self) -> None:
common_util.check_properties_are_not_none(self)
@ -281,7 +271,6 @@ class ScalarModelInputsAndLabels(Generic[E]):
def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
target_indices: List[int],
sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
"""
For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
@ -293,31 +282,14 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
:return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
label tensor, subject IDs, and the data item reconstructed from the data loader output
"""
if target_indices:
sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model # type: ignore
sequences = ClassificationItemSequence.from_minibatch(sample)
subject_ids = [x.id for x in sequences]
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
sequences=sequences,
target_indices=target_indices
)
model_inputs = sequence_model.get_input_tensors(sequences)
scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore
scalar_item = ScalarItem.from_dict(sample)
subject_ids = [str(x.id) for x in scalar_item.metadata] # type: ignore
model_inputs = scalar_model.get_input_tensors(scalar_item)
return ScalarModelInputsAndLabels[List[ClassificationItemSequence]](
model_inputs=model_inputs,
labels=labels,
subject_ids=subject_ids,
data_item=sequences
)
else:
scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore
scalar_item = ScalarItem.from_dict(sample)
subject_ids = [str(x.id) for x in scalar_item.metadata] # type: ignore
model_inputs = scalar_model.get_input_tensors(scalar_item)
return ScalarModelInputsAndLabels[ScalarItem](
model_inputs=model_inputs,
labels=scalar_item.label,
subject_ids=subject_ids,
data_item=scalar_item
)
return ScalarModelInputsAndLabels(
model_inputs=model_inputs,
labels=scalar_item.label,
subject_ids=subject_ids,
data_item=scalar_item
)

Просмотреть файл

@ -3,15 +3,10 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import abc
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Dict, List, Optional
import torch
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils.rnn import PackedSequence
from InnerEye.ML.utils.sequence_utils import map_packed_sequence_data
T = TypeVar('T', torch.Tensor, PackedSequence)
class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
@ -27,7 +22,7 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
self.smoothing_eps = smoothing_eps
self.is_binary_classification = is_binary_classification
def forward(self, *input: T, **kwargs: Any) -> Any:
def forward(self, *input: torch.Tensor, **kwargs: Any) -> Any:
def _smooth_target(target: torch.Tensor) -> torch.Tensor:
if self.is_binary_classification or len(target.shape) <= 2:
_num_classes = 2
@ -39,12 +34,9 @@ class SupervisedLearningCriterion(torch.nn.Module, abc.ABC):
return target * (1.0 - self.smoothing_eps) + \
(1.0 - target) * self.smoothing_eps / (_num_classes - 1.0) # type: ignore
_input: List[T] = list(input)
_input: List[torch.Tensor] = list(input)
if self.smoothing_eps > 0.0:
if isinstance(_input[1], PackedSequence):
_input[1] = map_packed_sequence_data(_input[1], _smooth_target)
else:
_input[1] = _smooth_target(_input[1])
_input[1] = _smooth_target(_input[1])
return self.forward_minibatch(*_input, **kwargs)
@ -102,8 +94,5 @@ class BinaryCrossEntropyWithLogitsLoss(SupervisedLearningCriterion):
sorted(self._class_counts.items())] # Uses the first number on the tuple to compare
return torch.tensor(weights, dtype=torch.float32)
def forward_minibatch(self, output: T, target: T, **kwargs: Any) -> Any:
if isinstance(target, PackedSequence) and isinstance(output, PackedSequence):
return self._loss_fn(output.data.view(-1, 1), target.data.view(-1, 1))
else:
return self._loss_fn(output, target)
def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> Any:
return self._loss_fn(output, target)

Просмотреть файл

@ -1,661 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import Module
from InnerEye.Common.metrics_constants import SEQUENCE_POSITION_HUE_NAME_PREFIX
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
from InnerEye.ML.reports.notebook_report import convert_to_html
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES
from InnerEye.ML.visualizers.model_hooks import HookBasedFeatureExtractor
def _tensor_as_numpy(tensor: torch.Tensor) -> np.ndarray:
return tensor.detach().cpu().numpy()
class GradientBasedFeatureExtractor(HookBasedFeatureExtractor):
"""
Base class for GradCam and BackPropagation classes.
"""
def __init__(self, model: Module,
config: ScalarModelBase,
target_layer: Any,
target_pos: int = -1):
if not config.is_classification_model:
raise NotImplementedError("Visualizations maps with GradCam are only"
"implemented for classification models.")
super().__init__(model, target_layer)
self.config = config
self.hooks: List[Any] = []
if config.use_gpu:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.encode_jointly = getattr(self.net, "encode_channels_jointly", True)
self.imaging_feature_type = getattr(self.net, "imaging_feature_type", ImagingFeatureType.Image)
self.num_non_image_features = getattr(self.net, "num_non_image_features", 0)
self.target_label_index = target_pos
self.logits = torch.Tensor()
self.probabilities = torch.Tensor()
def remove_hooks(self) -> None:
for hook in self.hooks:
hook.remove()
self.hooks = []
def get_target_score(self) -> torch.Tensor:
"""
Returns the target score i.e. logits for a positive predicted class,
negative logits for negative target class.
:param probabilities: Probabilities associated to the logits
:param logits: Output of the network before Sigmoid
"""
if self.logits.shape[-1] != 1:
raise NotImplementedError("More than one output class")
return torch.where(self.probabilities > 0.5, self.logits, -self.logits)
def backward(self) -> None:
"""
Defines the backward pass. Computes first the target scores to use to
for backpropagation and then updates the `gradients` attribute based on
the current value of the `logits` attribute (set in the forward pass).
"""
target_scores = self.get_target_score() # [B, num_targets, 1] or [B, 1]
self.model.zero_grad()
# If we have a sequence model, with potentially multiple labels.
# Only backpropagate the gradients for the given target_pos
if isinstance(self.config, SequenceModelBase):
gradients_to_propagate = torch.zeros(target_scores.shape, device=self.device)
gradients_to_propagate[:, self.target_label_index, :] = 1
else:
gradients_to_propagate = torch.ones(target_scores.shape, device=self.device)
target_scores.backward(gradient=gradients_to_propagate)
self.remove_hooks()
class GradCam(GradientBasedFeatureExtractor):
"""
Class to generate GradCam maps for images, "Pseudo-GradCam" (i.e. ReLu(input x gradients))
for non-images features of one batch for the given classification model. Tested and maintained for
ImageEncoderWithMLP and RNNClassifier (models that take both images and non-imaging feautres as input).
GradCam computes Relu(Gradients x Activations) at the output of the encoder of the network
(before the global pooling layer). "PseudoGradCam" for non-imaging features denotes
ReLu(input x gradients) for non-imaging features. "PseudoGradCam" is
used to compare relative feature importance of various non-imaging features for the final classification
task.
"""
def __init__(self, model: Union[DeviceAwareModule, torch.nn.DataParallel],
config: ScalarModelBase) -> None:
"""
:param model: The model to analyse
:param config: The ScalarModelBase config defining the parameters of this model.
"""
self.total_num_categorical_features = config.get_total_number_of_categorical_non_imaging_features()
self.total_number_of_numerical_non_imaging_features = \
config.get_total_number_of_numerical_non_imaging_features()
self.is_non_imaging_model = config.is_non_imaging_model
if self.is_non_imaging_model:
super().__init__(model, config=config, target_layer=None)
else:
if isinstance(model, torch.nn.DataParallel):
_model: DeviceAwareModule = model.module # type: ignore
target_layer = _model.get_last_encoder_layer_names()
self.conv_in_3d = bool(_model.conv_in_3d)
else:
target_layer = model.get_last_encoder_layer_names()
self.conv_in_3d = bool(model.conv_in_3d)
super().__init__(model=model, config=config, target_layer=target_layer)
self.gradients: Dict = {}
self.activations: Dict = {}
def backward_hook_fn(self, module: Module, grad_in: torch.Tensor, grad_out: torch.Tensor) -> None:
"""
Backward hook to save the gradients per device (to allow GradCam to be computed
with DataParallel models when training on multiple GPUs).
"""
device = str(grad_out[0].get_device())
if device not in self.gradients:
self.gradients[device] = []
self.gradients[device].append(grad_out[0].data.clone())
def forward_hook_fn(self, module: Module, input: torch.Tensor, output: torch.Tensor) -> None:
"""
Forward hook to save the activations of a given layer (per device to allow GradCam to be computed
with DataParallel models when training on multiple GPUs.
"""
device = str(output[0].get_device())
if device not in self.activations:
self.activations[device] = []
if isinstance(output, tuple):
self.activations[device].append([output[index].data.clone() for index in range(len(output))])
else:
self.activations[device].append(output.data.clone())
def forward(self, *input) -> None: # type: ignore
"""
Triggers the call to the forward pass of the module. Prior to call the forward model function, we
set the forward and backward passes. When calling this function, the `activations` attribute
will containing the activations of the target layer for the given `input` batch passed as an
argument to this function.
"""
self.activations = {}
if self.layer_name is not None:
submodule = self.net
for el in self.layer_name:
submodule = submodule._modules[el] # type: ignore
target_layer = submodule
self.hooks.append(target_layer.register_forward_hook(self.forward_hook_fn))
self.hooks.append(target_layer.register_backward_hook(self.backward_hook_fn)) # type: ignore
self.logits = self.model(*input)
if isinstance(self.logits, List):
self.logits = torch.nn.parallel.gather(self.logits, target_device=self.device)
self.probabilities = torch.nn.Sigmoid()(self.logits)
def backward(self) -> None:
"""
Defines the backward pass. Computes first the target scores to use to
for backpropagation and then updates the `gradients` attribute based on
the current value of the `logits` attribute (set in the forward pass).
"""
self.gradients = {}
super().backward()
def _get_image_grad_cam(self, input: List[torch.Tensor]) -> np.ndarray:
"""
Get GradCam mps for images input. GradCam computes
Relu(Gradients x Activations) at the output of the encoder of the network
(before the global pooling layer).
:param input: input batch
:return: the GradCam maps
"""
list_gradients = []
list_activations = []
# put all channels in one tensor per device
for device in self.gradients:
list_gradients.append(torch.stack(self.gradients[device], dim=1)) # [B, C_in, C_out, Z, X, Y]
list_activations.append(torch.stack(self.activations[device], dim=1)) # [B, C_in, C_out, Z, X, Y]
if self.config.use_gpu:
activations = torch.nn.parallel.gather(list_activations, target_device=self.device)
gradients = torch.nn.parallel.gather(list_gradients, target_device=self.device)
else:
assert len(list_activations) == 1
activations = list_activations[0]
gradients = list_gradients[0]
self.gradients = {}
self.activations = {}
B, C_in = input[0].shape[:2]
Z, X, Y = input[0].shape[-3:]
B_act, _, C_act, Z_act, X_act, Y_act = activations.shape
if self.conv_in_3d:
weights = torch.mean(gradients, dim=(3, 4, 5), keepdim=True)
Z_low = Z_act
else:
weights = torch.mean(gradients, dim=(4, 5), keepdim=True)
Z_low = Z
del list_gradients, gradients
low_dim_cam = torch.nn.functional.relu(torch.mul(activations, weights).sum(dim=2))
del weights, list_activations, activations
# Case one separate encoding per channel i.e. one GradCam map per channel
if not self.encode_jointly:
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
assert low_dim_cam.shape == (B, C_in, Z_low, X_act, Y_act) \
or low_dim_cam.shape == (B, C_in / HDF5_NUM_SEGMENTATION_CLASSES, Z_low, X_act, Y_act)
elif self.imaging_feature_type == ImagingFeatureType.Image:
assert low_dim_cam.shape == (B, C_in, Z_low, X_act, Y_act)
# Case one global encoding i.e. one GradCam map per image
else:
assert low_dim_cam.shape == (B, 1, Z_low, X_act, Y_act)
grad_cam = torch.nn.functional.interpolate(
low_dim_cam,
(Z, X, Y),
mode="trilinear"
)
return _tensor_as_numpy(grad_cam)
def _get_non_imaging_grad_cam(self) -> np.ndarray:
"""
Computes the "Pseudo GradCam" for non-imaging features i.e.
ReLu(non_imaging_inputs x gradients).
"""
assert self.non_image_input.grad is not None
total_pseudo_cam_non_image = _tensor_as_numpy(torch.nn.functional.relu(
torch.mul(self.non_image_input, self.non_image_input.grad)))
batch_size = self.non_image_input.shape[0]
non_image_input = _tensor_as_numpy(self.non_image_input.detach())
if self.total_num_categorical_features > 0:
if len(total_pseudo_cam_non_image.shape) == 2:
total_pseudo_cam_non_image = total_pseudo_cam_non_image.reshape(batch_size, 1, -1)
non_image_input = self.non_image_input.reshape(batch_size, 1, -1)
pseudo_cam_numerical = total_pseudo_cam_non_image[:, :,
:self.total_number_of_numerical_non_imaging_features]
pseudo_cam_one_hot = total_pseudo_cam_non_image[:, :,
self.total_number_of_numerical_non_imaging_features:]
categorical_input_one_hot = non_image_input[:, :, self.total_number_of_numerical_non_imaging_features:]
# Back to "not one hot", only one value per feature is non zero
batch_size, number_positions = pseudo_cam_one_hot.shape[:2]
if isinstance(self.config, SequenceModelBase):
pseudo_cam_categorical = np.zeros((batch_size, number_positions, len(self.config.categorical_columns)))
for b in range(batch_size):
for t in range(number_positions):
# Some features come from sequence padding, for those the entire row is 0 i.e. the feature
# is not really one-hot encoded.
if np.any(categorical_input_one_hot[b, t] != 0):
pseudo_cam_categorical[b, t] = pseudo_cam_one_hot[
b, t, categorical_input_one_hot[b, t] != 0]
else:
# For a non-sequence model a categorical feature might appear several times for several channels but
# there
# is no padding. Hence we handle the conversion differently.
pseudo_cam_categorical = pseudo_cam_one_hot[categorical_input_one_hot.cpu() != 0].reshape(
(batch_size, number_positions, -1))
return np.concatenate([pseudo_cam_numerical, pseudo_cam_categorical], axis=2)
else:
return total_pseudo_cam_non_image
def generate(self, input: List[torch.Tensor], target_position: int = -1, target_label_index: int = -1) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generates the GradCam for images, PseudoGradCam for non-imaging features
of one batch for the given classification model.
GradCam computes Relu(Gradients x Activations) at the output of the encoder of the network
(before the global pooling layer). "PseudoGradCam" for non-imaging features denotes
ReLu(input x gradients) for non-imaging features. "PseudoGradCam" is used to compare relative feature importance
of various non-imaging features for the final classification task.
:param input: input image [B, C, Z, X, Y]
:param target_position: in case of sequence model with multiple target weeks, specify which target
position prediction should be visualized. By default the last one.
:param target_label_index: index of the target label in the array of targets labels i.e. if target
positions are [2,3,5], the target_label_index for position 3 is 1.
:return: grad_cam: grad_cam maps [B, Z, X, Y]
"""
self.target_label_index = target_label_index
self.model.eval()
if self.num_non_image_features > 0:
self.non_image_input = input[1].clone().to(self.device).requires_grad_(True)
self.forward(*[input[0], self.non_image_input])
elif self.is_non_imaging_model:
self.non_image_input = input[0].clone().to(self.device).requires_grad_(True)
self.forward(self.non_image_input)
else:
self.forward(*input)
self.backward()
with torch.no_grad():
grad_cam_image = None
pseudo_cam_non_image = None
if not self.is_non_imaging_model:
grad_cam_image = self._get_image_grad_cam(input)
if target_position > -1:
grad_cam_image = grad_cam_image[:, :(target_position + 1), ...]
if self.num_non_image_features > 0 or self.is_non_imaging_model:
pseudo_cam_non_image = self._get_non_imaging_grad_cam()
if target_position > -1:
pseudo_cam_non_image = pseudo_cam_non_image[:, :(target_position + 1), ...]
if self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
if self.encode_jointly:
# In this case broadcasting can happen automatically in GuidedGradCam
# computation.
assert grad_cam_image is not None # for mypy
assert grad_cam_image.shape[1] == 1
else:
# Otherwise, copy GradCam output twice to compute GuidedGradCam (once for images and
# once for segmentations).
grad_cam_image = np.concatenate([grad_cam_image, grad_cam_image], axis=1)
return grad_cam_image, pseudo_cam_non_image, _tensor_as_numpy(self.probabilities[:, target_label_index])
class GuidedBackPropagation(GradientBasedFeatureExtractor):
"""
Class to compute GuidedBackPropagation maps for images features.
"""
def __init__(self, model: Module, config: ScalarModelBase) -> None:
super().__init__(model=model, config=config, target_layer=None)
def guided_backprop_hook(self, module: Module, grad_in: torch.Tensor, grad_out: torch.Tensor) \
-> Optional[Tuple[torch.Tensor]]:
"""
Backward hook for guided Backpropagation.
Propagate only positive gradient when backpropagating through ReLu layers.
"""
# For all ReLU layers propagate only positive gradients
if isinstance(module, torch.nn.ReLU):
return torch.nn.functional.relu(grad_in[0]),
return None
def forward(self, *input): # type: ignore
"""
Triggers the call to the forward pass of the module and the registration of the backward hook.
"""
for layer in self.net.modules():
# Type check disabled: the type is correct but the PyTorch documentation is not.
# noinspection PyTypeChecker
self.hooks.append(layer.register_backward_hook(self.guided_backprop_hook))
self.image_input_grad = input[0].clone().requires_grad_(True)
if self.num_non_image_features > 0:
self.logits = self.model(self.image_input_grad, input[1])
else:
self.logits = self.model(self.image_input_grad)
if isinstance(self.logits, List):
self.logits = torch.nn.parallel.gather(self.logits, target_device=self.device)
self.probabilities = torch.nn.Sigmoid()(self.logits)
def generate(self, input: List[torch.Tensor],
target_position: int = -1,
target_label_index: int = -1) -> np.ndarray:
"""
Generate Guided Backpropagation maps for one input batch.
:param input: input batch
:param target_position: in case of sequence model with multiple target weeks, specify which target
position prediction should be visualized. By default the last one.
:param target_label_index: index of the target label in the array of targets labels i.e. if target
positions are [2,3,5], the target_label_index for position 3 is 1.
:return: guided backprop maps, size [B, C, Z, X, Y]
"""
self.target_label_index = target_label_index
self.model.eval()
self.forward(*input)
if self.config.use_gpu:
torch.cuda.empty_cache()
self.backward()
B, C = input[0].shape[:2]
Z, X, Y = input[0].shape[-3:]
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
grads_of_one_hot = -_tensor_as_numpy(self.image_input_grad.grad)
one_hot_input = _tensor_as_numpy(self.image_input_grad)
backprop_map = grads_of_one_hot * one_hot_input
backprop_map = backprop_map.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES, Z, X, Y))
backprop_map = backprop_map.sum(axis=2) # [B, C, Z, X, Y]
if target_position > -1:
backprop_map = backprop_map[:, :(target_position + 1), ...]
return backprop_map
elif self.imaging_feature_type == ImagingFeatureType.Image:
backprop_map = self.image_input_grad.grad.detach().cpu().numpy().reshape((B, C, Z, X, Y))
if target_position > -1:
backprop_map = backprop_map[:, :(target_position + 1), ...]
return backprop_map
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
grads = _tensor_as_numpy(self.image_input_grad.grad) # [1, CLASSES, Z, X, Y] or [1, week, CLASSES, Z, X, Y]
input_image = _tensor_as_numpy(self.image_input_grad)
if len(grads.shape) == 5:
grads = grads.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES + 1, Z, X, Y))
input_image = input_image.reshape((B, -1, HDF5_NUM_SEGMENTATION_CLASSES + 1, Z, X, Y))
one_hot_input = input_image[:, :, :-1, ...]
grads_of_one_hot = grads[:, :, :-1, ...]
backprop_map_segmentation = (grads_of_one_hot * one_hot_input).sum(axis=2) # [B, C, Z, X, Y]
backprop_map_image = grads[:, :, -1, ...] # [B, C, Z, X, Y]
if target_position > -1:
backprop_map_segmentation = backprop_map_segmentation[:, :(target_position + 1), ...]
backprop_map_image = backprop_map_image[:, :(target_position + 1), ...]
return np.concatenate([backprop_map_segmentation, backprop_map_image], axis=1) # [B, 2*C, Z, X, Y]
else:
ValueError("This imaging feature type is not supported.")
class VisualizationMaps:
"""
Wrapper class to compute GradCam maps, GuidedGradCam and "Pseudo-GradCam" maps
for a specific model.
"""
def __init__(self, model: Union[DeviceAwareModule, torch.nn.DataParallel],
config: ScalarModelBase) -> None:
self.config = config
self.is_non_imaging_model = config.is_non_imaging_model
self.grad_cam: GradCam = GradCam(model, config)
if not self.is_non_imaging_model:
self.guided_backprop: GuidedBackPropagation = GuidedBackPropagation(model, config)
self.encode_channels_jointly: bool = self.guided_backprop.encode_jointly
self.imaging_feature_type = self.grad_cam.imaging_feature_type
def generate(self, input: List[torch.Tensor],
target_position: int = -1,
target_label_index: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Generates the GuidedGradCam, GradCam and PseudoGradCam maps (for non-imaging data)
for one batch of data.
:param input: input batch sixe [B, C, Z, X, Y]
:param target_position: in case of sequence model with multiple target weeks, specify which target
position prediction should be visualized. By default the last one.
:param target_label_index: index of the target label in the array of targets labels i.e. if target
positions are [2,3,5], the target_label_index for position 3 is 1.
:return: A tuple of GuidedGradCam maps (for image input), GradCam maps (for image input),
PseudoGradCam (for non-image inputs), posteriors predicted for the given target position.
by the model for this batch of data.
"""
image_gcam, pseudo_cam_non_img, probability = self.grad_cam.generate(input, target_position, target_label_index)
if self.is_non_imaging_model:
image_guided_gcam = None
else:
guided_bp = self.guided_backprop.generate(input, target_position, target_label_index)
image_guided_gcam = image_gcam * guided_bp
return image_guided_gcam, \
image_gcam, \
pseudo_cam_non_img, \
probability
def save_visualizations_in_notebook(self,
classification_sequence: Union[ScalarItem,
List[ClassificationItemSequence[ScalarItem]]],
input_batch: List[torch.Tensor],
filenames: List[str],
ground_truth_labels: np.ndarray,
gradcam_dir: Path
) -> None:
"""
Generate, plot and save the visualizations for one batch of
data for a sequence model. The visualization are produced in
a Jupyter Notebook for readability. There is one notebook generated
for each subject. This notebook can be viewed in the AML UI. Additionally
a HTML is produced containing only the cells' output.
:param input_batch: input to the network
:param classification_sequence: classification item for current batch
:param filenames: a list of filenames for the plots size: [Batch]
:param ground_truth_labels: the labels for this input_batch
:param gradcam_dir: directory where to save the plots.
"""
non_image_features = self.config.numerical_columns + self.config.categorical_columns
has_non_image_features = len(non_image_features) > 0
batch_size = len(filenames)
if isinstance(self.config, SequenceModelBase):
target_indices = self.config.get_target_indices()
if target_indices is None:
target_indices = [-1]
else:
target_indices = [-1]
for label_index in range(len(target_indices)):
target_position = target_indices[label_index]
current_output_dir = self.config.visualization_folder / f"{SEQUENCE_POSITION_HUE_NAME_PREFIX}_" \
f"{target_position}"
current_output_dir.mkdir(exist_ok=True)
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = self.generate(input_batch,
target_position,
label_index)
for i in range(batch_size):
if not self.is_non_imaging_model:
non_imaging_labels = self._get_non_imaging_plot_labels(
classification_sequence, # type: ignore
non_image_features,
index=i,
target_position=target_position)
if isinstance(self.config, SequenceModelBase):
image = self._get_image_attributes_for_sequence_item(classification_sequence, # type: ignore
index=i,
target_position=target_position)
else:
image = self._get_image_attributes_for_scalar_item(classification_sequence, i) # type: ignore
# Need to temporarily save the variables to access them from the notebook.
# Because papermill does not support passing numpy array as parameters.
np.save(str(gradcam_dir / "image.npy"), image)
np.save(str(gradcam_dir / "gradcam.npy"), grad_cams[i])
np.save(str(gradcam_dir / "guided_grad_cam.npy"), guided_grad_cams[i])
if has_non_image_features:
np.save(str(gradcam_dir / "non_image_pseudo_cam.npy"), pseudo_cam_non_img[i])
has_image_features = True
else:
non_imaging_labels = self._get_non_imaging_plot_labels(
classification_sequence, non_image_features, index=i, target_position=target_position)
has_non_image_features = True
has_image_features = False
self.encode_channels_jointly = False
self.imaging_feature_type = ImagingFeatureType.Image
np.save(str(gradcam_dir / "non_image_pseudo_cam.npy"), pseudo_cam_non_img[i])
current_label = ground_truth_labels[i, label_index]
# If the label is NaN it means that we don't have data for this position and
# we used padding for the input. Hence do not save visualizations for this position.
if not np.isnan(current_label):
params_dict = dict(subject_id=filenames[i],
target_position=target_position,
gradcam_dir=str(gradcam_dir),
has_non_image_features=has_non_image_features,
probas=str(probas[i]),
ground_truth_labels=str(current_label),
non_image_labels=non_imaging_labels,
encode_jointly=self.encode_channels_jointly,
imaging_feature_type=self.imaging_feature_type.value,
has_image_features=has_image_features,
value_image_and_segmentation=ImagingFeatureType.ImageAndSegmentation.value, )
result_path = str(current_output_dir.joinpath(f"{filenames[i]}.ipynb"))
import papermill
papermill.execute_notebook(os.path.join(os.path.dirname(os.path.realpath(__file__)),
"gradcam_visualization.ipynb"),
result_path,
parameters=params_dict,
progress_bar=False)
convert_to_html(Path(result_path))
def _get_non_imaging_plot_labels(self, classification_item: Union[ScalarItem,
List[ClassificationItemSequence[ScalarItem]]],
non_image_features: List[str],
index: int,
target_position: int = -1) -> List[str]:
"""
Gets labels to use for the plots of non-imaging feature importance.
:param classification_item: The classification item for which the return the
label (can vary from subject to subject as they might not all contain the same
position in case of a sequence model).
:param non_image_features: The name of the imaging features used by the model.
:param index: The index of the subject in the batch (used only for sequence models).
:return: the labels (list of string)
"""
if isinstance(self.config, SequenceModelBase):
channels = []
for item in classification_item[index].items: # type: ignore
if (item.metadata.sequence_position - self.config.min_sequence_position_value) <= target_position \
or target_position == -1:
channels.append(item.metadata.sequence_position)
return [f"{col}_{channel}" for channel in channels for col in
non_image_features] # type: ignore
else:
non_imaging_labels = []
non_image_features = self.config.numerical_columns + self.config.categorical_columns
non_image_feature_channels_dict = self.config.get_non_image_feature_channels_dict()
for col in non_image_features:
non_imaging_labels.extend(
[f"{col}_{channel}" for channel in non_image_feature_channels_dict[col]]) # type: ignore
return non_imaging_labels
def _get_image_attributes_for_sequence_item(self,
classification_sequence: List[ClassificationItemSequence[
ScalarItem]],
index: int,
target_position: int) -> np.ndarray:
"""
Extract the image and/or the segmentation for the classification item to be able to
produce the visualizations.
:param classification_sequence: The classification sequence for which to plot (contains
the entire batch)
:param index: the exact subject for which to plot.
:return: An array containing the imaging input to plot.
"""
images = []
segmentations = []
for item in classification_sequence[index].items:
if (item.metadata.sequence_position - self.config.min_sequence_position_value) <= target_position \
or target_position == -1:
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
segmentations.append(item.segmentations)
elif self.imaging_feature_type == ImagingFeatureType.Image:
images.append(item.images)
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
segmentations.append(item.segmentations)
images.append(item.images)
if self.imaging_feature_type == ImagingFeatureType.Image:
return _tensor_as_numpy(torch.cat(images, dim=0)).astype(float) # type: ignore
elif self.imaging_feature_type == ImagingFeatureType.Segmentation:
return _tensor_as_numpy(torch.cat(segmentations, dim=0)).astype(int) # type: ignore
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
return np.concatenate(
[_tensor_as_numpy(torch.cat(images, dim=0)).astype(float), # type: ignore
_tensor_as_numpy(torch.cat(segmentations, dim=0)).astype(int)], # type: ignore
axis=0)
def _get_image_attributes_for_scalar_item(self, classification_item: ScalarItem, index: int) -> np.ndarray:
"""
Extract the image and/or the segmentation for the classification item to be able to
produce the visualizations.
:param classification_item: The classification items for which to plot (contains the
entire batch data).
:param index: the exact subject for which to plot.
:return: An array containing the imaging input to plot.
"""
if self.imaging_feature_type == ImagingFeatureType.Segmentation:
if classification_item.segmentations is None:
raise ValueError("Expected classification_item.segmentations to not be None")
return _tensor_as_numpy(classification_item.segmentations[index]).astype(int)
elif self.imaging_feature_type == ImagingFeatureType.Image:
return _tensor_as_numpy(classification_item.images[index]).astype(float)
elif self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
image = _tensor_as_numpy(classification_item.images[index]).astype(float) # [C, Z, X, Y]
assert classification_item.segmentations is not None # for mypy
seg = _tensor_as_numpy(classification_item.segmentations[index]).astype(int)
return np.concatenate([seg, image], axis=0)
else:
raise ValueError("The provided imaging feature type is not supported.")

Просмотреть файл

@ -11,7 +11,6 @@ On the modelling side, this toolbox supports
- Segmentation models
- Classification and regression models
- Sequence models
- Adding cloud support to any PyTorch Lightning model, via a [bring-your-own-model setup](docs/bring_your_own_model.md)
- Active label cleaning and noise robust learning toolbox (stand-alone folder)
@ -48,8 +47,7 @@ often seen with medical images.
- Easy creation of new models via a configuration-based approach, and inheritance from an existing
architecture.
Once training in AzureML is done, the models can be deployed from within AzureML or via
[Azure Stack Hub](https://azure.microsoft.com/en-us/products/azure-stack/hub/).
Once training in AzureML is done, the models can be deployed from within AzureML.
## Getting started
@ -180,14 +178,7 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Credits
## This toolbox is maintained by the
[Microsoft Medical Image Analysis team](https://www.microsoft.com/en-us/research/project/medical-image-analysis/).
This toolbox is maintained by the
[Microsoft InnerEye team](https://www.microsoft.com/en-us/research/project/medical-image-analysis/),
and has received valuable contributions from a number
of people outside our team. We would like to thank in particular our interns,
[Yao Quin](http://cseweb.ucsd.edu/~yaq007/), [Zoe Landgraf](https://www.linkedin.com/in/zoe-landgraf-a2212293),
[Padmaja Jonnalagedda](https://www.linkedin.com/in/jspadmaja/),
[Mathias Perslev](https://github.com/perslev), as well as the AI Residents
[Patricia Gillespie](https://www.microsoft.com/en-us/research/people/t-pagill/) and
[Guilherme Ilunga](https://gilunga.github.io/).

Просмотреть файл

@ -185,7 +185,7 @@ S1,image2,img12.nii,True
S2,image2,image22.nii,False
""")
df = pd.read_csv(csv_string, sep=",", dtype=str)
items: List[ScalarDataSource] = DataSourceReader[ScalarDataSource](
items: List[ScalarDataSource] = DataSourceReader(
data_frame=df,
image_channels=["image1", "image2"],
image_file_column="path",
@ -197,7 +197,7 @@ S2,image2,image22.nii,False
def test_load_items_errors() -> None:
"""
Test error cases when creating a list of classificationItems from a dataframe
Test error cases when creating a list of classification Items from a dataframe
"""
def load(csv_string: StringIO) -> str:

Просмотреть файл

@ -1,676 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from io import StringIO
from pathlib import Path
from typing import List, Optional, Union
from unittest import mock
import numpy as np
import pandas as pd
import pytest
import torch
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.dataset.full_image_dataset import collate_with_metadata
from InnerEye.ML.dataset.sample import GeneralSampleMetadata
from InnerEye.ML.dataset.scalar_dataset import DataSourceReader, filter_valid_classification_data_sources_items
from InnerEye.ML.dataset.scalar_sample import ScalarDataSource, ScalarItem, SequenceDataSource
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset, add_difference_features, \
group_samples_into_sequences
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.features_util import FeatureStatistics
from InnerEye.ML.utils.io_util import ImageAndSegmentations
from InnerEye.ML.utils.ml_util import set_random_seed
from InnerEye.ML.utils.sequence_utils import sequences_to_padded_tensor
from InnerEye.ML.utils.split_dataset import DatasetSplits
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToyMultiLabelSequenceModel, \
_get_multi_label_sequence_dataframe
from Tests.ML.util import assert_tensors_equal, create_dataset_csv_file
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
def test_load_items_seq() -> None:
"""
Test loading file paths and labels from a datafrome if
"""
csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META
S1,0,foo.nii,,0,0,M1
S1,1,,True,1.1,1.2,M2
S2,1,bar.nii,False,2.1,2.2,
""")
df = pd.read_csv(csv_string, sep=",", dtype=str)
items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource](
data_frame=df,
image_channels=None,
image_file_column="path",
label_channels=None,
label_value_column="value",
numerical_columns=["scalar1", "scalar2"],
sequence_column="seq").load_data_sources()
assert len(items) == 3
assert isinstance(items[0].metadata, GeneralSampleMetadata)
assert items[0].metadata.id == "S1"
assert items[0].metadata.props == {"META": "M1"}
assert items[0].metadata.sequence_position == 0
assert len(items[0].label.tolist()) == 1
assert math.isnan(items[0].label.item())
assert items[0].channel_files == ["foo.nii"]
assert_tensors_equal(items[0].numerical_non_image_features, [0.0, 0.0])
assert isinstance(items[1].metadata, GeneralSampleMetadata)
assert items[1].metadata.id == "S1"
assert items[1].metadata.props == {"META": "M2"}
assert items[1].metadata.sequence_position == 1
assert_tensors_equal(items[1].label, [1.0])
assert items[1].channel_files == ['']
assert_tensors_equal(items[1].numerical_non_image_features, [1.1, 1.2])
assert isinstance(items[2].metadata, GeneralSampleMetadata)
assert items[2].metadata.id == "S2"
assert items[2].metadata.props == {"META": ''}
assert items[2].metadata.sequence_position == 1
assert_tensors_equal(items[2].label, [0.0])
assert items[2].channel_files == ["bar.nii"]
assert_tensors_equal(items[2].numerical_non_image_features, [2.1, 2.2])
def test_load_items_seq_from_dataset() -> None:
"""
Test loading a sequence dataset with numerical, categorical features and images.
"""
dummy_dataset = full_ml_test_data_path() / "sequence_data_for_classification" / "dataset.csv"
df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource](
data_frame=df,
image_channels=None,
image_file_column="IMG",
label_channels=None,
label_value_column="Label",
numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
sequence_column="Position").load_data_sources()
assert len(items) == 3 * 9 # 3 subjects, 9 visits each, no missing
assert items[0].metadata.id == "2137.00005"
assert items[0].metadata.sequence_position == 0
assert items[0].metadata.props["CAT2"] == "category_A"
# One of the labels is missing, missing labels should be encoded as NaN
assert math.isnan(items[0].label[0])
assert items[0].channel_files == ["img_1"]
assert str(items[0].numerical_non_image_features.tolist()) == str([362.0, np.nan, np.nan, 71.0])
assert items[8].metadata.id == "2137.00005"
assert items[8].metadata.sequence_position == 8
assert items[8].label.tolist() == [0.0]
assert items[8].channel_files == ['']
assert str(items[8].numerical_non_image_features.tolist()) == str([350.0, np.nan, np.nan, 8.0])
assert items[16].metadata.id == "2627.00001"
assert items[16].label.tolist() == [0.0]
assert items[16].channel_files == ["img_2"]
assert_tensors_equal(items[16].numerical_non_image_features, [217.0, 0.0, 0.01, 153.0])
assert items[26].metadata.id == "3250.00005"
assert items[26].metadata.sequence_position == 8
assert_tensors_equal(items[26].label, [0.0])
assert items[26].channel_files == ["img_11"]
assert_tensors_equal(items[26].numerical_non_image_features, [238.0, 0.0, 0.02, 84.0])
grouped = group_samples_into_sequences(
filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
max_sequence_position_value=None))
# There are 3 patients total, but one of them has missing measurements for all visits
assert len(grouped) == 2
assert grouped[0].id == "2627.00001"
assert grouped[1].id == "3250.00005"
# 2627.00001 has full information for weeks 0, 4, and 8
assert len(grouped[0].items) == 3
assert grouped[0].items[0].metadata["VISIT"] == "V1"
assert grouped[0].items[2].metadata["VISIT"] == "VST 3"
assert len(grouped[1].items) == 9
assert items[16].metadata.sequence_position == 7
def test_seq_dataset_loader() -> None:
dummy_dataset = full_ml_test_data_path() / "sequence_data_for_classification" / "dataset.csv"
df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
dataset = SequenceDataset(
args=SequenceModelBase(
image_file_column="IMG",
label_value_column="Label",
numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
sequence_target_positions=[8],
sequence_column="Position",
local_dataset=Path(),
should_validate=False
),
data_frame=df
)
assert len(dataset) == 2
# Patch the load_images function that well be called once we access a dataset item
with mock.patch('InnerEye.ML.dataset.scalar_sample.load_images_and_stack',
return_value=ImageAndSegmentations[torch.Tensor](images=torch.ones(1),
segmentations=torch.empty(0))):
item0 = ClassificationItemSequence(**dataset[0])
item1 = ClassificationItemSequence(**dataset[1])
assert item0.id == "2627.00001"
len_2627 = 3
assert len(item0.items) == len_2627
assert item1.id == "3250.00005"
len_3250 = 9
assert len(item1.items) == len_3250
# Data loaders use a customized collate function, that must work with the sequences too.
collated = collate_with_metadata([dataset[0], dataset[1]])
assert collated["id"] == ["2627.00001", "3250.00005"]
# All subject sequences should be turned into lists of lists.
assert isinstance(collated["items"], list)
assert len(collated["items"]) == 2
assert isinstance(collated["items"][0], list)
assert isinstance(collated["items"][1], list)
assert len(collated["items"][0]) == len_2627
assert len(collated["items"][1]) == len_3250
back_to_items = ClassificationItemSequence(**collated)
assert back_to_items.id == ["2627.00001", "3250.00005"]
def test_group_items() -> None:
"""
Test if grouping and filtering of sequence data sets works.
"""
def _create(id: str, sequence_position: int, file: Optional[str], metadata: str) -> SequenceDataSource:
return SequenceDataSource(channel_files=[file],
numerical_non_image_features=torch.tensor([]),
categorical_non_image_features=torch.tensor([]),
label=torch.tensor([]),
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
props={"M": metadata}))
items = [
_create("a", 1, "f", "a.1"),
_create("a", 0, "f", "a.0"),
_create("a", 4, "f", "a.4"),
_create("b", 1, None, "b.1"),
_create("b", 0, None, "b.0"),
_create("c", 0, "f", "c.0"),
_create("d", 1, "f", "d.1"),
]
grouped = group_samples_into_sequences(items)
assert len(grouped) == 3
def assert_group(group: ClassificationItemSequence, subject: str, props: List[str]) -> None:
assert isinstance(group, ClassificationItemSequence)
assert group.id == subject
assert [i.metadata.props["M"] for i in group.items] == props
# For subject a, item a.4 should be dropped because the consecutive sequence is only [0, 1]
assert_group(grouped[0], "a", ["a.0", "a.1"])
assert_group(grouped[1], "b", ["b.0", "b.1"])
assert_group(grouped[2], "c", ["c.0"])
# Group should not contain subject d because its only item is at index 1
def _create_item(id: str, sequence_position: int, metadata: str, label: Optional[float] = None) -> SequenceDataSource:
return SequenceDataSource(channel_files=["foo"],
numerical_non_image_features=torch.tensor([]),
categorical_non_image_features=torch.tensor([]),
label=(torch.tensor([label]) if label else torch.tensor([])),
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
props={"M": metadata}))
def _assert_group(group: ClassificationItemSequence, subject: str, props: List[str]) -> None:
assert group.id == subject
assert [i.metadata.props["M"] for i in group.items] == props
def test_group_items_with_min_and_max_sequence_position_values() -> None:
"""
Test if grouping of sequence data works when requiring a full set of items.
"""
items = [
_create_item("a", 1, "a.1"),
_create_item("a", 0, "a.0"),
_create_item("a", 2, "a.2"),
_create_item("b", 1, "b.1"),
_create_item("b", 0, "b.0"),
]
# When not providing a max_sequence_position_value, sequences of any length are OK.
grouped = group_samples_into_sequences(items, max_sequence_position_value=None)
assert len(grouped) == 2
_assert_group(grouped[0], "a", ["a.0", "a.1", "a.2"])
_assert_group(grouped[1], "b", ["b.0", "b.1"])
# With a max_sequence_position_value, the set must be complete up to the given index.
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2)
assert len(grouped) == 2
_assert_group(grouped[0], "a", ["a.1", "a.2"])
# When a max position is given, the sequence will be truncated to at most contain the given value.
grouped = group_samples_into_sequences(items, min_sequence_position_value=0, max_sequence_position_value=1)
assert len(grouped) == 2
_assert_group(grouped[0], "a", ["a.0", "a.1"])
_assert_group(grouped[1], "b", ["b.0", "b.1"])
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=1)
assert len(grouped) == 2
_assert_group(grouped[0], "a", ["a.1"])
_assert_group(grouped[1], "b", ["b.1"])
# Allow sequences upto max_sequence_position_value=2
grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2)
assert len(grouped) == 2
_assert_group(grouped[0], "a", ["a.1", "a.2"])
_assert_group(grouped[1], "b", ["b.1"])
# There are no items that have sequence position == 3, hence the next two calls should not return any items.
grouped = group_samples_into_sequences(items, min_sequence_position_value=3)
assert len(grouped) == 0
# Check that items upto max_sequence_position_value=3 are included
grouped = group_samples_into_sequences(items, max_sequence_position_value=3)
assert len(grouped) == 2
# Sequence positions must be unique
with pytest.raises(ValueError) as ex:
group_samples_into_sequences([_create_item("a", 0, "a.0")] * 2)
assert "contains duplicates" in str(ex)
def test_group_items_with_label_positions() -> None:
items = [
_create_item("a", 0, "a.0", 1),
_create_item("a", 3, "a.3", math.inf),
_create_item("a", 1, "a.1", 0),
_create_item("a", 2, "a.2", 1),
]
# Extracting the sequence from 2 to 3
grouped = group_samples_into_sequences(items, min_sequence_position_value=2, max_sequence_position_value=3)
assert len(grouped) == 1
_assert_group(grouped[0], "a", ["a.2", 'a.3'])
def test_filter_valid_items() -> None:
"""
Test if filtering of sequence data sets works.
"""
def _create(id: str, sequence_position: int, file: Optional[str], metadata: str) -> SequenceDataSource:
return SequenceDataSource(channel_files=[file],
numerical_non_image_features=torch.tensor([]),
categorical_non_image_features=torch.tensor([]),
label=torch.tensor([]),
metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position,
props={"M": metadata}))
items = [
_create("a", 1, "f1", "a.1"), # Valid item
_create("b", 0, None, "b.0"), # Invalid because no file
_create("b", 1, "d", "b.1"), # valid
_create("c", 0, "f3", "c.0"), # valid item for subject "c"
]
def assert_items(filtered: List[SequenceDataSource], props: List[str]) -> None:
assert [i.metadata.props["M"] for i in filtered] == props
# Standard filtering should remove items with missing file name only, that is b.0
filtered1 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
max_sequence_position_value=None)
assert_items(filtered1, ["a.1", "b.1", "c.0"])
# Filtering also for max_sequence_position_value
filtered2 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
max_sequence_position_value=1)
assert_items(filtered2, ["a.1", "b.1", "c.0"])
filtered3 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
max_sequence_position_value=0)
assert_items(filtered3, ["c.0"])
# Filtering also for min_sequence_position_value
filtered4 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
min_sequence_position_value=1,
max_sequence_position_value=None)
assert_items(filtered4, ["a.1", "b.1"])
filtered5 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=None,
min_sequence_position_value=2,
max_sequence_position_value=None)
assert_items(filtered5, [])
# Now also filter by file name mapping: only "d" is in the mapping, hence only b.1 should survive
file_mapping = {"d": Path("d"), "foo": Path("bar")}
filtered4 = filter_valid_classification_data_sources_items(items, file_to_path_mapping=file_mapping,
max_sequence_position_value=1)
assert_items(filtered4, ["b.1"])
# noinspection PyUnresolvedReferences
def test_sequence_dataloader() -> None:
"""
Test if we can create a data loader from the dataset, and recover the items as expected in batched form.
Including instances where not all elements of the sequence have labels.
"""
csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META
S1,0,foo.nii,,0,0,M1
S1,1,,True,1.1,1.2,M2
S2,0,bar.nii,False,2.1,2.2,M3
S2,1,,False,2.0,2.0,M4
""")
df = pd.read_csv(csv_string, sep=",", dtype=str)
config = SequenceModelBase(
image_file_column=None,
label_value_column="value",
numerical_columns=["scalar1"],
sequence_target_positions=[1],
sequence_column="seq",
local_dataset=Path.cwd(),
should_validate=False
)
dataset = SequenceDataset(config, data_frame=df)
assert len(dataset) == 2
data_loader = dataset.as_data_loader(shuffle=False, batch_size=2, num_dataload_workers=0)
# We have 2 subjects, with a batch size of 2 those should be turned into 1 batch
data_loader_output = list(i for i in data_loader)
assert len(data_loader_output) == 1
loaded = list(ClassificationItemSequence(**i) for i in data_loader_output)
assert loaded[0].id == ["S1", "S2"]
assert isinstance(loaded[0].items[0][0], ScalarItem)
assert loaded[0].items[0][0].metadata.id == "S1"
assert loaded[0].items[0][1].metadata.id == "S1"
assert loaded[0].items[1][0].metadata.id == "S2"
assert loaded[0].items[1][1].metadata.id == "S2"
# The batched sequence data are awkward to work with. Check if we can un-roll them correctly via
# from_minibatch
un_batched = ClassificationItemSequence.from_minibatch(data_loader_output[0])
assert len(un_batched) == 2
for i in range(2):
assert un_batched[i].id == dataset.items[i].id
assert len(un_batched[i].items) == len(dataset.items[i].items)
for j in range(len(un_batched[i].items)):
assert un_batched[i].items[j].metadata.id == dataset.items[i].items[j].metadata.id
def test_standardize_features() -> None:
"""
Test if the non-image feature can be normalized to mean 0, std 1.
:return:
"""
set_random_seed(1234)
expected_mean = torch.tensor([[123, 2, 3], [4, 5, 6]])
expected_std = torch.tensor([[0, 2, 3], [3, 4, 4]])
feature_size = (2, 3)
sequences: List[ClassificationItemSequence] = []
for s in range(1000):
items = []
seq_length = torch.randint(low=3, high=6, size=(1,)).item()
for i in range(seq_length): # type: ignore
# All features are random Gaussian, apart from feature 0 which is constant.
# Normalization must be able to deal with constant features when dividing by standard deviation.
features = torch.randn(size=feature_size, dtype=torch.float32) * expected_std + expected_mean
# Randomly put some infinite values in the vector
features[s % 2, s % 3] = np.inf if torch.rand(1) > 0.9 else features[s % 2, s % 3]
features[0, 0] = expected_mean[0, 0]
item = ScalarItem(metadata=GeneralSampleMetadata(id="foo"),
numerical_non_image_features=features,
categorical_non_image_features=features,
label=torch.tensor([]),
images=torch.tensor([]),
segmentations=torch.tensor([]))
items.append(item)
sequences.append(ClassificationItemSequence(id="foo", items=items))
mean_std = FeatureStatistics.from_data_sources(sequences)
assert mean_std.mean.shape == feature_size
assert mean_std.std.shape == feature_size
assert_tensors_equal(mean_std.mean, expected_mean, 0.07)
assert_tensors_equal(mean_std.std, expected_std, 0.07)
# After normalization, mean should be 0, and std should be 1.
standardized_seq = mean_std.standardize(sequences)
mean_std_from_standardized = FeatureStatistics.from_data_sources(standardized_seq)
# After normalization, the mean should be 0, apart from the constant feature, which should be left untouched,
# hence its mean is the original feature value.
expected_mean_from_standardized = torch.zeros(feature_size)
expected_mean_from_standardized[0, 0] = expected_mean[0, 0]
expected_std_from_standardized = torch.ones(feature_size)
expected_std_from_standardized[0, 0] = 0.0
assert_tensors_equal(mean_std_from_standardized.mean, expected_mean_from_standardized, abs=1e-5)
assert_tensors_equal(mean_std_from_standardized.std, expected_std_from_standardized, abs=1e-5)
@pytest.mark.parametrize("is_sequence", [True, False])
def test_standardize_features_when_singleton(is_sequence: bool) -> None:
"""
Test how feature standardize copes with datasets that only have 1 entry.
"""
numerical_features = torch.ones((1, 3))
categorical_features = torch.tensor([[0, 1, 1], [1, 0, 0]])
item: Union[SequenceDataSource, ScalarDataSource]
sources: Union[ListOfSequences, List[ScalarDataSource]]
if is_sequence:
item = SequenceDataSource(metadata=GeneralSampleMetadata(id="foo"),
numerical_non_image_features=numerical_features,
categorical_non_image_features=categorical_features,
label=torch.tensor([]),
channel_files=[])
sources = [ClassificationItemSequence(id="foo", items=[item])]
mean_std = FeatureStatistics.from_data_sources(sources)
else:
item = ScalarDataSource(metadata=GeneralSampleMetadata(id="foo"),
numerical_non_image_features=numerical_features,
categorical_non_image_features=categorical_features,
label=torch.tensor([]),
channel_files=[])
sources = [item]
mean_std = FeatureStatistics.from_data_sources(sources)
assert_tensors_equal(mean_std.mean, numerical_features)
# Standard deviation can't be computed because there is only one element, hence becomes nan.
assert torch.all(torch.isnan(mean_std.std))
# When applying such a standardization to the sequences, they should not be changed (similar to features that
# are constant)
standardized_sources = mean_std.standardize(sources)
if is_sequence:
assert_tensors_equal(standardized_sources[0].items[0].numerical_non_image_features, numerical_features)
assert_tensors_equal(standardized_sources[0].items[0].categorical_non_image_features, categorical_features)
else:
assert_tensors_equal(standardized_sources[0].numerical_non_image_features, numerical_features)
assert_tensors_equal(standardized_sources[0].categorical_non_image_features, categorical_features)
def test_add_difference_features() -> None:
"""
Test if we can add difference features for sequence data sets (differences from position i compared to position 0
in the sequence)
"""
def _create(features: List) -> SequenceDataSource:
return SequenceDataSource(metadata=GeneralSampleMetadata(id="foo"),
channel_files=[],
label=torch.tensor([]),
categorical_non_image_features=torch.tensor([]),
numerical_non_image_features=torch.tensor(features).float())
item1 = _create([[1, 2, 3], [4, 5, 6]])
item2 = _create([[11, 22, 33], [44, 55, 66]])
items = [ClassificationItemSequence[SequenceDataSource](id="bar", items=[item1, item2])]
updated = add_difference_features(items, [0, 2])
# The two difference features should be added along dimension 1 of the tensor
assert updated[0].items[0].numerical_non_image_features.shape == (2, 5)
# Item 0 should have differences of 0
assert_tensors_equal(updated[0].items[0].numerical_non_image_features[:, 0:3], item1.numerical_non_image_features)
assert_tensors_equal(updated[0].items[0].numerical_non_image_features[:, 3:5], [[0, 0], [0, 0]])
# Item 1 should have non-zero diff, and keep the original non-image features in the first few dim
assert_tensors_equal(updated[0].items[1].numerical_non_image_features[:, 0:3], item2.numerical_non_image_features)
assert_tensors_equal(updated[0].items[1].numerical_non_image_features[:, 3:5], [[10, 30], [40, 60]])
def test_seq_to_tensor() -> None:
"""
Test if we can create a tensor from a variable length sequence.
"""
def _create(features: List) -> torch.Tensor:
return ScalarItem(
segmentations=torch.empty(0),
metadata=GeneralSampleMetadata(id="foo"),
images=torch.tensor([]),
label=torch.tensor([]),
categorical_non_image_features=torch.tensor(features).float(),
numerical_non_image_features=torch.tensor(features).float()
).get_all_non_imaging_features()
item1 = _create([1, 2, 3, 4, 5, 6])
item2 = _create([11, 22, 33])
items = [item1, item1, item2, item1]
stacked = sequences_to_padded_tensor(items)
assert torch.is_tensor(stacked)
# pad_sequence will pad the tensors to the maximum sequence length
assert stacked.shape == (len(items), item1.numel())
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None:
"""
Check that the sequence dataset works end-to-end, including applying the right standardization.
"""
csv_string = """subject,seq,value,scalar1,scalar2,META,BETA
S1,0,False,0,0,M1,B1
S1,1,True,1,10,M2,B2
S2,0,False,2,20,M2,B1
S3,0,True,3,30,M1,B1
S4,0,True,4,40,M2,B1
"""
csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir)
config = SequenceModelBase(
local_dataset=csv_path,
image_file_column=None,
label_value_column="value",
numerical_columns=["scalar1", "scalar2"],
sequence_target_positions=[0],
categorical_columns=["META", "BETA"],
sequence_column="seq",
num_dataload_workers=0,
train_batch_size=2,
should_validate=False,
shuffle=False
)
config.read_dataset_if_needed()
df = config.dataset_data_frame
assert df is not None
df1 = df[df.subject.isin(["S1", "S2"])]
df2 = df[df.subject == "S3"]
df3 = df[df.subject == "S4"]
splits = DatasetSplits(train=df1, val=df2, test=df3)
with mock.patch.object(SequenceModelBase,
'get_model_train_test_dataset_splits',
return_value=splits):
train_val_loaders = config.create_data_loaders()
# Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10)
# Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10)
feature_stats = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN).feature_statistics
assert feature_stats is not None
assert_tensors_equal(feature_stats.mean, [1, 10])
assert_tensors_equal(feature_stats.std, [1, 10])
train_items = list(ClassificationItemSequence.from_minibatch(b)
for b in train_val_loaders[ModelExecutionMode.TRAIN])
assert len(train_items) == 1, "2 items in training set with batch size of 2 should return 1 minibatch"
assert len(train_items[0]) == 2
assert train_items[0][0].id == "S1"
assert_tensors_equal(train_items[0][0].items[0].get_all_non_imaging_features(), [-1., -1., 1., 0., 1., 0.])
assert_tensors_equal(train_items[0][0].items[1].get_all_non_imaging_features(), [0., 0., 0., 1., 0., 1.])
assert train_items[0][1].id == "S2"
assert_tensors_equal(train_items[0][1].items[0].get_all_non_imaging_features(), [1., 1., 0., 1., 1., 0.])
val_items = list(ClassificationItemSequence.from_minibatch(b)
for b in train_val_loaders[ModelExecutionMode.VAL])
assert len(val_items) == 1
assert len(val_items[0]) == 1
assert val_items[0][0].id == "S3"
# Items in the validation set should be normalized using the mean and std on the training data.
# Hence, the non-image features (3, 30) should turn into (2, 2)
assert_tensors_equal(val_items[0][0].items[0].get_all_non_imaging_features(), [2., 2., 1., 0., 1., 0.])
# Check that the test set is also normalized correctly using the training mean and std.
test_items = list(ClassificationItemSequence(**b)
for b in config.get_torch_dataset_for_inference(ModelExecutionMode.TEST))
assert test_items[0].id == "S4"
# Check Non-image features of (4, 40)
assert_tensors_equal(test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])
def test_get_class_counts(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models that predicts at multiple time points,
when it is started via run_ml.
"""
dataset_contents = _get_multi_label_sequence_dataframe()
config = ToyMultiLabelSequenceModel(should_validate=False)
assert config.get_target_indices() == [1, 2, 3]
expected_prediction_targets = ["Seq_pos 01", "Seq_pos 02", "Seq_pos 03"]
assert len(config.get_target_indices()) == len(expected_prediction_targets) # type: ignore
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = dataset_contents
config.pre_process_dataset_dataframe()
splits = config.get_dataset_splits()
train_dataset = config.create_torch_datasets(splits)[ModelExecutionMode.TRAIN]
class_counts = train_dataset.get_class_counts()
assert class_counts == {0: 2}
def test_get_labels_at_target_indices() -> None:
"""
Test to ensure label selection based on target indices is as expected
"""
sequence_items = _create_scalar_items(length=3)
sequence = ClassificationItemSequence(id="A", items=sequence_items)
# since label at sequence position 3 will not exist, we expect the result tensor to be padded with a nan
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1, 2, 3])
assert torch.allclose(labels, torch.tensor([[1.0], [1.0], [1.0], [np.nan]]), equal_nan=True)
# test we can extract all of the labels in the sequence
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1, 2])
assert torch.equal(labels, torch.tensor([[1.0], [1.0], [1.0]]))
# test we can extract only a subset of the labels in the sequence
labels = sequence.get_labels_at_target_indices(target_indices=[0, 1])
assert torch.equal(labels, torch.tensor([[1.0], [1.0]]))
# test we raise an exception for invalid target indices
with pytest.raises(Exception):
sequence.get_labels_at_target_indices(target_indices=[-1])
def test_create_labels_tensor_for_minibatch() -> None:
"""
Test to make sure labels tensor is created as expected for minibatch
"""
sequences = [ClassificationItemSequence(id=x, items=_create_scalar_items(length=i + 1))
for i, x in enumerate(["A", "B"])]
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0, 1, 2])
assert torch.allclose(labels, torch.tensor([
[[1.0], [np.nan], [np.nan]],
[[1.0], [1.0], [np.nan]]]
), equal_nan=True)
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0, 1])
assert torch.allclose(labels, torch.tensor([
[[1.0], [np.nan]],
[[1.0], [1.0]]]
), equal_nan=True)
labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(sequences, target_indices=[0])
assert torch.equal(labels, torch.tensor([
[[1.0]],
[[1.0]]]
))
def _create_scalar_items(length: int, label_value: float = 1.0) -> List[ScalarItem]:
return [ScalarItem(metadata=GeneralSampleMetadata(id="foo", sequence_position=x),
numerical_non_image_features=torch.tensor([]),
categorical_non_image_features=torch.tensor([]),
label=torch.tensor([label_value]),
images=torch.tensor([]),
segmentations=torch.tensor([])) for x in range(length)]

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -1,177 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import shutil
from pathlib import Path
from typing import Any, Tuple
import numpy as np
import pandas as pd
import pytest
import torch
from torch.utils.data import DataLoader
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
def noop_transform(x: Any) -> Any:
return x
def _check_generator_consistency(dl: DataLoader) -> None:
dataloader_generator = dl.generator
bag_sampler_generator = dl.dataset.data.bag_sampler.generator # type: ignore
assert torch.equal(dataloader_generator.get_state(),
bag_sampler_generator.get_state())
def compare_dataloaders(dl1: DataLoader, dl2: DataLoader) -> None:
for batch1, batch2 in zip(dl1, dl2):
_check_generator_consistency(dl1)
_check_generator_consistency(dl2)
assert batch1.keys() == batch2.keys()
for key in batch1:
assert len(batch1[key]) == len(batch2[key])
for item1, item2 in zip(batch1[key], batch2[key]):
if isinstance(item1, torch.Tensor):
assert torch.allclose(item1, item2, equal_nan=True)
else:
assert item1 == item2
class MockTilesDataset(TilesDataset):
TILE_X_COLUMN = TILE_Y_COLUMN = None
TRAIN_SPLIT_LABEL = 'train'
VAL_SPLIT_LABEL = 'val'
TEST_SPLIT_LABEL = 'test'
def generate_mock_dataset_df(n_slides: int, n_tiles: int, n_classes: int) -> pd.DataFrame:
np.random.seed(1234)
slide_ids = np.random.randint(n_slides, size=n_tiles)
slide_labels = np.random.randint(n_classes, size=n_slides)
tile_labels = slide_labels[slide_ids]
split_labels = [MockTilesDataset.TRAIN_SPLIT_LABEL,
MockTilesDataset.VAL_SPLIT_LABEL,
MockTilesDataset.TEST_SPLIT_LABEL]
slide_splits = np.random.choice(split_labels, size=n_slides)
tile_splits = slide_splits[slide_ids]
df = pd.DataFrame()
df[MockTilesDataset.TILE_ID_COLUMN] = np.arange(n_tiles)
df[MockTilesDataset.SLIDE_ID_COLUMN] = slide_ids
df[MockTilesDataset.LABEL_COLUMN] = tile_labels
df[MockTilesDataset.SPLIT_COLUMN] = tile_splits
df[MockTilesDataset.IMAGE_COLUMN] = [f"{tile_splits[i]}/{i:06d}.png" for i in range(n_tiles)]
return df
class MockTilesDataModule(TilesDataModule):
def get_splits(self) -> Tuple[MockTilesDataset, MockTilesDataset, MockTilesDataset]:
df = MockTilesDataset(self.root_path).dataset_df
df = df.reset_index()
split_dfs = (df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TRAIN_SPLIT_LABEL],
df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.VAL_SPLIT_LABEL],
df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TEST_SPLIT_LABEL])
return tuple(MockTilesDataset(self.root_path, dataset_df=split_df) # type: ignore
for split_df in split_dfs)
@pytest.fixture
def mock_data_dir(tmp_path: Path) -> Path:
csv_dir = tmp_path / "mock_tiles_dataset"
csv_dir.mkdir(exist_ok=True)
csv_path = csv_dir / MockTilesDataset.DEFAULT_CSV_FILENAME
if not csv_path.exists():
csv_path.parent.mkdir(parents=True, exist_ok=True)
df = generate_mock_dataset_df(n_slides=8, n_tiles=100, n_classes=2)
df.to_csv(csv_path, index=False)
return csv_dir
def _get_datamodule(cache_mode: CacheMode, precache_location: CacheLocation,
cache_dir_provided: bool, data_dir: Path) -> TilesDataModule:
if (cache_mode is CacheMode.NONE and precache_location is not CacheLocation.NONE) \
or (cache_mode is CacheMode.DISK and not cache_dir_provided) \
or (precache_location is not CacheLocation.NONE and not cache_dir_provided):
pytest.skip("Unsupported combination of caching arguments")
cache_dir = data_dir / f"datamodule_cache_{cache_mode.value}_{precache_location.value}" if cache_dir_provided else None
if cache_dir is not None and cache_dir.exists():
shutil.rmtree(cache_dir)
return MockTilesDataModule(root_path=data_dir,
transform=noop_transform,
seed=0,
batch_size=2,
cache_mode=cache_mode,
precache_location=precache_location,
cache_dir=cache_dir)
@pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE])
@pytest.mark.parametrize('precache_location', [CacheLocation.NONE, CacheLocation.CPU, CacheLocation.SAME])
@pytest.mark.parametrize('cache_dir_provided', [True, False])
def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation,
cache_dir_provided: bool) -> None:
# Compare two dataloaders from the same datamodule
datamodule = _get_datamodule(cache_mode=cache_mode,
precache_location=precache_location,
cache_dir_provided=cache_dir_provided,
data_dir=mock_data_dir)
datamodule.prepare_data()
train_dataloader = datamodule.train_dataloader()
train_dataloader2 = datamodule.train_dataloader()
compare_dataloaders(train_dataloader, train_dataloader2)
# Compare datamodules reusing the same cache
datamodule = _get_datamodule(cache_mode=cache_mode,
precache_location=precache_location,
cache_dir_provided=cache_dir_provided,
data_dir=mock_data_dir)
datamodule.prepare_data()
train_dataloader = datamodule.train_dataloader()
reloaded_datamodule = _get_datamodule(cache_mode=cache_mode,
precache_location=precache_location,
cache_dir_provided=cache_dir_provided,
data_dir=mock_data_dir)
reloaded_datamodule.prepare_data()
reloaded_train_dataloader = reloaded_datamodule.train_dataloader()
compare_dataloaders(train_dataloader, reloaded_train_dataloader)
@pytest.mark.parametrize('cache_mode, precache_location, cache_dir_provided',
[(CacheMode.DISK, CacheLocation.SAME, True),
(CacheMode.DISK, CacheLocation.CPU, True),
(CacheMode.MEMORY, CacheLocation.SAME, True),
(CacheMode.MEMORY, CacheLocation.CPU, True),
(CacheMode.MEMORY, CacheLocation.NONE, False),
(CacheMode.NONE, CacheLocation.NONE, False)
])
def test_tile_id_coverage(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation,
cache_dir_provided: bool) -> None:
datamodule = _get_datamodule(cache_mode=cache_mode,
precache_location=precache_location,
cache_dir_provided=cache_dir_provided,
data_dir=mock_data_dir)
datamodule.prepare_data()
train_dataset = datamodule.train_dataset
train_dataloader = datamodule.train_dataloader()
expected_tile_ids = set(train_dataset.dataset_df.index)
loaded_tile_ids = set() # type: ignore
for batch in train_dataloader:
for stacked_bag_tile_ids in batch[train_dataset.TILE_ID_COLUMN]:
if isinstance(stacked_bag_tile_ids, torch.Tensor):
stacked_bag_tile_ids = stacked_bag_tile_ids.tolist()
bag_tile_ids = set(stacked_bag_tile_ids)
assert bag_tile_ids.isdisjoint(loaded_tile_ids), \
f"Tile IDs already seen: {bag_tile_ids}"
loaded_tile_ids.update(bag_tile_ids)
assert loaded_tile_ids == expected_tile_ids

Просмотреть файл

Просмотреть файл

@ -1,40 +0,0 @@
import os
import pandas as pd
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.utils.naming import SlideKey
HISTO_TEST_DATA_DIR = str(tests_root_directory("ML/histopathology/test_data"))
class MockSlidesDataset(SlidesDataset):
DEFAULT_CSV_FILENAME = "test_slides_dataset.csv"
METADATA_COLUMNS = ('meta1', 'meta2')
def __init__(self) -> None:
super().__init__(root=HISTO_TEST_DATA_DIR)
def test_slides_dataset() -> None:
dataset = MockSlidesDataset()
assert isinstance(dataset.dataset_df, pd.DataFrame)
assert dataset.dataset_df.index.name == dataset.SLIDE_ID_COLUMN
assert len(dataset) == len(dataset.dataset_df)
sample = dataset[0]
assert isinstance(sample, dict)
assert all(isinstance(key, SlideKey) for key in sample)
expected_keys = [SlideKey.SLIDE_ID, SlideKey.IMAGE, SlideKey.IMAGE_PATH, SlideKey.LABEL,
SlideKey.METADATA]
assert all(key in sample for key in expected_keys)
image_path = sample[SlideKey.IMAGE_PATH]
assert isinstance(image_path, str)
assert os.path.isfile(image_path)
metadata = sample[SlideKey.METADATA]
assert isinstance(metadata, dict)
assert all(meta_col in metadata for meta_col in type(dataset).METADATA_COLUMNS)

Просмотреть файл

@ -1,42 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
import pytest
import torch
from monai.data.dataset import Dataset
from torch.utils.data import DataLoader
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.transforms import LoadTiled
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
reason="TCGA-CRCk dataset is unavailable")
@pytest.mark.parametrize('train', [True, False])
def test_dataset(train: bool) -> None:
base_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR, train=train)
dataset = Dataset(base_dataset, transform=LoadTiled('image')) # type: ignore
expected_length = 93408 if train else 98904
assert len(dataset) == expected_length
sample = dataset[0]
expected_keys = ['slide_id', 'tile_id', 'image', 'split', 'label']
assert all(key in sample for key in expected_keys)
assert isinstance(sample['image'], torch.Tensor)
assert sample['image'].shape == (3, 224, 224)
batch_size = 16
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # type: ignore
batch = next(iter(loader))
assert all(key in batch for key in expected_keys)
assert isinstance(batch['image'], torch.Tensor)
assert batch['image'].shape == (batch_size, 3, 224, 224)
assert batch['image'].dtype == torch.float32
assert batch['label'].shape == (batch_size,)
assert batch['label'].dtype == torch.int64

Просмотреть файл

Просмотреть файл

@ -1,379 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from unittest.mock import MagicMock
import pytest
import torch
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from torchmetrics import Accuracy, Metric # noqa
from torchvision.models import resnet18
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import (
DeepSMILECrck,
)
from InnerEye.ML.configs.histo_configs.classification.DeepSMILEPanda import (
DeepSMILEPanda,
)
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.default_paths import (
TCGA_CRCK_DATASET_DIR,
PANDA_TILES_DATASET_DIR,
)
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder, ImageNetEncoder, TileEncoder
from InnerEye.ML.Histopathology.utils.naming import MetricsKey, ResultsKey
def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
assert n_classes > 0
# hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
encoder = get_supervised_imagenet_encoder()
module = DeepMILModule(
encoder=encoder,
label_column="label",
n_classes=n_classes,
pooling_layer=pooling_layer,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate,
)
bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])
bag_labels_list = []
bag_logits_list = []
bag_attn_list = []
for bag in bag_images:
if n_classes > 1:
labels = randint(n_classes, size=(max_bag_size,))
else:
labels = randint(n_classes + 1, size=(max_bag_size,))
bag_labels_list.append(module.get_bag_label(labels))
logit, attn = module(bag)
assert logit.shape == (1, n_classes)
assert attn.shape == (module.pool_out_dim, max_bag_size)
bag_logits_list.append(logit.view(-1))
bag_attn_list.append(attn)
bag_logits = stack(bag_logits_list)
bag_labels = stack(bag_labels_list).view(-1)
assert bag_logits.shape[0] == (batch_size)
assert bag_labels.shape[0] == (batch_size)
if module.n_classes > 1:
loss = module.loss_fn(bag_logits, bag_labels)
else:
loss = module.loss_fn(bag_logits.squeeze(1), bag_labels.float())
assert loss > 0
assert loss.shape == ()
probs = module.activation_fn(bag_logits)
assert ((probs >= 0) & (probs <= 1)).all()
if n_classes > 1:
assert probs.shape == (batch_size, n_classes)
else:
assert probs.shape[0] == batch_size
if n_classes > 1:
preds = argmax(probs, dim=1)
else:
preds = round(probs)
assert preds.shape[0] == batch_size
for metric_name, metric_object in module.train_metrics.items():
if metric_name == MetricsKey.CONF_MATRIX:
continue
score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1))
assert torch.all(score >= 0)
assert torch.all(score <= 1)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
def is_integral(x: torch.Tensor) -> bool:
return (x == x.long()).all() # type: ignore
assert scores.shape == labels.shape
assert torch.is_floating_point(scores), "Received scores with integer dtype"
assert not is_integral(scores), "Received scores with integral values"
assert is_integral(labels), "Received labels with floating-point values"
def add_callback(fn: Callable, callback: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
callback(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
def test_metrics() -> None:
input_dim = (128,)
module = DeepMILModule(
encoder=IdentityEncoder(input_dim=input_dim),
label_column=TilesDataset.LABEL_COLUMN,
n_classes=1,
pooling_layer=AttentionLayer,
)
# Patching to enable running the module without a Trainer object
module.trainer = MagicMock(world_size=1) # type: ignore
module.log = MagicMock() # type: ignore
batch_size = 20
bag_size = 5
class_weights = torch.tensor([.8, .2])
bags: List[Dict] = []
for slide_idx in range(batch_size):
bag_label = torch.multinomial(class_weights, 1)
sample: Dict[str, Iterable] = {
TilesDataset.SLIDE_ID_COLUMN: [str(slide_idx)] * bag_size,
TilesDataset.TILE_ID_COLUMN: [f"{slide_idx}-{tile_idx}"
for tile_idx in range(bag_size)],
TilesDataset.IMAGE_COLUMN: rand(bag_size, *input_dim),
TilesDataset.LABEL_COLUMN: bag_label.expand(bag_size),
}
sample[TilesDataset.PATH_COLUMN] = [tile_id + '.png'
for tile_id in sample[TilesDataset.TILE_ID_COLUMN]]
bags.append(sample)
batch = default_collate(bags)
# ================
# Test that the module metrics match manually computed metrics with the correct inputs
module_metrics_dict = module.test_metrics
independent_metrics_dict = module.get_metrics()
# Patch the metrics to check that the inputs are valid. In particular, test that the scores
# do not have integral values, which would suggest that hard labels were passed instead.
for metric_obj in module_metrics_dict.values():
metric_obj.update = add_callback(metric_obj.update, validate_metric_inputs)
results = module.test_step(batch, 0)
predicted_probs = results[ResultsKey.PROB]
true_labels = results[ResultsKey.TRUE_LABEL]
for key, metric_obj in module_metrics_dict.items():
value = metric_obj.compute()
expected_value = independent_metrics_dict[key](predicted_probs, true_labels)
assert torch.allclose(value, expected_value), f"Discrepancy in '{key}' metric"
# ================
# Test that thresholded metrics (e.g. accuracy, precision, etc.) change as the threshold is varied.
# If they don't, it suggests the inputs are hard labels instead of continuous scores.
thresholded_metrics_keys = [key for key, metric in module_metrics_dict.items()
if hasattr(metric, 'threshold')]
def set_metrics_threshold(metrics_dict: Any, threshold: float) -> None:
for key in thresholded_metrics_keys:
metrics_dict[key].threshold = threshold
def reset_metrics(metrics_dict: Any) -> None:
for metric_obj in metrics_dict.values():
metric_obj.reset()
low_threshold, high_threshold = torch.quantile(predicted_probs, torch.tensor([0.1, 0.9]))
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=low_threshold)
_ = module.test_step(batch, 0)
results_low_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
reset_metrics(module_metrics_dict)
set_metrics_threshold(module_metrics_dict, threshold=high_threshold)
_ = module.test_step(batch, 0)
results_high_threshold = {key: module_metrics_dict[key].compute()
for key in thresholded_metrics_keys}
for key in thresholded_metrics_keys:
assert not torch.allclose(results_low_threshold[key], results_high_threshold[key]), \
f"Got same value for '{key}' metric with low and high thresholds"
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {
key: [
value.to(device) if isinstance(value, Tensor) else value for value in values
]
for key, values in batch.items()
}
CONTAINER_DATASET_DIR = {
DeepSMILEPanda: PANDA_TILES_DATASET_DIR,
DeepSMILECrck: TCGA_CRCK_DATASET_DIR,
}
@pytest.mark.parametrize("container_type", [DeepSMILEPanda,
DeepSMILECrck])
@pytest.mark.parametrize("use_gpu", [True, False])
def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> None:
dataset_dir = CONTAINER_DATASET_DIR[container_type]
if not os.path.isdir(dataset_dir):
pytest.skip(
f"Dataset for container {container_type.__name__} "
f"is unavailable: {dataset_dir}"
)
if container_type is DeepSMILECrck:
container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__)
elif container_type is DeepSMILEPanda:
container = DeepSMILEPanda(encoder_type=ImageNetEncoder.__name__)
else:
container = container_type()
container.setup()
data_module: TilesDataModule = container.get_data_module() # type: ignore
data_module.max_bag_size = 10
module = container.create_model()
if use_gpu:
module.cuda()
train_data_loader = data_module.train_dataloader()
for batch_idx, batch in enumerate(train_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.training_step(batch, batch_idx)
loss.retain_grad()
loss.backward()
assert loss.grad is not None
assert loss.shape == ()
assert isinstance(loss, Tensor)
break
val_data_loader = data_module.val_dataloader()
for batch_idx, batch in enumerate(val_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
loss = module.validation_step(batch, batch_idx)
assert loss.shape == () # noqa
assert isinstance(loss, Tensor)
break
test_data_loader = data_module.test_dataloader()
for batch_idx, batch in enumerate(test_data_loader):
batch = move_batch_to_expected_device(batch, use_gpu)
outputs_dict = module.test_step(batch, batch_idx)
loss = outputs_dict[ResultsKey.LOSS] # noqa
assert loss.shape == ()
assert isinstance(loss, Tensor)
break
def test_class_weights_binary() -> None:
class_weights = Tensor([0.5, 3.5])
n_classes = 1
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))
bag_label = randint(n_classes + 1, size=(1,))
pos_weight = Tensor([class_weights[1] / (class_weights[0] + 1e-5)])
loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float())
criterion_unweighted = nn.BCEWithLogitsLoss()
loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float())
if bag_label.item() == 1:
assert allclose(loss_weighted, pos_weight * loss_unweighted)
else:
assert allclose(loss_weighted, loss_unweighted)
def test_class_weights_multiclass() -> None:
class_weights = Tensor([0.33, 0.33, 0.33])
n_classes = 3
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))
bag_label = randint(n_classes, size=(1,))
loss_weighted = module.loss_fn(logits, bag_label)
criterion_unweighted = nn.CrossEntropyLoss()
loss_unweighted = criterion_unweighted(logits, bag_label)
# The weighted and unweighted loss functions give the same loss values for batch_size = 1.
# https://stackoverflow.com/questions/67639540/pytorch-cross-entropy-loss-weights-not-working
# TODO: the test should reflect actual weighted loss operation for the class weights after batch_size > 1 is implemented.
assert allclose(loss_weighted, loss_unweighted)

Просмотреть файл

@ -1,70 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Callable, Tuple
import numpy as np
import pytest
from torch import Tensor, float32, nn, rand
from torchvision.models import resnet18
from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
ImageNetSimCLREncoder)
from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor
TILE_SIZE = 224
INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE)
def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
def get_simclr_imagenet_encoder() -> TileEncoder:
return ImageNetSimCLREncoder(tile_size=TILE_SIZE)
def get_histo_ssl_encoder() -> TileEncoder:
return HistoSSLEncoder(tile_size=TILE_SIZE)
def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int,
batch_size: int = 5) -> None:
if isinstance(encoder, nn.Module):
for param_name, param in encoder.named_parameters():
assert not param.requires_grad, \
f"Feature extractor has unfrozen parameters: {param_name}"
images = rand(batch_size, *input_dims, dtype=float32)
features = encoder(images)
assert isinstance(features, Tensor)
assert features.shape == (batch_size, output_dim)
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
get_simclr_imagenet_encoder,
get_histo_ssl_encoder])
def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None:
encoder = create_encoder_fn()
_test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding)
def _dummy_classifier() -> nn.Module:
input_size = np.prod(INPUT_DIMS)
hidden_dim = 10
return nn.Sequential(
nn.Flatten(),
nn.Linear(input_size, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier])
def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None:
classifier = create_classifier_fn()
encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS)
_test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features)

Просмотреть файл

@ -1,210 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from typing import Callable, Sequence, Union
import numpy as np
import pytest
import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from monai.transforms import Compose
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import Subset
from torchvision.models import resnet18
from health_ml.utils.bag_utils import BagDataset
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
from Tests.ML.util import assert_dicts_equal
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
reason="TCGA-CRCk tiles dataset is unavailable")
def test_load_tile() -> None:
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
image_key = tiles_dataset.IMAGE_COLUMN
load_transform = LoadTiled(image_key)
index = 0
# Test that the transform affects only the image entry in the sample
input_sample = tiles_dataset[index]
loaded_sample = load_transform(input_sample)
assert_dicts_equal(loaded_sample, input_sample, exclude_keys=[image_key])
# Test that the MONAI Dataset applies the same transform
loaded_dataset = Dataset(tiles_dataset, transform=load_transform) # type:ignore
same_dataset_sample = loaded_dataset[index]
assert_dicts_equal(same_dataset_sample, loaded_sample)
# Test that loading another sample gives different results
different_sample = loaded_dataset[index + 1]
assert not torch.allclose(different_sample[image_key], loaded_sample[image_key])
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
reason="TCGA-CRCk tiles dataset is unavailable")
def test_load_tiles_batch() -> None:
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
image_key = tiles_dataset.IMAGE_COLUMN
max_bag_size = 5
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
max_bag_size=max_bag_size)
load_batch_transform = LoadTilesBatchd(image_key)
loaded_dataset = Dataset(tiles_dataset, transform=LoadTiled(image_key)) # type:ignore
image_shape = loaded_dataset[0][image_key].shape
index = 0
# Test that the transform affects only the image entry in the batch,
# and that the loaded images have the expected shape
bagged_batch = bagged_dataset[index]
manually_loaded_batch = load_batch_transform(bagged_batch)
assert_dicts_equal(manually_loaded_batch, bagged_batch, exclude_keys=[image_key])
assert manually_loaded_batch[image_key].shape == (max_bag_size, *image_shape)
# Test that the MONAI Dataset applies the same transform
loaded_bagged_dataset = Dataset(bagged_dataset, transform=load_batch_transform) # type:ignore
loaded_bagged_batch = loaded_bagged_dataset[index]
assert_dicts_equal(loaded_bagged_batch, manually_loaded_batch)
# Test that loading another batch gives different results
different_batch = loaded_bagged_dataset[index + 1]
assert not torch.allclose(different_batch[image_key], manually_loaded_batch[image_key])
# Test that loading and bagging commute
bagged_loaded_dataset = BagDataset(loaded_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=max_bag_size)
bagged_loaded_batch = bagged_loaded_dataset[index]
assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch)
def _test_cache_and_persistent_datasets(tmp_path: Path,
base_dataset: TorchDataset,
transform: Union[Sequence[Callable], Callable],
cache_subdir: str) -> None:
default_dataset = Dataset(base_dataset, transform=transform) # type: ignore
cached_dataset = CacheDataset(base_dataset, transform=transform) # type: ignore
cache_dir = tmp_path / cache_subdir
cache_dir.mkdir(exist_ok=True)
persistent_dataset = PersistentDataset(base_dataset, transform=transform, # type: ignore
cache_dir=cache_dir)
for default_sample, cached_sample, persistent_sample \
in zip(default_dataset, cached_dataset, persistent_dataset): # type: ignore
assert_dicts_equal(cached_sample, default_sample)
assert_dicts_equal(persistent_sample, default_sample)
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
reason="TCGA-CRCk tiles dataset is unavailable")
def test_cached_loading(tmp_path: Path) -> None:
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
image_key = tiles_dataset.IMAGE_COLUMN
max_num_tiles = 100
tiles_subset = Subset(tiles_dataset, range(max_num_tiles))
_test_cache_and_persistent_datasets(tmp_path,
tiles_subset,
transform=LoadTiled(image_key),
cache_subdir="TCGA-CRCk_tiles_cache")
max_bag_size = 5
max_num_bags = max_num_tiles // max_bag_size
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
max_bag_size=max_bag_size)
bagged_subset = Subset(bagged_dataset, range(max_num_bags))
_test_cache_and_persistent_datasets(tmp_path,
bagged_subset,
transform=LoadTilesBatchd(image_key),
cache_subdir="TCGA-CRCk_load_cache")
@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
reason="TCGA-CRCk tiles dataset is unavailable")
@pytest.mark.parametrize('use_gpu , chunk_size',
[(False, 0), (False, 2), (True, 0), (True, 2)]
)
def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
image_key = tiles_dataset.IMAGE_COLUMN
max_bag_size = 5
bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
max_bag_size=max_bag_size)
encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3)
if use_gpu:
encoder.cuda()
encode_transform = EncodeTilesBatchd(image_key, encoder, chunk_size=chunk_size)
transform = Compose([LoadTilesBatchd(image_key), encode_transform])
dataset = Dataset(bagged_dataset, transform=transform) # type: ignore
sample = dataset[0]
assert sample[image_key].shape == (max_bag_size, encoder.num_encoding)
# TODO: Ensure it works in DDP
max_num_bags = 20
bagged_subset = Subset(bagged_dataset, range(max_num_bags))
_test_cache_and_persistent_datasets(tmp_path,
bagged_subset,
transform=transform,
cache_subdir="TCGA-CRCk_embed_cache")
@pytest.mark.parametrize('include_non_indexable', [True, False])
@pytest.mark.parametrize('allow_missing_keys', [True, False])
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
batch_size = 5
max_size = batch_size // 2
data = {
'array_1d': np.random.randn(batch_size),
'array_2d': np.random.randn(batch_size, 4),
'tensor_1d': torch.randn(batch_size),
'tensor_2d': torch.randn(batch_size, 4),
'list': torch.randn(batch_size).tolist(),
'indices': list(range(batch_size)),
'non-indexable': 42,
}
keys_to_subsample = list(data.keys())
if not include_non_indexable:
keys_to_subsample.remove('non-indexable')
keys_to_subsample.append('missing-key')
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
allow_missing_keys=allow_missing_keys)
if include_non_indexable:
with pytest.raises(ValueError):
sub_data = subsampling(data)
return
elif not allow_missing_keys:
with pytest.raises(KeyError):
sub_data = subsampling(data)
return
else:
sub_data = subsampling(data)
assert set(sub_data.keys()) == set(data.keys())
# Check lenghts before and after subsampling
for key in keys_to_subsample:
if key not in data:
continue # Skip missing keys
assert len(data[key]) == batch_size # type: ignore
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
# Check contents of subsampled elements
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
for idx, elem in zip(sub_data['indices'], sub_data[key]):
assert np.array_equal(elem, data[key][idx]) # type: ignore
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
sub_data2 = subsampling(data)
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore

Просмотреть файл

Просмотреть файл

@ -1,171 +0,0 @@
from typing import Optional
import numpy as np
import pytest
from monai.data.image_reader import WSIReader
from InnerEye.Common.common_util import is_windows
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
from InnerEye.ML.Histopathology.preprocessing.tiling import tile_array_2d
from InnerEye.ML.Histopathology.preprocessing.loading import (LoadROId, get_luminance, load_slide_at_level,
segment_foreground)
from InnerEye.ML.Histopathology.utils.naming import SlideKey
from Tests.ML.histopathology.datasets.test_slides_dataset import MockSlidesDataset
TEST_IMAGE_PATH = str(tests_root_directory("ML/histopathology/test_data/panda_wsi_example.tiff"))
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
def test_load_slide() -> None:
level = 2
reader = WSIReader('cuCIM')
from cucim import CuImage
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
dims = slide_obj.resolutions['level_dimensions'][level][::-1]
slide = load_slide_at_level(reader, slide_obj, level)
assert isinstance(slide, np.ndarray)
expected_shape = (3, *dims)
assert slide.shape == expected_shape
frac_empty = (slide == 0).mean()
assert frac_empty == 0.0
larger_dims = (2 * dims[0], 2 * dims[1])
larger_slide, _ = reader.get_data(slide_obj, size=larger_dims, level=level)
assert isinstance(larger_slide, np.ndarray)
assert larger_slide.shape == (3, *larger_dims)
# Overlapping parts match exactly
assert np.array_equal(larger_slide[:, :dims[0], :dims[1]], slide)
# Non-overlapping parts are all empty
empty_fill_value = 0 # fill value seems to depend on the image
assert np.array_equiv(larger_slide[:, dims[0]:, :], empty_fill_value)
assert np.array_equiv(larger_slide[:, :, dims[1]:], empty_fill_value)
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
def test_get_luminance() -> None:
level = 2 # here we only need to test at a single resolution
reader = WSIReader('cuCIM')
from cucim import CuImage
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
slide = load_slide_at_level(reader, slide_obj, level)
slide_luminance = get_luminance(slide)
assert isinstance(slide_luminance, np.ndarray)
assert slide_luminance.shape == slide.shape[1:]
assert (slide_luminance <= 255).all() and (slide_luminance >= 0).all()
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
tiles_luminance = get_luminance(tiles)
assert isinstance(tiles_luminance, np.ndarray)
assert tiles_luminance.shape == (tiles.shape[0], *tiles.shape[2:])
assert (tiles_luminance <= 255).all() and (tiles_luminance >= 0).all()
slide_luminance_tiles, _ = tile_array_2d(np.expand_dims(slide_luminance, axis=0),
tile_size=224, constant_values=255)
assert np.array_equal(slide_luminance_tiles.squeeze(1), tiles_luminance)
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
def test_segment_foreground() -> None:
level = 2 # here we only need to test at a single resolution
reader = WSIReader('cuCIM')
from cucim import CuImage
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
slide = load_slide_at_level(reader, slide_obj, level)
auto_mask, auto_threshold = segment_foreground(slide, threshold=None)
assert isinstance(auto_mask, np.ndarray)
assert auto_mask.dtype == bool
assert auto_mask.shape == slide.shape[1:]
assert 0 < auto_mask.sum() < auto_mask.size # auto-seg should not produce trivial mask
luminance = get_luminance(slide)
assert luminance.min() < auto_threshold < luminance.max()
mask, returned_threshold = segment_foreground(slide, threshold=auto_threshold)
assert isinstance(mask, np.ndarray)
assert mask.dtype == bool
assert mask.shape == slide.shape[1:]
assert np.array_equal(mask, auto_mask)
assert returned_threshold == auto_threshold
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
tiles_mask, _ = segment_foreground(tiles, threshold=auto_threshold)
assert isinstance(tiles_mask, np.ndarray)
assert tiles_mask.dtype == bool
assert tiles_mask.shape == (tiles.shape[0], *tiles.shape[2:])
slide_mask_tiles, _ = tile_array_2d(np.expand_dims(mask, axis=0),
tile_size=224, constant_values=False)
assert np.array_equal(slide_mask_tiles.squeeze(1), tiles_mask)
@pytest.mark.parametrize('level', [1, 2])
@pytest.mark.parametrize('foreground_threshold', [None, 215])
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
def test_get_bounding_box(level: int, foreground_threshold: Optional[float]) -> None:
margin = 0
reader = WSIReader('cuCIM')
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
foreground_threshold=foreground_threshold)
from cucim import CuImage
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
level0_bbox, _ = loader._get_bounding_box(slide_obj)
highest_level = slide_obj.resolutions['level_count'] - 1
# level = highest_level
slide = load_slide_at_level(reader, slide_obj, level=level)
scale = slide_obj.resolutions['level_downsamples'][level]
bbox = level0_bbox / scale
assert bbox.x >= 0 and bbox.y >= 0
assert bbox.x + bbox.w <= slide.shape[1]
assert bbox.y + bbox.h <= slide.shape[2]
# Now with nonzero margin
margin = 42
loader_margin = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
foreground_threshold=foreground_threshold)
level0_bbox_margin, _ = loader_margin._get_bounding_box(slide_obj)
# Here we test the box differences at the highest resolution, because margin is
# specified in low-res pixels. Otherwise could fail due to rounding error.
level0_scale: float = slide_obj.resolutions['level_downsamples'][highest_level]
level0_margin = int(level0_scale * margin)
assert level0_bbox_margin.x == level0_bbox.x - level0_margin
assert level0_bbox_margin.y == level0_bbox.y - level0_margin
assert level0_bbox_margin.w == level0_bbox.w + 2 * level0_margin
assert level0_bbox_margin.h == level0_bbox.h + 2 * level0_margin
@pytest.mark.parametrize('level', [1, 2])
@pytest.mark.parametrize('margin', [0, 42])
@pytest.mark.parametrize('foreground_threshold', [None, 215])
@pytest.mark.skipif(is_windows(), reason="cucim package is not available on Windows")
def test_load_roi(level: int, margin: int, foreground_threshold: Optional[float]) -> None:
dataset = MockSlidesDataset()
sample = dataset[0]
reader = WSIReader('cuCIM')
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
foreground_threshold=foreground_threshold)
loaded_sample = loader(sample)
assert isinstance(loaded_sample, dict)
# Check that none of the input keys were removed
assert all(key in loaded_sample for key in sample)
# Check that the expected new keys were inserted
additional_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD]
assert all(key in loaded_sample for key in additional_keys)
assert isinstance(loaded_sample[SlideKey.IMAGE], np.ndarray)
image_shape = loaded_sample[SlideKey.IMAGE].shape
assert len(image_shape)
assert image_shape[0] == 3
origin = loaded_sample[SlideKey.ORIGIN]
assert isinstance(origin, tuple)
assert len(origin) == 2
assert all(isinstance(coord, int) for coord in origin)
assert isinstance(loaded_sample[SlideKey.SCALE], (int, float))
assert loaded_sample[SlideKey.SCALE] >= 1.0
assert isinstance(loaded_sample[SlideKey.FOREGROUND_THRESHOLD], (int, float))

Просмотреть файл

@ -1,126 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import numpy as np
import pytest
from InnerEye.ML.Histopathology.preprocessing.tiling import assemble_tiles_2d, get_1d_padding, \
pad_for_tiling_2d, tile_array_2d
@pytest.mark.parametrize("length,tile_size",
[(8, 4), (9, 4), (8, 3), (4, 4), (3, 4)])
def test_1d_padding(length: int, tile_size: int) -> None:
pad_pre, pad_post = get_1d_padding(length, tile_size)
assert pad_pre >= 0 and pad_post >= 0
assert pad_pre < tile_size and pad_post < tile_size
assert abs(pad_post - pad_pre) <= 1, "Asymmetric padding"
padded_length = pad_pre + length + pad_post
assert padded_length % tile_size == 0
n_tiles = padded_length // tile_size
expected_n_tiles = int(np.ceil(length / tile_size))
assert n_tiles == expected_n_tiles
@pytest.mark.parametrize("width,height", [(8, 6)])
@pytest.mark.parametrize("tile_size", [3, 4, 5])
@pytest.mark.parametrize("channels_first", [True, False])
def test_2d_padding(width: int, height: int, tile_size: int, channels_first: bool) -> None:
channels = 2
pad_value = 0
array = np.random.rand(channels, height, width)
input_array = array if channels_first else array.transpose(1, 2, 0)
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(input_array, tile_size, channels_first,
constant_values=pad_value)
if not channels_first:
padded_array = padded_array.transpose(2, 0, 1)
padded_channels, padded_height, padded_width = padded_array.shape
assert padded_channels == channels and padded_height >= height and padded_width >= width
assert padded_height % tile_size == 0 and padded_width % tile_size == 0
assert 0 <= offset_h < tile_size and 0 <= offset_w < tile_size
crop = padded_array[:, offset_h:offset_h + height, offset_w:offset_w + width]
assert np.array_equal(crop, array)
# np.array_equiv() broadcasts the shapes
assert np.array_equiv(padded_array[:, :offset_h, :], pad_value)
assert np.array_equiv(padded_array[:, :, :offset_w], pad_value)
assert np.array_equiv(padded_array[:, offset_h + height:, :], pad_value)
assert np.array_equiv(padded_array[:, :, offset_w + width:], pad_value)
def _get_2d_meshgrid(width: int, height: int, channels_first: bool = True) -> np.ndarray:
array = np.stack(np.meshgrid(np.arange(width), np.arange(height)),
axis=0 if channels_first else -1)
assert array.shape == ((2, height, width) if channels_first else (height, width, 2))
return array
@pytest.mark.parametrize("width,height", [(8, 6)])
@pytest.mark.parametrize("tile_size", [3, 4, 5])
@pytest.mark.parametrize("channels_first", [True, False])
def test_tile_array_2d_both(width: int, height: int, tile_size: int, channels_first: bool) -> None:
channels = 2
array = _get_2d_meshgrid(width, height, channels_first)
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first,
constant_values=0)
tiles, coords = tile_array_2d(array, tile_size, channels_first)
assert tiles.shape[0] == coords.shape[0]
expected_n_tiles_w = int(np.ceil(width / tile_size))
expected_n_tiles_h = int(np.ceil(height / tile_size))
expected_n_tiles = expected_n_tiles_w * expected_n_tiles_h
if channels_first:
assert tiles.shape == (expected_n_tiles, channels, tile_size, tile_size)
else:
assert tiles.shape == (expected_n_tiles, tile_size, tile_size, channels)
assert coords.shape == (expected_n_tiles, 2)
for idx in range(tiles.shape[0]):
row = coords[idx, 1] + offset_h
col = coords[idx, 0] + offset_w
if channels_first:
expected_tile = padded_array[:, row:row + tile_size, col:col + tile_size]
else:
expected_tile = padded_array[row:row + tile_size, col:col + tile_size, :]
assert np.array_equal(tiles[idx], expected_tile)
expected_x = tile_size * (idx % expected_n_tiles_w) - offset_w
expected_y = tile_size * (idx // expected_n_tiles_w) - offset_h
assert tuple(coords[idx]) == (expected_x, expected_y)
@pytest.mark.parametrize("width,height", [(8, 6)])
@pytest.mark.parametrize("tile_size", [3, 4, 5])
@pytest.mark.parametrize("channels_first", [True, False])
def test_assemble_tiles_2d(width: int, height: int, tile_size: int, channels_first: bool) -> None:
array = _get_2d_meshgrid(width, height, channels_first)
fill_value = 0
padded_array, padding_offset = pad_for_tiling_2d(array, tile_size, channels_first,
constant_values=fill_value)
tiles, coords = tile_array_2d(array, tile_size, channels_first)
assembled_array, assembly_offset = assemble_tiles_2d(tiles, coords, fill_value=fill_value,
channels_first=channels_first)
assert np.array_equal(assembled_array, padded_array)
assert np.array_equal(assembly_offset, padding_offset)
for idx in range(tiles.shape[0]):
row = coords[idx, 1] + assembly_offset[1]
col = coords[idx, 0] + assembly_offset[0]
if channels_first:
crop = assembled_array[:, row:row + tile_size, col:col + tile_size]
else:
crop = assembled_array[row:row + tile_size, col:col + tile_size, :]
assert np.array_equal(crop, tiles[idx])

Просмотреть файл

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:06eb0acaa2883181e9b6ab976863f71cc843a75ed9175fae8fe9b879635af1b0
size 816563

Просмотреть файл

@ -1,2 +0,0 @@
slide_id,image,label,meta1,meta2
foo,panda_wsi_example.tiff,0,bar,baz
1 slide_id image label meta1 meta2
2 foo panda_wsi_example.tiff 0 bar baz

Просмотреть файл

Просмотреть файл

@ -1,213 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
import math
import numpy as np
from typing import List
import matplotlib
from torch.functional import Tensor
import pytest
from InnerEye.Common.common_util import is_windows
from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, \
plot_heatmap_overlay, plot_normalized_confusion_matrix
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.plotting import resize_and_save
from InnerEye.ML.utils.ml_util import set_random_seed
from Tests.ML.util import assert_binary_files_match
def assert_equal_lists(pred: List, expected: List) -> None:
assert len(pred) == len(expected)
for i, slide in enumerate(pred):
for j, value in enumerate(slide):
if type(value) in [int, float]:
assert math.isclose(value, expected[i][j], rel_tol=1e-06)
elif (type(value) == Tensor) and (value.ndim >= 1):
for k, idx in enumerate(value):
assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06)
elif isinstance(value, List):
for k, idx in enumerate(value):
if type(idx) in [int, float]:
assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06)
elif type(idx) == Tensor:
assert math.isclose(idx.item(), expected[i][j][k].item(), rel_tol=1e-06)
else:
raise TypeError("Unexpected list composition")
test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],
ResultsKey.CLASS_PROBS: [Tensor([0.6, 0.4]), Tensor([0.3, 0.7]), Tensor([0.6, 0.4]), Tensor([0.0, 1.0]),
Tensor([0.7, 0.3]), Tensor([0.8, 0.2]), Tensor([0.1, 0.9]), Tensor([0.01, 0.99])],
ResultsKey.TRUE_LABEL: [0, 1, 1, 1, 1, 0, 0, 0],
ResultsKey.BAG_ATTN:
[Tensor([[0.10, 0.00, 0.20, 0.15]]),
Tensor([[0.10, 0.18, 0.15, 0.13]]),
Tensor([[0.25, 0.23, 0.20, 0.21]]),
Tensor([[0.33, 0.31, 0.37, 0.35]]),
Tensor([[0.43, 0.01, 0.07, 0.25]]),
Tensor([[0.53, 0.11, 0.17, 0.55]]),
Tensor([[0.63, 0.21, 0.27, 0.05]]),
Tensor([[0.73, 0.31, 0.37, 0.15]])],
ResultsKey.TILE_X:
[Tensor([200, 200, 424, 424]),
Tensor([200, 200, 424, 424]),
Tensor([200, 200, 424, 424]),
Tensor([200, 200, 424, 424])],
ResultsKey.TILE_Y:
[Tensor([200, 424, 200, 424]),
Tensor([200, 200, 424, 424]),
Tensor([200, 200, 424, 424]),
Tensor([200, 200, 424, 424])]
}
def test_select_k_tiles() -> None:
nslides = 2
ntiles = 2
# TP
top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
print(top_tp)
assert_equal_lists(top_tp, [(4, Tensor([0.0, 1.0]), [3, 4], [Tensor([0.37]), Tensor([0.35])]),
(2, Tensor([0.3, 0.7]), [2, 3], [Tensor([0.18]), Tensor([0.15])])])
assert_equal_lists(bottom_tp, [(4, Tensor([0.0, 1.0]), [2, 1], [Tensor([0.31]), Tensor([0.33])]),
(2, Tensor([0.3, 0.7]), [1, 4], [Tensor([0.10]), Tensor([0.13])])])
# FN
top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
assert_equal_lists(top_fn, [(5, Tensor([0.7, 0.3]), [1, 4], [Tensor([0.43]), Tensor([0.25])]),
(3, Tensor([0.6, 0.4]), [1, 2], [Tensor([0.25]), Tensor([0.23])])])
assert_equal_lists(bottom_fn, [(5, Tensor([0.7, 0.3]), [2, 3], [Tensor([0.01]), Tensor([0.07])]),
(3, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.20]), Tensor([0.21])])])
# TN
top_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'highest_att'))
bottom_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'lowest_att'))
assert_equal_lists(top_tn, [(6, Tensor([0.8, 0.2]), [4, 1], [Tensor([0.55]), Tensor([0.53])]),
(1, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.2]), Tensor([0.15])])])
assert_equal_lists(bottom_tn, [(6, Tensor([0.8, 0.2]), [2, 3], [Tensor([0.11]), Tensor([0.17])]),
(1, Tensor([0.6, 0.4]), [2, 1], [Tensor([0.00]), Tensor([0.10])])])
# FP
top_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'highest_att'))
bottom_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'lowest_att'))
assert_equal_lists(top_fp, [(8, Tensor([0.01, 0.99]), [1, 3], [Tensor([0.73]), Tensor([0.37])]),
(7, Tensor([0.1, 0.9]), [1, 3], [Tensor([0.63]), Tensor([0.27])])])
assert_equal_lists(bottom_fp, [(8, Tensor([0.01, 0.99]), [4, 2], [Tensor([0.15]), Tensor([0.31])]),
(7, Tensor([0.1, 0.9]), [4, 2], [Tensor([0.05]), Tensor([0.21])])])
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
def test_plot_scores_hist(test_output_dirs: OutputFolderForTests) -> None:
fig = plot_scores_hist(test_dict)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / "plot_score_hist.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / "score_hist.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6])
def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None:
set_random_seed(0)
slide_image = np.random.rand(3, 1000, 2000)
fig = plot_slide(slide_image=slide_image, scale=scale)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / "plot_slide.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None:
set_random_seed(0)
slide_image = np.random.rand(3, 1000, 2000)
location_bbox = [100, 100]
slide = 1
tile_size = 224
level = 0
fig = plot_heatmap_overlay(slide=slide, # type: ignore
slide_image=slide_image,
results=test_dict, # type: ignore
location_bbox=location_bbox,
tile_size=tile_size,
level=level)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / "plot_heatmap_overlay.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / "heatmap_overlay.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows")
def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests, n_classes: int) -> None:
set_random_seed(0)
if n_classes > 1:
cm = np.random.randint(1, 1000, size=(n_classes, n_classes))
class_names = [str(i) for i in range(n_classes)]
else:
cm = np.random.randint(1, 1000, size=(n_classes + 1, n_classes + 1))
class_names = [str(i) for i in range(n_classes + 1)]
cm_n = cm / cm.sum(axis=1, keepdims=True)
assert (cm_n <= 1).all()
fig = plot_normalized_confusion_matrix(cm=cm_n, class_names=class_names)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / f"plot_confusion_matrix_{n_classes}.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)
@pytest.mark.parametrize("level", [0, 1, 2])
def test_location_selected_tiles(level: int) -> None:
set_random_seed(0)
slide = 1
location_bbox = [100, 100]
slide_image = np.random.rand(3, 1000, 2000)
coords = []
slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore
slide_idx = slide_ids.index(slide)
for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore
tile_coords = np.transpose(
np.array([test_dict[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), # type: ignore
test_dict[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) # type: ignore
coords.append(tile_coords)
coords = np.array(coords)
tile_coords_transformed = location_selected_tiles(tile_coords=coords,
location_bbox=location_bbox,
level=level)
tile_xs, tile_ys = tile_coords_transformed.T
level_dict = {0: 1, 1: 4, 2: 16}
factor = level_dict[level]
assert min(tile_xs) >= 0
assert max(tile_xs) <= slide_image.shape[2] // factor
assert min(tile_ys) >= 0
assert max(tile_ys) <= slide_image.shape[1] // factor

Просмотреть файл

@ -1,28 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pandas as pd
from InnerEye.ML.Histopathology.utils.tcga_utils import extract_fields
def test_extract_fields() -> None:
slide_id = "TCGA-XX-0123"
tile_id = "ABCDEFGHIJKL"
split = "train"
label = 0
path = (f"CRC_DX_{split.upper()}/"
f"{['MSS', 'MSIMUT'][label]}/"
f"blk-{tile_id}-{slide_id}-01Z-00-DX1.png")
fields = {
'slide_id': slide_id,
'tile_id': tile_id,
'image': path,
'split': split,
'label': label
}
extracted_fields = extract_fields(pd.Series(fields))
assert fields == extracted_fields

Просмотреть файл

@ -1,36 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
from InnerEye.ML.models.architectures.sequential.gru import LayerNormGRU, LayerNormGRUCell
def test_cell_initialisation() -> None:
model = LayerNormGRU(input_size=2, hidden_size=2, num_layers=3)
assert len(model.cells) == 3
def test_layer_norm_initialisation() -> None:
cell = LayerNormGRUCell(input_size=2, hidden_size=2, use_layer_norm=True, dropout=0.50)
assert isinstance(cell.ln_r, nn.LayerNorm)
assert cell.dropout == 0.50
def test_gru_forward_pass() -> None:
num_layers = 2
batch_size = 10
hidden_dim = 2
input_dim = 3
seq_length = 5
model = LayerNormGRU(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers)
input_sequence = torch.rand(size=(batch_size, seq_length, input_dim))
initial_hidden_state = torch.zeros(size=(num_layers, batch_size, hidden_dim))
output_sequence = model(input_sequence, initial_hidden_state)
assert output_sequence.size(0) == batch_size
assert output_sequence.size(1) == seq_length
assert output_sequence.size(2) == hidden_dim

Просмотреть файл

@ -1,596 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from io import StringIO
from typing import Any, List, Optional, Tuple
from unittest import mock
import numpy as np
import pandas as pd
import pytest
import torch
from torchvision.transforms import ColorJitter, RandomAffine
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, logging_to_stdout
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, SEQUENCE_POSITION_HUE_NAME_PREFIX
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
from InnerEye.ML.lightning_models import transfer_batch_to_device
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
from InnerEye.ML.model_testing import create_metrics_dict_for_scalar_models
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoder, ImagingFeatureType
from InnerEye.ML.models.architectures.sequential.rnn_classifier import RNNClassifier, RNNClassifierWithEncoder
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import ScalarLoss
from InnerEye.ML.sequence_config import SEQUENCE_LENGTH_FILE, SEQUENCE_LENGTH_STATS_FILE, SequenceModelBase
from InnerEye.ML.utils import ml_util
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.io_util import ImageAndSegmentations
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.split_dataset import DatasetSplits
from InnerEye.ML.visualizers.grad_cam_hooks import VisualizationMaps
from Tests.ML.util import get_default_azure_config, model_train_unittest
SCAN_SIZE = (6, 64, 60)
def prepare_sequences(num_sequences: int, sequence_length: int, batch_size: int) -> Tuple[List, List]:
# Returns [batch][sequence, label]
num_mini_batches = num_sequences // batch_size
inputs = np.random.choice([0, 1], size=(num_sequences, sequence_length), p=[1. / 3, 2. / 3]).astype(np.float32)
inputs = torch.tensor(inputs)
labels = torch.sum(inputs, dim=1) > (sequence_length // 2)
labels = labels.long()
data = list()
for batch_index in range(num_mini_batches):
_input = inputs[batch_index * batch_size: (batch_index + 1) * batch_size]
_label = labels[batch_index * batch_size: (batch_index + 1) * batch_size]
data.append((_input, _label))
return data[:num_mini_batches // 2], data[num_mini_batches // 2:]
class ToySequenceModel(SequenceModelBase):
def __init__(self, use_combined_model: bool = False,
imaging_feature_type: ImagingFeatureType = ImagingFeatureType.Image,
combine_hidden_states: bool = False,
use_encoder_layer_norm: bool = False,
sequence_target_positions: Optional[List[int]] = None,
use_mean_teacher_model: bool = False,
**kwargs: Any) -> None:
num_epochs = 3
mean_teacher_alpha = 0.999 if use_mean_teacher_model else None
sequence_target_positions = [2] if sequence_target_positions is None else sequence_target_positions
image_column = "image" if use_combined_model else None
categorical_feature_encoder = CategoricalToOneHotEncoder.create_from_dataframe(
dataframe=_get_mock_sequence_dataset(), columns=["cat1"])
super().__init__(
local_dataset=full_ml_test_data_path("sequence_data_for_classification"),
temperature_scaling_config=TemperatureScalingConfig(),
label_value_column="label",
numerical_columns=["numerical1", "numerical2"],
categorical_columns=["cat1"],
categorical_feature_encoder=categorical_feature_encoder,
sequence_column="seqColumn",
sequence_target_positions=sequence_target_positions,
image_file_column=image_column,
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
num_epochs=num_epochs,
num_dataload_workers=0,
train_batch_size=3,
l_rate=1e-1,
load_segmentation=True,
use_mixed_precision=True,
label_smoothing_eps=0.05,
drop_last_batch_in_training=True,
mean_teacher_alpha=mean_teacher_alpha,
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
max_num_gpus=1,
**kwargs
)
self.use_combined_model = use_combined_model
self.imaging_feature_type = imaging_feature_type
self.combine_hidden_state = combine_hidden_states
self.use_encoder_layer_norm = use_encoder_layer_norm
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return DatasetSplits.from_proportions(
df=dataset_df,
proportion_train=0.7,
proportion_test=0.2,
proportion_val=0.1,
)
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
if self.use_combined_model:
return ModelTransformsPerExecutionMode(
train=ImageTransformationPipeline(
transforms=[RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
ColorJitter(brightness=0.2)]))
else:
return ModelTransformsPerExecutionMode()
def create_model(self) -> RNNClassifier:
if self.use_combined_model:
image_encoder: Optional[ImageEncoder] = \
ImageEncoder(num_image_channels=1,
imaging_feature_type=self.imaging_feature_type,
num_non_image_features=self.get_total_number_of_non_imaging_features(),
stride_size_per_encoding_block=(1, 2, 2),
initial_feature_channels=4,
num_encoder_blocks=3,
)
assert image_encoder is not None # for mypy
input_dims = image_encoder.final_num_feature_channels
else:
image_encoder = None
input_dims = self.get_total_number_of_non_imaging_features()
ref_indices = [0, 1] if self.combine_hidden_state else None
return RNNClassifierWithEncoder(input_dim=input_dims,
hidden_dim=3,
output_dim=1,
num_rnn_layers=1,
rnn_dropout=0.0,
ref_indices=ref_indices,
image_encoder=image_encoder,
use_encoder_batch_norm=self.use_encoder_layer_norm,
target_indices=self.get_target_indices())
def _get_mock_sequence_dataset(dataset_contents: Optional[str] = None) -> pd.DataFrame:
# The dataset has "measurements" for 3 different positions 0, 1, and 2, with columns for numerical1 and numerical2.
# Labels are attached to position 3 only.
if dataset_contents is None:
dataset_contents = """subject,numerical1,numerical2,cat1,seqColumn,label,image
2137.00005,362,71,A,0,0,scan1.npy
2137.00005,357,69,B,1,0,scan2.npy
2137.00005,355,64,C,2,0,scan3.npy
2137.00005,355,63,C,3,1,scan4.npy
2137.00125,348,64,A,0,0,scan1.npy
2137.00125,316,68,A,1,0,scan3.npy
2137.00125,349,68,A,2,0,scan2.npy
2137.00125,361,67,B,3,0,scan1.npy
2137.00125,350,68,B,4,0,scan1.npy
2627.00001,477,58,A,0,0,scan2.npy
2627.00001,220,59,A,1,0,scan2.npy
2627.00001,222,60,A,2,0,scan1.npy
2627.00001,217,65,A,5,1,scan3.npy
2627.12341,210,60,B,0,0,scan4.npy
2627.12341,217,61,B,1,0,scan1.npy
2627.12341,224,63,B,2,1,scan2.npy
3250.00005,344,76,C,0,0,scan2.npy
3250.00005,233,76,C,1,0,scan4.npy
3250.00005,212,84,C,2,0,scan3.npy
3250.00005,215,84,C,3,0,scan1.npy
3250.00005,215,82,C,4,0,scan1.npy
3250.12345,233,84,C,0,0,scan3.npy
3250.12345,218,84,C,1,0,scan3.npy
3250.12345,221,84,C,2,0,scan1.npy
3250.12345,238,84,C,3,0,scan1.npy
"""
return pd.read_csv(StringIO(dataset_contents), dtype=str)
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
[(False, ImagingFeatureType.Image),
(True, ImagingFeatureType.Image),
(True, ImagingFeatureType.Segmentation),
(True, ImagingFeatureType.ImageAndSegmentation)])
@pytest.mark.parametrize("combine_hidden_state", (True, False))
@pytest.mark.parametrize("use_encoder_layer_norm", (True, False))
@pytest.mark.parametrize("use_mean_teacher_model", (False,))
@pytest.mark.gpu
def test_rnn_classifier_via_config_1(use_combined_model: bool,
imaging_feature_type: ImagingFeatureType,
combine_hidden_state: bool,
use_encoder_layer_norm: bool,
use_mean_teacher_model: bool,
test_output_dirs: OutputFolderForTests) -> None:
"""
Test if we can build a simple RNN model that only feeds off non-image features.
This just tests the mechanics of training, but not if the model learned.
"""
logging_to_stdout()
config = ToySequenceModel(use_combined_model,
imaging_feature_type=imaging_feature_type,
combine_hidden_states=combine_hidden_state,
use_encoder_layer_norm=use_encoder_layer_norm,
use_mean_teacher_model=use_mean_teacher_model,
should_validate=False)
config.use_mixed_precision = True
# Necessary because torch otherwise says "index_add_cuda_ does not have a deterministic implementation"
config.pl_deterministic = False
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_mock_sequence_dataset()
# Patch the load_images function that will be called once we access a dataset item
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
segmentations=np.random.randint(0, 2, SCAN_SIZE))
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
model_train_unittest(config, output_folder=test_output_dirs)
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
[(False, ImagingFeatureType.Image),
(True, ImagingFeatureType.Image),
(True, ImagingFeatureType.Segmentation),
(True, ImagingFeatureType.ImageAndSegmentation)])
def test_run_ml_with_sequence_model(use_combined_model: bool,
imaging_feature_type: ImagingFeatureType,
test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models, when it is started together via run_ml.
"""
logging_to_stdout()
config = ToySequenceModel(use_combined_model, imaging_feature_type,
should_validate=False, sequence_target_positions=[2, 10])
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_mock_sequence_dataset()
config.num_epochs = 1
config.max_batch_grad_cam = 1
# make sure we are testing with at least one sequence position that will not exist
# to ensure correct handling of sequences that do not contain all the expected target positions
assert max(config.sequence_target_positions) > config.dataset_data_frame[config.sequence_column].astype(float).max()
# Patch the load_images function that will be called once we access a dataset item
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
segmentations=np.random.randint(0, 2, SCAN_SIZE))
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
azure_config = get_default_azure_config()
azure_config.train = True
MLRunner(config, azure_config=azure_config).run()
@pytest.mark.parametrize(["use_combined_model", "imaging_feature_type"],
[(False, ImagingFeatureType.Image),
(True, ImagingFeatureType.Image),
(True, ImagingFeatureType.Segmentation),
(True, ImagingFeatureType.ImageAndSegmentation)])
def test_visualization_with_sequence_model(use_combined_model: bool,
imaging_feature_type: ImagingFeatureType,
test_output_dirs: OutputFolderForTests) -> None:
config = ToySequenceModel(use_combined_model, imaging_feature_type, should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_mock_sequence_dataset()
config.num_epochs = 1
model = config.create_model()
if config.use_gpu:
model = model.cuda()
dataloader = SequenceDataset(config,
data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
batch_size=2)
# Patch the load_images function that will be called once we access a dataset item
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
segmentations=np.random.randint(0, 2, SCAN_SIZE))
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
batch = next(iter(dataloader))
if config.use_gpu:
batch = transfer_batch_to_device(batch, torch.device(0))
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
target_indices=config.get_target_indices(),
sample=batch) # type: ignore
number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
number_subjects = len(model_inputs_and_labels.subject_ids)
visualizer = VisualizationMaps(model, config)
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
model_inputs_and_labels.model_inputs)
if use_combined_model:
if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
assert grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
else:
assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences)
assert grad_cams.shape[:2] == (number_subjects, number_sequences)
else:
assert guided_grad_cams is None
assert grad_cams is None
assert pseudo_cam_non_img.shape[:2] == (number_subjects, number_sequences)
assert probas.shape[0] == number_subjects
non_image_features = config.numerical_columns + config.categorical_columns
non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
non_image_features,
index=0,
target_position=3)
assert non_imaging_plot_labels == ['numerical1_0',
'numerical2_0',
'cat1_0',
'numerical1_1',
'numerical2_1',
'cat1_1',
'numerical1_2',
'numerical2_2',
'cat1_2',
'numerical1_3',
'numerical2_3',
'cat1_3']
class ToySequenceModel2(SequenceModelBase):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
temperature_scaling_config=TemperatureScalingConfig(),
local_dataset=full_ml_test_data_path("sequence_data_for_classification"),
label_value_column="label",
numerical_columns=["feature"],
sequence_column="index",
sequence_target_positions=[2],
loss_type=ScalarLoss.BinaryCrossEntropyWithLogits,
num_epochs=20,
num_dataload_workers=0,
train_batch_size=40,
l_rate=1e-2,
drop_last_batch_in_training=True,
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
max_num_gpus=1,
**kwargs
)
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return DatasetSplits.from_proportions(
df=dataset_df,
proportion_train=0.8,
proportion_test=0.1,
proportion_val=0.1,
)
def create_model(self) -> Any:
return RNNClassifier(input_dim=self.get_total_number_of_non_imaging_features(),
hidden_dim=12,
output_dim=1,
num_rnn_layers=1,
rnn_dropout=0.25,
use_layer_norm=False,
target_indices=self.get_target_indices())
@pytest.mark.gpu
def test_rnn_classifier_via_config_2(test_output_dirs: OutputFolderForTests) -> None:
"""
Test if we can build an RNN classifier that learns sequences, of the same kind as in
test_rnn_classifier_toy_problem, but built via the config.
Only test the non-combined model because otherwise the build takes too much time.
"""
expected_max_train_loss = 0.71
expected_max_val_loss = 0.71
num_sequences = 100
ml_util.set_random_seed(123)
dataset_contents = "subject,index,feature,label\n"
for subject in range(num_sequences):
# Sequences have variable length
sequence_length = np.random.choice([9, 10, 11, 12])
# Each sequence is a series of 0 and 1
inputs = np.random.choice([0, 1], size=(sequence_length,), p=[1. / 3, 2. / 3])
label = np.sum(inputs) > (sequence_length // 2)
for i, value in enumerate(inputs.tolist()):
dataset_contents += f"S{subject},{i},{value},{label}\n"
logging_to_stdout()
config = ToySequenceModel2(should_validate=False)
# Necessary because torch otherwise says "index_add_cuda_ does not have a deterministic implementation"
config.pl_deterministic = False
config.num_epochs = 2
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents)
results, _ = model_train_unittest(config, output_folder=test_output_dirs)
actual_train_loss = results.get_metric(is_training=True, metric_type=MetricType.LOSS.value)[-1]
actual_val_loss = results.get_metric(is_training=False, metric_type=MetricType.LOSS.value)[-1]
print(f"Training loss after {config.num_epochs} epochs: {actual_train_loss}")
print(f"Validation loss after {config.num_epochs} epochs: {actual_val_loss}")
assert actual_train_loss <= expected_max_train_loss, "Training loss too high"
assert actual_val_loss <= expected_max_val_loss, "Validation loss too high"
# Issue #374: put back in when temperature scaling is enabled again
# assert np.allclose(results.optimal_temperature_scale_values_per_checkpoint_epoch, [0.97], rtol=0.1)
class ToyMultiLabelSequenceModel(SequenceModelBase):
def __init__(self, **kwargs: Any) -> None:
num_epochs = 3
super().__init__(
temperature_scaling_config=TemperatureScalingConfig(),
label_value_column="Label",
numerical_columns=["NUM1", "NUM2"],
sequence_column="Position",
sequence_target_positions=[1, 2, 3],
loss_type=ScalarLoss.WeightedCrossEntropyWithLogits,
num_epochs=num_epochs,
num_dataload_workers=0,
train_batch_size=3,
l_rate=1e-1,
label_smoothing_eps=0.05,
categorical_columns=["CAT1"],
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.
max_num_gpus=1,
**kwargs
)
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return DatasetSplits.from_proportions(
df=dataset_df,
proportion_train=0.7,
proportion_test=0.2,
proportion_val=0.1,
)
def create_model(self) -> Any:
return RNNClassifier(input_dim=self.get_total_number_of_non_imaging_features(),
hidden_dim=3,
output_dim=1,
num_rnn_layers=2,
rnn_dropout=0.0,
target_indices=self.get_target_indices())
def test_run_ml_with_multi_label_sequence_model(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models that predicts at multiple time points,
when it is started via run_ml.
"""
logging_to_stdout()
config = ToyMultiLabelSequenceModel()
assert config.get_target_indices() == [1, 2, 3]
expected_prediction_targets = [f"{SEQUENCE_POSITION_HUE_NAME_PREFIX} {x}"
for x in ["01", "02", "03"]]
_target_indices = config.get_target_indices()
assert _target_indices is not None
assert len(_target_indices) == len(expected_prediction_targets)
metrics_dict = create_metrics_dict_for_scalar_models(config)
assert metrics_dict.get_hue_names(include_default=False) == expected_prediction_targets
config.set_output_to(test_output_dirs.root_dir)
# Create a fake dataset directory to make config validation pass
config.local_dataset = test_output_dirs.root_dir
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
config.pre_process_dataset_dataframe()
config.num_epochs = 1
config.max_batch_grad_cam = 1
azure_config = get_default_azure_config()
azure_config.train = True
MLRunner(config, azure_config=azure_config).run()
# The metrics file should have one entry per epoch per subject per prediction target,
# for all the 3 prediction targets.
metrics_file = config.outputs_folder / "Train" / SUBJECT_METRICS_FILE_NAME
assert metrics_file.exists()
metrics = pd.read_csv(metrics_file)
assert LoggingColumns.Patient.value in metrics
assert LoggingColumns.Epoch.value in metrics
assert LoggingColumns.Hue.value in metrics
assert metrics[LoggingColumns.Hue.value].unique().tolist() == expected_prediction_targets
group_by_subject = metrics.groupby(by=[LoggingColumns.Patient.value,
LoggingColumns.Epoch.value])
expected_prediction_target_lengths = [3, 2, 3, 3]
for i, x in enumerate(group_by_subject):
assert len(x[1]) == expected_prediction_target_lengths[i]
group_by_subject_and_target = metrics.groupby(by=[LoggingColumns.Patient.value,
LoggingColumns.Epoch.value,
LoggingColumns.Hue.value])
for _, group in group_by_subject_and_target:
assert len(group) == 1
@pytest.mark.parametrize("combine_hidden_states", [True, False])
def test_pad_gru_output(combine_hidden_states: bool) -> None:
"""
Test to make sure if model output does not cover the target indices then it is padded
"""
config = ToySequenceModel(
sequence_target_positions=[5, 7],
combine_hidden_states=combine_hidden_states,
should_validate=False
)
model: RNNClassifier = config.create_model()
# base case where no padding is required
test_input = torch.rand(max(config.get_target_indices()) + 1, 1)
padded = model.pad_gru_output(test_input)
assert torch.equal(test_input, padded)
# case when padding is required
test_input = torch.rand(min(config.get_target_indices()) - 1, 1)
expected = torch.cat([test_input, test_input.new_full((4, 1), fill_value=0)], dim=0)
padded = model.pad_gru_output(test_input)
assert torch.allclose(expected, padded)
def test_visualization_for_different_target_weeks(test_output_dirs: OutputFolderForTests) -> None:
"""
Tests that the visualizations are differentiated depending on the target week
for which we visualize it.
"""
config = ToyMultiLabelSequenceModel(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
config.pre_process_dataset_dataframe()
model = create_model_with_temperature_scaling(config)
dataloader = SequenceDataset(config,
data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
batch_size=2)
batch = next(iter(dataloader))
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
target_indices=config.get_target_indices(),
sample=batch)
visualizer = VisualizationMaps(model, config)
if config.use_gpu:
device = visualizer.grad_cam.device
batch = transfer_batch_to_device(batch, device)
model = model.to(device)
# Pseudo-grad cam explaining the prediction at target sequence 2
_, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(model_inputs_and_labels.model_inputs,
target_position=2,
target_label_index=2)
# Pseudo-grad cam explaining the prediction at target sequence 0
_, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(model_inputs_and_labels.model_inputs,
target_position=0,
target_label_index=0)
assert pseudo_cam_non_img_1.shape[1] == 1
assert pseudo_cam_non_img_3.shape[1] == 3
# Both visualizations should not be equal
assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
assert np.any(probas_3 != probas_1)
def _get_multi_label_sequence_dataframe() -> pd.DataFrame:
"""
Returns a mock dataset for multi label sequence model.
"""
dataset_contents = """subject,NUM1,CAT1,NUM2,Position,Label
2137.00005,362,A,71,0,
2137.00005,357,B,69,1,0
2137.00005,355,C,64,2,0
2137.00005,355,C,63,3,1
2137.00125,348,A,64,0,0
2137.00125,316,A,68,1,1
2137.00125,349,B,68,2,0
2137.00125,361,B,67,3,1
2137.00125,350,B,68,4,0
2627.00001,477,C,58,0,0
2627.00001,220,C,59,1,0
2627.00001,222,A,60,2,0
2627.00001,217,A,65,5,1
2627.12341,210,B,60,0,0
2627.12341,217,B,61,1,0
2627.12341,224,B,63,2,1
3250.00005,344,B,76,0,0
3250.00005,233,A,76,1,0
3250.00005,212,A,84,2,0
3250.00005,215,A,84,3,0
3250.00005,215,A,82,4,0
3250.12345,233,A,84,0,1
3250.12345,218,A,84,1,0
3250.12345,221,B,84,2,0
3250.12345,238,B,84,3,0
"""
return pd.read_csv(StringIO(dataset_contents), dtype=str)
def test_sequence_dataset_stats_hook(test_output_dirs: OutputFolderForTests) -> None:
model = ToySequenceModel()
model.set_output_to(test_output_dirs.root_dir)
model.dataset_data_frame = _get_mock_sequence_dataset()
model.create_and_set_torch_datasets()
length_file = model.logs_folder / SEQUENCE_LENGTH_FILE
assert length_file.is_file()
assert length_file.read_text().splitlines() == [
"cross_validation_split_index,data_split,subject,sequence_length",
"-1,Train,2137.00005,4",
"-1,Train,2627.12341,3",
"-1,Train,3250.00005,5",
"-1,Train,3250.12345,4",
"-1,Test,2627.00001,3",
"-1,Val,2137.00125,5"]
stats_file = model.logs_folder / SEQUENCE_LENGTH_STATS_FILE
assert stats_file.is_file()
assert stats_file.read_text().splitlines() == [
" sequence_length ",
" count mean std min 25% 50% 75% max",
"data_split ",
"Test 1.0 3.0 NaN 3.0 3.00 3.0 3.00 3.0",
"Train 4.0 4.0 0.816497 3.0 3.75 4.0 4.25 5.0",
"Val 1.0 5.0 NaN 5.0 5.00 5.0 5.00 5.0"]

Просмотреть файл

@ -19,19 +19,16 @@ from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset, ScalarItemAugmentation
from InnerEye.ML.lightning_models import transfer_batch_to_device
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoderWithMlp, \
ImagingFeatureType
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import AggregationType, ScalarLoss, ScalarModelBase, get_non_image_features_dict
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES, segmentation_to_one_hot
from InnerEye.ML.utils.io_util import ImageAndSegmentations, NumpyFile
from InnerEye.ML.utils.ml_util import is_gpu_available, set_random_seed
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling
from InnerEye.ML.utils.split_dataset import DatasetSplits
from InnerEye.ML.visualizers.grad_cam_hooks import VisualizationMaps
from InnerEye.ML.visualizers.model_summary import ModelSummary
from Tests.ML.util import get_default_azure_config, model_train_unittest
@ -324,86 +321,3 @@ def test_segmentation_to_one_hot(use_gpu: bool, input_on_gpu: bool) -> None:
else:
expected = torch.zeros((B,) + dim, device=one_hot.device)
assert one_hot[:, i, ...].float().allclose(expected), f"Dimension {i} should have all ones"
@pytest.mark.parametrize("encode_channels_jointly", [True, False])
@pytest.mark.parametrize("use_non_imaging_features", [True, False])
@pytest.mark.parametrize("imaging_feature_type", [ImagingFeatureType.Image,
ImagingFeatureType.Segmentation,
ImagingFeatureType.ImageAndSegmentation])
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
imaging_feature_type: ImagingFeatureType,
encode_channels_jointly: bool,
test_output_dirs: OutputFolderForTests) -> None:
dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
S1,week0,scan1.npy,,1,10,Male,Val1
S1,week1,scan2.npy,True,2,20,Female,Val2
S2,week0,scan3.npy,,3,30,Female,Val3
S2,week1,scan4.npy,False,4,40,Female,Val1
S3,week0,scan1.npy,,5,50,Male,Val2
S3,week1,scan3.npy,True,6,60,Male,Val2
"""
dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
specific_channels={"categorical2": ["week1"]}) \
if use_non_imaging_features else {}
config = ImageEncoder(
local_dataset=Path(),
encode_channels_jointly=encode_channels_jointly,
should_validate=False,
numerical_columns=numerical_columns,
categorical_columns=categorical_columns,
imaging_feature_type=imaging_feature_type,
non_image_feature_channels=non_image_feature_channels,
categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
dataframe=dataset_dataframe, columns=categorical_columns)
)
dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
.as_data_loader(shuffle=False, batch_size=2)
config.set_output_to(test_output_dirs.root_dir)
config.num_epochs = 1
model = create_model_with_temperature_scaling(config)
visualizer = VisualizationMaps(model, config)
# Patch the load_images function that will be called once we access a dataset item
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
segmentations=np.random.randint(0, 2, (6, 64, 60)))
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
batch = next(iter(dataloader))
if config.use_gpu:
device = visualizer.grad_cam.device
batch = transfer_batch_to_device(batch, device)
visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
target_indices=[],
sample=batch)
number_channels = len(config.image_channels)
number_subjects = len(model_inputs_and_labels.subject_ids)
guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
model_inputs_and_labels.model_inputs)
if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
else:
assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)
assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
else (number_subjects, number_channels)
if use_non_imaging_features:
non_image_features = config.numerical_columns + config.categorical_columns
non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
non_image_features,
index=0)
assert non_imaging_plot_labels == ['numerical1_week1',
'numerical1_week0',
'numerical2_week1',
'numerical2_week0',
'categorical1_week1',
'categorical1_week0',
'categorical2_week1']
assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))

Просмотреть файл

@ -24,7 +24,6 @@ from InnerEye.ML.pipelines.scalar_inference import ScalarEnsemblePipeline, Scala
from InnerEye.ML.run_ml import is_classification_model
from InnerEye.ML.scalar_config import EnsembleAggregationType
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToySequenceModel
from Tests.ML.utils.test_model_util import create_model_and_store_checkpoint
@ -190,7 +189,6 @@ def test_is_classification_model() -> None:
assert is_classification_model(GlaucomaPublic())
assert is_classification_model(ClassificationModelForTesting())
assert not is_classification_model(BasicModel2Epochs())
assert not is_classification_model(ToySequenceModel())
def test_inference_required_single_runs() -> None:

Просмотреть файл

@ -1,74 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Optional, Sequence
import numpy as np
import pytest
import torch
from torch.nn.utils.rnn import pack_sequence
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels, map_packed_sequence_data, \
sequences_to_padded_tensor
def test_map_packed_sequence_data() -> None:
"""
Test to ensure helper function to apply a map transform to a packed sequence returns expected results.
"""
packed = pack_sequence([torch.tensor([[1.0], [2.0]])], enforce_sorted=False)
mapped = map_packed_sequence_data(packed, lambda x: x * 2)
assert torch.equal(mapped.data, torch.tensor([[2.0], [4.0]]))
with pytest.raises(Exception):
map_packed_sequence_data(packed, lambda x: x.unsqueeze(dim=0))
def test_get_masked_model_outputs_and_labels() -> None:
"""
Test to ensure the helper function to get masked model outputs, labels and their associated subject ids
returns the expected results.
"""
def _create_masked_and_check_expected(_model_outputs: torch.Tensor,
_labels: torch.Tensor,
_subject_ids: Sequence[str],
_sorted_indices: Optional[torch.Tensor] = None) -> None:
_masked = get_masked_model_outputs_and_labels(_model_outputs, _labels, _subject_ids)
assert _masked is not None
sorted_indices = _masked.labels.sorted_indices if _sorted_indices is None else _sorted_indices
if sorted_indices is not None:
_labels = _labels[sorted_indices]
_model_outputs = _model_outputs[sorted_indices]
_subject_ids = np.array(_subject_ids)[sorted_indices].tolist()
_expected_labels = _labels.transpose(dim0=0, dim1=1).flatten()
_mask = ~torch.isnan(_expected_labels)
_expected_labels = _expected_labels[_mask]
_expected_model_outputs = _model_outputs.transpose(dim0=0, dim1=1).flatten()[_mask]
_expected_subject_ids = _subject_ids
assert torch.equal(_expected_model_outputs, _masked.model_outputs.data)
assert torch.equal(_expected_labels, _masked.labels.data)
assert _expected_subject_ids[:_masked.labels.sorted_indices.shape[0]] == _masked.subject_ids
# test base case where no masking needs to be applied
model_outputs = torch.rand((3, 4, 1))
labels = torch.rand((3, 4, 1)).round()
subject_ids = ['1', '2', '3']
_create_masked_and_check_expected(model_outputs, labels, subject_ids)
# test with unequal length sequences where masking will be performed
model_outputs = sequences_to_padded_tensor([torch.rand(x + 1, 1) for x in range(3)], padding_value=np.nan)
labels = sequences_to_padded_tensor([torch.rand(x + 1, 1) for x in range(3)], padding_value=np.nan)
_create_masked_and_check_expected(model_outputs, labels, subject_ids)
# test where one sequence is totally removed
model_outputs[0] = np.nan
labels[0] = np.nan
_create_masked_and_check_expected(model_outputs, labels, subject_ids, _sorted_indices=torch.tensor([2, 1, 0]))

Просмотреть файл

@ -13,14 +13,13 @@ from pandas.core.dtypes.common import is_string_dtype
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.common_util import CROSSVAL_RESULTS_FOLDER, FULL_METRICS_DATAFRAME_FILE, \
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME, logging_to_stdout
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import LoggingColumns
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode
from InnerEye.ML.deep_learning_config import ModelCategory
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.utils.csv_util import CSV_INSTITUTION_HEADER, CSV_SERIES_HEADER
from InnerEye.ML.visualizers.plot_cross_validation import COL_MODE, \
METRICS_BY_MODE_AND_STRUCTURE_FILE, METRICS_BY_MODE_FILE, \
@ -29,8 +28,6 @@ from InnerEye.ML.visualizers.plot_cross_validation import COL_MODE, \
create_results_breakdown, download_crossval_result_files, get_split_id, load_dataframes, \
plot_cross_validation_from_files, save_outliers
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, get_most_recent_run_id
from Tests.ML.models.architectures.sequential.test_rnn_classifier import ToyMultiLabelSequenceModel, \
_get_multi_label_sequence_dataframe
from Tests.ML.util import assert_text_files_match, get_default_azure_config
@ -430,29 +427,6 @@ def test_download_or_get_local_file_2(test_output_dirs: OutputFolderForTests) ->
assert file_in_folder == download_to_folder / file2
@pytest.mark.skip(reason="This test is only used to create input for test_load_files_with_prediction_target")
def test_run_ml_with_multi_label_sequence_in_crossval(test_output_dirs: OutputFolderForTests) -> None:
"""
Test training and testing of sequence models that predicts at multiple time points,
including aggregation of cross validation results.
"""
logging_to_stdout()
config = ToyMultiLabelSequenceModel(should_validate=False)
assert config.get_target_indices() == [1, 2, 3]
expected_prediction_targets = ["Seq_pos 01", "Seq_pos 02", "Seq_pos 03"]
target_indices = config.get_target_indices()
assert target_indices
assert len(target_indices) == len(expected_prediction_targets)
config.set_output_to(test_output_dirs.root_dir)
config.dataset_data_frame = _get_multi_label_sequence_dataframe()
config.pre_process_dataset_dataframe()
config.num_epochs = 1
config.number_of_cross_validation_splits = 2
azure_config = get_default_azure_config()
azure_config.train = True
MLRunner(config, azure_config=azure_config).run()
def test_load_files_with_prediction_target() -> None:
"""
For multi-week RNNs that predict at multiple sequence points: Test that the dataframes

Просмотреть файл

@ -73,7 +73,7 @@ steps:
scanType: 'Register'
verbosity: 'Normal'
alertWarningLevel: 'High'
failOnAlert: true
failOnAlert: false
failOnStderr: true
- task: PublishBuildArtifacts@1

Просмотреть файл

@ -16,5 +16,5 @@ steps:
scanType: 'Register'
verbosity: 'Normal'
alertWarningLevel: 'High'
failOnAlert: true
failOnAlert: false
failOnStderr: true

Просмотреть файл

Просмотреть файл

@ -119,64 +119,12 @@ You can find instructions for other Linux distributions on NVidia website: https
The following steps describe how to set up specific tools. You can execute most of those at a later
point, if you want to dig deeper into the code.
## PyCharm
## VSCode
Our team uses [PyCharm](https://www.jetbrains.com/pycharm/) for development, but any good editor
([VSCode](https://code.visualstudio.com/) for example) will do as well.
([VSCode](https://code.visualstudio.com/) for example)
This repository already contains a PyCharm configuration file in `.idea/InnerEye-DeepLearning.iml`. It should
automatically pick the WSL Python interpreter (see [WSL.md](WSL.md)) as the default (no need to import the settings file)
- if it doesn't happen you will need to adjust that as described [here](https://www.jetbrains.com/help/pycharm/configuring-python-interpreter.html).
## How to manually set up flake8 as a PyCharm external tool
Go to File / Settings / Tools / External Tools / Add.
* Name: Flake8
* Program: $PyInterpreterDirectory$/python
* Arguments: -m flake8 $ProjectFileDir$
* Working directory: $ProjectFileDir$
* Advanced Options / Output Filters: $FILE_PATH$\:$LINE$\:$COLUMN$\:.*
Run Flake8 by right-clicking on a source file, External Tools / Flake8
## How to manually set up mypy as a PyCharm external tool
Go to File / Settings / Tools / External Tools / Add.
* Name: mypy
* Program: $PyInterpreterDirectory$/python
* Arguments: $ProjectFileDir$/mypy_runner.py -m <path to mypy executable>
You can find the path to the mypy executable by typing `where mypy` on Windows or `which mypy` on Linux.
If you have configured a virtual environment in PyCharm, the path will usually be
`$PyInterpreterDirectory$/Scripts/mypy.exe` on Windows and `$PyInterpreterDirectory$/mypy` on Linux.
* Working directory: $ProjectFileDir$
* Advanced Options / Output Filters: $FILE_PATH$\:$LINE$\:.*
Run mypy by right-clicking on a source file, External Tools / mypy
## Deleting and creating a Conda environment
To delete, make sure the environment being deleted is not your current environment (just run `deactivate`). Then run
`conda env remove --name environmentToDelete`.
To create an environment from scratch and then export it to a YAML file:
conda create --name envName python
pip install whatEverPackage
pip install packageWithVersion==1.0.42
conda env export --no-builds --file=my_env.yml
With conda installation, the Apex library is built without the C++ files that are intended be used in backend-op
computations such as fused_adam and fused_layernorm. This is mainly because we are unable to pass the
required input install arguments to the setup file through a conda environment file. By building the library with
these arguments, one could expect further speed-ups in both forward-backward model passes. If you are interested in
installing Apex with these flags, please run the following commands in your shell:
git clone https://github.com/NVIDIA/apex; cd apex
git checkout 880ab925bce9f817a93988b021e12db5f67f7787
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
## Conda
- `conda env create -f environment.yml`
## Conda updates

Просмотреть файл

@ -1,5 +1,6 @@
# How to setup Azure Machine Learning for InnerEye
Our preferred way to use AzureML is using the [AzureTRE](https://microsoft.github.io/AzureTRE/)
In order to be able to train models on Azure Machine Learning (AML) you will need to setup your environment in the
Azure Portal first. In this document we will walk you through this process step-by-step.