Enable tiling non-PANDA WSI datasets (#621)
* Add basic dataset and environment changes * Add loading/preproc utils * Back-up PANDA tiling scripts * Refactor and generalise tiling scripts * Remove Azure scripts * Add test WSI file * Add preprocessing tests * Update changelog * Add Linux condition for cuCIM in environment.yml * Use PANDA instead of TCGA-PRAD in test * Leave TcgaPradDataset as an example * Fix skipped InnerEye dataset tests * Create and test mock slides dataset * Remove Tests/ML/datasets from pytest discovery
This commit is contained in:
Родитель
276e0f5253
Коммит
6a4d334a99
|
@ -15,3 +15,4 @@
|
||||||
*.dcm filter=lfs diff=lfs merge=lfs -text
|
*.dcm filter=lfs diff=lfs merge=lfs -text
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
|
@ -38,6 +38,7 @@ jobs that run in AzureML.
|
||||||
- ([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
|
- ([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
|
||||||
- ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
|
- ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
|
||||||
- ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
|
- ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
|
||||||
|
- ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
|
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -12,6 +12,8 @@ import torch
|
||||||
from sklearn.utils.class_weight import compute_class_weight
|
from sklearn.utils.class_weight import compute_class_weight
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||||
|
|
||||||
|
|
||||||
class TilesDataset(Dataset):
|
class TilesDataset(Dataset):
|
||||||
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.
|
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.
|
||||||
|
@ -71,7 +73,7 @@ class TilesDataset(Dataset):
|
||||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||||
dataset_df = pd.read_csv(self.dataset_csv)
|
dataset_df = pd.read_csv(self.dataset_csv)
|
||||||
|
|
||||||
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN,
|
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN,
|
||||||
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
|
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
|
||||||
for column in columns:
|
for column in columns:
|
||||||
if column is not None and column not in dataset_df.columns:
|
if column is not None and column not in dataset_df.columns:
|
||||||
|
@ -110,3 +112,109 @@ class TilesDataset(Dataset):
|
||||||
classes = np.unique(slide_labels)
|
classes = np.unique(slide_labels)
|
||||||
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
|
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
|
||||||
return torch.as_tensor(class_weights)
|
return torch.as_tensor(class_weights)
|
||||||
|
|
||||||
|
|
||||||
|
class SlidesDataset(Dataset):
|
||||||
|
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
|
||||||
|
|
||||||
|
The output dictionaries are indexed by `..utils.naming.SlideKey`.
|
||||||
|
|
||||||
|
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
|
||||||
|
:param IMAGE_COLUMN: CSV column name for relative path to image file.
|
||||||
|
:param LABEL_COLUMN: CSV column name for tile label.
|
||||||
|
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
|
||||||
|
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
|
||||||
|
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
|
||||||
|
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
|
||||||
|
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
|
||||||
|
"""
|
||||||
|
SLIDE_ID_COLUMN: str = 'slide_id'
|
||||||
|
IMAGE_COLUMN: str = 'image'
|
||||||
|
LABEL_COLUMN: str = 'label'
|
||||||
|
MASK_COLUMN: Optional[str] = None
|
||||||
|
SPLIT_COLUMN: Optional[str] = None
|
||||||
|
|
||||||
|
TRAIN_SPLIT_LABEL: str = 'train'
|
||||||
|
TEST_SPLIT_LABEL: str = 'test'
|
||||||
|
|
||||||
|
METADATA_COLUMNS: Tuple[str, ...] = ()
|
||||||
|
|
||||||
|
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
||||||
|
|
||||||
|
N_CLASSES: int = 1 # binary classification by default
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
root: Union[str, Path],
|
||||||
|
dataset_csv: Optional[Union[str, Path]] = None,
|
||||||
|
dataset_df: Optional[pd.DataFrame] = None,
|
||||||
|
train: Optional[bool] = None,
|
||||||
|
validate_columns: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
:param root: Root directory of the dataset.
|
||||||
|
:param dataset_csv: Full path to a dataset CSV file, containing at least
|
||||||
|
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
|
||||||
|
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
|
||||||
|
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
||||||
|
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
||||||
|
:param train: If `True`, loads only the training split (resp. `False` for test split). By
|
||||||
|
default (`None`), loads the entire dataset as-is.
|
||||||
|
:param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`.
|
||||||
|
"""
|
||||||
|
if self.SPLIT_COLUMN is None and train is not None:
|
||||||
|
raise ValueError("Train/test split was specified but dataset has no split column")
|
||||||
|
|
||||||
|
self.root_dir = Path(root)
|
||||||
|
|
||||||
|
if dataset_df is not None:
|
||||||
|
self.dataset_csv = None
|
||||||
|
else:
|
||||||
|
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
||||||
|
dataset_df = pd.read_csv(self.dataset_csv)
|
||||||
|
|
||||||
|
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
|
||||||
|
if train is None:
|
||||||
|
self.dataset_df = dataset_df
|
||||||
|
else:
|
||||||
|
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
|
||||||
|
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
|
||||||
|
|
||||||
|
if validate_columns:
|
||||||
|
self.validate_columns()
|
||||||
|
|
||||||
|
def validate_columns(self) -> None:
|
||||||
|
"""Check that loaded dataframe contains expected columns, raises `ValueError` otherwise.
|
||||||
|
|
||||||
|
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
|
||||||
|
call `validate_columns()` after creating derived columns, for example.
|
||||||
|
"""
|
||||||
|
columns = [self.IMAGE_COLUMN, self.LABEL_COLUMN, self.MASK_COLUMN,
|
||||||
|
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
|
||||||
|
for column in columns:
|
||||||
|
if column is not None and column not in self.dataset_df.columns:
|
||||||
|
raise ValueError(f"Expected column '{column}' not found in the dataframe")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.dataset_df.shape[0]
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[SlideKey, Any]:
|
||||||
|
slide_id = self.dataset_df.index[index]
|
||||||
|
slide_row = self.dataset_df.loc[slide_id]
|
||||||
|
sample = {SlideKey.SLIDE_ID: slide_id}
|
||||||
|
|
||||||
|
rel_image_path = slide_row[self.IMAGE_COLUMN]
|
||||||
|
sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path)
|
||||||
|
# we're replicating this column because we want to propagate the path to the batch
|
||||||
|
sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE]
|
||||||
|
|
||||||
|
if self.MASK_COLUMN:
|
||||||
|
rel_mask_path = slide_row[self.MASK_COLUMN]
|
||||||
|
sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path)
|
||||||
|
sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK]
|
||||||
|
|
||||||
|
sample[SlideKey.LABEL] = slide_row[self.LABEL_COLUMN]
|
||||||
|
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_mask(cls) -> bool:
|
||||||
|
return cls.MASK_COLUMN is not None
|
||||||
|
|
|
@ -3,11 +3,13 @@
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
PANDA_DATASET_ID = "PANDA"
|
||||||
PANDA_TILES_DATASET_ID = "PANDA_tiles"
|
PANDA_TILES_DATASET_ID = "PANDA_tiles"
|
||||||
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
|
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
|
||||||
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"
|
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"
|
||||||
|
|
||||||
DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
|
DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
|
||||||
|
PANDA_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_DATASET_ID
|
||||||
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
|
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
|
||||||
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
|
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
|
||||||
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
|
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
|
||||||
|
|
|
@ -7,50 +7,42 @@ from pathlib import Path
|
||||||
from typing import Any, Dict, Union, Optional
|
from typing import Any, Dict, Union, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from cucim import CuImage
|
||||||
|
from health_ml.utils import box_utils
|
||||||
from monai.config import KeysCollection
|
from monai.config import KeysCollection
|
||||||
from monai.data.image_reader import ImageReader, WSIReader
|
from monai.data.image_reader import ImageReader, WSIReader
|
||||||
from monai.transforms import MapTransform
|
from monai.transforms import MapTransform
|
||||||
from openslide import OpenSlide
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from health_ml.utils import box_utils
|
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||||
|
|
||||||
|
|
||||||
class PandaDataset(Dataset):
|
class PandaDataset(SlidesDataset):
|
||||||
"""Dataset class for loading files from the PANDA challenge dataset.
|
"""Dataset class for loading files from the PANDA challenge dataset.
|
||||||
|
|
||||||
Iterating over this dataset returns a dictionary containing the `'image_id'`, paths to the `'image'`
|
Iterating over this dataset returns a dictionary following the `SlideKey` schema plus meta-data
|
||||||
and `'mask'` files, and the remaining meta-data from the original dataset (`'data_provider'`,
|
from the original dataset (`'data_provider'`, `'isup_grade'`, and `'gleason_score'`).
|
||||||
`'isup_grade'`, and `'gleason_score'`).
|
|
||||||
|
|
||||||
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
|
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
|
||||||
"""
|
"""
|
||||||
def __init__(self, root_dir: Union[str, Path], n_slides: Optional[int] = None,
|
SLIDE_ID_COLUMN = 'image_id'
|
||||||
frac_slides: Optional[float] = None) -> None:
|
IMAGE_COLUMN = 'image'
|
||||||
super().__init__()
|
MASK_COLUMN = 'mask'
|
||||||
self.root_dir = Path(root_dir)
|
LABEL_COLUMN = 'isup_grade'
|
||||||
self.train_df = pd.read_csv(self.root_dir / "train.csv", index_col='image_id')
|
|
||||||
if n_slides or frac_slides:
|
|
||||||
self.train_df = self.train_df.sample(n=n_slides, frac=frac_slides, replace=False,
|
|
||||||
random_state=1234)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score')
|
||||||
return self.train_df.shape[0]
|
|
||||||
|
|
||||||
def _get_image_path(self, image_id: str) -> Path:
|
DEFAULT_CSV_FILENAME = "train.csv"
|
||||||
return self.root_dir / "train_images" / f"{image_id}.tiff"
|
|
||||||
|
|
||||||
def _get_mask_path(self, image_id: str) -> Path:
|
def __init__(self,
|
||||||
return self.root_dir / "train_label_masks" / f"{image_id}_mask.tiff"
|
root: Union[str, Path],
|
||||||
|
dataset_csv: Optional[Union[str, Path]] = None,
|
||||||
def __getitem__(self, index: int) -> Dict:
|
dataset_df: Optional[pd.DataFrame] = None) -> None:
|
||||||
image_id = self.train_df.index[index]
|
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
|
||||||
return {
|
# PANDA CSV does not come with paths for image and mask files
|
||||||
'image_id': image_id,
|
slide_ids = self.dataset_df.index
|
||||||
'image': str(self._get_image_path(image_id).absolute()),
|
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
|
||||||
'mask': str(self._get_mask_path(image_id).absolute()),
|
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
|
||||||
**self.train_df.loc[image_id].to_dict()
|
self.validate_columns()
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
|
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
|
||||||
|
@ -96,10 +88,10 @@ class LoadPandaROId(MapTransform):
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box:
|
def _get_bounding_box(self, mask_obj: CuImage) -> box_utils.Box:
|
||||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||||
highest_level = mask_obj.level_count - 1
|
highest_level = mask_obj.resolutions['level_count'] - 1
|
||||||
scale = mask_obj.level_downsamples[highest_level]
|
scale = mask_obj.resolutions['level_downsamples'][highest_level]
|
||||||
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
||||||
|
|
||||||
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
||||||
|
@ -107,14 +99,14 @@ class LoadPandaROId(MapTransform):
|
||||||
return bbox
|
return bbox
|
||||||
|
|
||||||
def __call__(self, data: Dict) -> Dict:
|
def __call__(self, data: Dict) -> Dict:
|
||||||
mask_obj: OpenSlide = self.reader.read(data[self.mask_key])
|
mask_obj: CuImage = self.reader.read(data[self.mask_key])
|
||||||
image_obj: OpenSlide = self.reader.read(data[self.image_key])
|
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||||
|
|
||||||
level0_bbox = self._get_bounding_box(mask_obj)
|
level0_bbox = self._get_bounding_box(mask_obj)
|
||||||
|
|
||||||
# OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
|
||||||
# but relative region size in pixels at the chosen level
|
# but relative region size in pixels at the chosen level
|
||||||
scale = mask_obj.level_downsamples[self.level]
|
scale = mask_obj.resolutions['level_downsamples'][self.level]
|
||||||
scaled_bbox = level0_bbox / scale
|
scaled_bbox = level0_bbox / scale
|
||||||
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
|
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
|
||||||
size=(scaled_bbox.w, scaled_bbox.h),
|
size=(scaled_bbox.w, scaled_bbox.h),
|
||||||
|
|
|
@ -4,13 +4,14 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||||
|
|
||||||
|
|
||||||
class TcgaPradDataset(Dataset):
|
class TcgaPradDataset(SlidesDataset):
|
||||||
"""Dataset class for loading TCGA-PRAD slides.
|
"""Dataset class for loading TCGA-PRAD slides.
|
||||||
|
|
||||||
Iterating over this dataset returns a dictionary containing:
|
Iterating over this dataset returns a dictionary containing:
|
||||||
|
@ -19,16 +20,14 @@ class TcgaPradDataset(Dataset):
|
||||||
- `'image_path'` (str): absolute slide image path
|
- `'image_path'` (str): absolute slide image path
|
||||||
- `'label'` (int, 0 or 1): label for predicting positive or negative
|
- `'label'` (int, 0 or 1): label for predicting positive or negative
|
||||||
"""
|
"""
|
||||||
SLIDE_ID_COLUMN: str = 'slide_id'
|
|
||||||
CASE_ID_COLUMN: str = 'case_id'
|
|
||||||
IMAGE_COLUMN: str = 'image_path'
|
IMAGE_COLUMN: str = 'image_path'
|
||||||
LABEL_COLUMN: str = 'label'
|
LABEL_COLUMN: str = 'label'
|
||||||
|
|
||||||
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
DEFAULT_CSV_FILENAME: str = "dataset.csv"
|
||||||
|
|
||||||
def __init__(self, root_dir: Union[str, Path],
|
def __init__(self, root: Union[str, Path],
|
||||||
dataset_csv: Optional[Union[str, Path]] = None,
|
dataset_csv: Optional[Union[str, Path]] = None,
|
||||||
dataset_df: Optional[pd.DataFrame] = None,) -> None:
|
dataset_df: Optional[pd.DataFrame] = None) -> None:
|
||||||
"""
|
"""
|
||||||
:param root: Root directory of the dataset.
|
:param root: Root directory of the dataset.
|
||||||
:param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from
|
:param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from
|
||||||
|
@ -36,27 +35,8 @@ class TcgaPradDataset(Dataset):
|
||||||
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
|
||||||
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
|
||||||
"""
|
"""
|
||||||
self.root_dir = Path(root_dir)
|
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
|
||||||
|
# Example of how to define a custom label column from existing columns:
|
||||||
if dataset_df is not None:
|
self.dataset_df[self.LABEL_COLUMN] = (self.dataset_df['label1']
|
||||||
self.dataset_csv = None
|
| self.dataset_df['label2']).astype(int)
|
||||||
else:
|
self.validate_columns()
|
||||||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
|
|
||||||
dataset_df = pd.read_csv(self.dataset_csv)
|
|
||||||
|
|
||||||
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
|
|
||||||
dataset_df[self.LABEL_COLUMN] = (dataset_df['label1_mutation']
|
|
||||||
| dataset_df['label2_mutation']).astype(int)
|
|
||||||
self.dataset_df = dataset_df
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.dataset_df.shape[0]
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
|
||||||
slide_id = self.dataset_df.index[index]
|
|
||||||
sample = {
|
|
||||||
self.SLIDE_ID_COLUMN: slide_id,
|
|
||||||
**self.dataset_df.loc[slide_id].to_dict()
|
|
||||||
}
|
|
||||||
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
|
|
||||||
return sample
|
|
||||||
|
|
|
@ -0,0 +1,230 @@
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""This script is specific to PANDA and is kept only for retrocompatibility.
|
||||||
|
`create_tiles_dataset.py` is the new supported way to process slide datasets.
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from monai.data import Dataset
|
||||||
|
from monai.data.image_reader import WSIReader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from InnerEye.ML.Histopathology.preprocessing import tiling
|
||||||
|
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||||
|
|
||||||
|
|
||||||
|
CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
|
||||||
|
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
|
||||||
|
TMP_SUFFIX = "_tmp"
|
||||||
|
|
||||||
|
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
|
||||||
|
-> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]:
|
||||||
|
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
||||||
|
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
||||||
|
foreground_mask = mask_tile > 0
|
||||||
|
occupancy = foreground_mask.mean(axis=(-2, -1))
|
||||||
|
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
||||||
|
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
||||||
|
|
||||||
|
|
||||||
|
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
||||||
|
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
||||||
|
pil_image = PIL.Image.fromarray(array_hwc)
|
||||||
|
pil_image.convert('RGB').save(path)
|
||||||
|
return pil_image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
|
||||||
|
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
||||||
|
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
|
||||||
|
constant_values=255)
|
||||||
|
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)
|
||||||
|
|
||||||
|
selected: np.ndarray
|
||||||
|
occupancies: np.ndarray
|
||||||
|
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
|
||||||
|
n_discarded = (~selected).sum()
|
||||||
|
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")
|
||||||
|
|
||||||
|
image_tiles = image_tiles[selected]
|
||||||
|
mask_tiles = mask_tiles[selected]
|
||||||
|
tile_locations = tile_locations[selected]
|
||||||
|
occupancies = occupancies[selected]
|
||||||
|
|
||||||
|
abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)
|
||||||
|
|
||||||
|
return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded
|
||||||
|
|
||||||
|
|
||||||
|
# TODO refactor this to separate metadata identification from saving. We might want the metadata
|
||||||
|
# even if the saving fails
|
||||||
|
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
|
||||||
|
tile_location: Sequence[int], output_dir: Path) -> dict:
|
||||||
|
slide_id = sample['image_id']
|
||||||
|
descriptor = get_tile_descriptor(tile_location)
|
||||||
|
image_tile_filename = f"train_images/{descriptor}.png"
|
||||||
|
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"
|
||||||
|
|
||||||
|
save_image(image_tile, output_dir / image_tile_filename)
|
||||||
|
save_image(mask_tile, output_dir / mask_tile_filename)
|
||||||
|
|
||||||
|
tile_metadata = {
|
||||||
|
'slide_id': slide_id,
|
||||||
|
'tile_id': get_tile_id(slide_id, tile_location),
|
||||||
|
'image': image_tile_filename,
|
||||||
|
'mask': mask_tile_filename,
|
||||||
|
'tile_x': tile_location[0],
|
||||||
|
'tile_y': tile_location[1],
|
||||||
|
'data_provider': sample['data_provider'],
|
||||||
|
'slide_isup_grade': sample['isup_grade'],
|
||||||
|
'slide_gleason_score': sample['gleason_score'],
|
||||||
|
}
|
||||||
|
|
||||||
|
return tile_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||||
|
output_dir: Path, tile_progress: bool = False) -> None:
|
||||||
|
slide_id = sample['image_id']
|
||||||
|
slide_dir: Path = output_dir / (slide_id + "/")
|
||||||
|
logging.info(f">>> Slide dir {slide_dir}")
|
||||||
|
if slide_dir.exists(): # already processed slide - skip
|
||||||
|
logging.info(f">>> Skipping {slide_dir} - already processed")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
slide_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
dataset_csv_path = slide_dir / "dataset.csv"
|
||||||
|
dataset_csv_file = dataset_csv_path.open('w')
|
||||||
|
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||||
|
|
||||||
|
tiles_failure = 0
|
||||||
|
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
||||||
|
failed_tiles_file = failed_tiles_csv_path.open('w')
|
||||||
|
failed_tiles_file.write('tile_id' + '\n')
|
||||||
|
|
||||||
|
logging.info(f"Loading slide {slide_id} ...")
|
||||||
|
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
|
||||||
|
sample = loader(sample) # load 'image' and 'mask' from disk
|
||||||
|
|
||||||
|
logging.info(f"Tiling slide {slide_id} ...")
|
||||||
|
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
|
||||||
|
generate_tiles(sample, tile_size, occupancy_threshold)
|
||||||
|
n_tiles = image_tiles.shape[0]
|
||||||
|
|
||||||
|
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
||||||
|
try:
|
||||||
|
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
|
||||||
|
slide_dir)
|
||||||
|
tile_metadata['occupancy'] = occupancies[i]
|
||||||
|
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
|
||||||
|
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
|
||||||
|
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
|
||||||
|
dataset_csv_file.write(dataset_row + '\n')
|
||||||
|
except Exception as e:
|
||||||
|
tiles_failure += 1
|
||||||
|
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
|
||||||
|
failed_tiles_file.write(descriptor)
|
||||||
|
traceback.print_exc()
|
||||||
|
warnings.warn(f"An error occurred while saving tile "
|
||||||
|
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
||||||
|
|
||||||
|
dataset_csv_file.close()
|
||||||
|
failed_tiles_file.close()
|
||||||
|
if tiles_failure > 0:
|
||||||
|
# TODO what we want to do with slides that have some failed tiles?
|
||||||
|
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||||
|
full_csv = dataset_dir / "dataset.csv"
|
||||||
|
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
|
||||||
|
# and it seems to find many more files
|
||||||
|
# print("List of files")
|
||||||
|
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
|
||||||
|
with full_csv.open('w') as full_csv_file:
|
||||||
|
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||||
|
first_file = True
|
||||||
|
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
|
||||||
|
logging.info(f"Merging slide {slide_csv}")
|
||||||
|
content = slide_csv.read_text()
|
||||||
|
if not first_file:
|
||||||
|
content = content[content.index('\n') + 1:] # discard header row for all but the first file
|
||||||
|
full_csv_file.write(content)
|
||||||
|
first_file = False
|
||||||
|
return full_csv
|
||||||
|
|
||||||
|
|
||||||
|
def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
|
||||||
|
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:
|
||||||
|
|
||||||
|
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
||||||
|
# to select a subsample use keyword n_slides
|
||||||
|
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore
|
||||||
|
|
||||||
|
output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
|
||||||
|
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")
|
||||||
|
|
||||||
|
if overwrite and output_dir.exists():
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
||||||
|
|
||||||
|
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
||||||
|
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
||||||
|
tile_progress=not parallel)
|
||||||
|
|
||||||
|
if parallel:
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
pool = multiprocessing.Pool()
|
||||||
|
map_func = pool.imap_unordered # type: ignore
|
||||||
|
else:
|
||||||
|
map_func = map # type: ignore
|
||||||
|
|
||||||
|
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
|
||||||
|
|
||||||
|
if parallel:
|
||||||
|
pool.close()
|
||||||
|
|
||||||
|
logging.info("Merging slide files in a single file")
|
||||||
|
merge_dataset_csv_files(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main(panda_dir="/tmp/datasets/PANDA",
|
||||||
|
root_output_dir="/datadrive",
|
||||||
|
level=1,
|
||||||
|
tile_size=224,
|
||||||
|
margin=64,
|
||||||
|
occupancy_threshold=0.05,
|
||||||
|
parallel=True,
|
||||||
|
overwrite=False)
|
|
@ -4,13 +4,12 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
|
@ -18,37 +17,43 @@ from monai.data import Dataset
|
||||||
from monai.data.image_reader import WSIReader
|
from monai.data.image_reader import WSIReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||||
from InnerEye.ML.Histopathology.preprocessing import tiling
|
from InnerEye.ML.Histopathology.preprocessing import tiling
|
||||||
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, segment_foreground
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey, TileKey
|
||||||
|
|
||||||
CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
|
|
||||||
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
|
|
||||||
TMP_SUFFIX = "_tmp"
|
|
||||||
|
|
||||||
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
|
def select_tiles(foreground_mask: np.ndarray, occupancy_threshold: float) \
|
||||||
-> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]:
|
-> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Exclude tiles that are mostly background based on estimated occupancy.
|
||||||
|
|
||||||
|
:param foreground_mask: Boolean array of shape (*, H, W).
|
||||||
|
:param occupancy_threshold: Tiles with lower occupancy (between 0 and 1) will be discarded.
|
||||||
|
:return: A tuple containing which tiles were selected and the estimated occupancies. These will
|
||||||
|
be boolean and float arrays of shape (*,), or scalars if `foreground_mask` is a single tile.
|
||||||
|
"""
|
||||||
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
||||||
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
||||||
foreground_mask = mask_tile > 0
|
|
||||||
occupancy = foreground_mask.mean(axis=(-2, -1))
|
occupancy = foreground_mask.mean(axis=(-2, -1))
|
||||||
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()
|
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
||||||
|
"""Format the XY tile coordinates into a tile descriptor."""
|
||||||
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
||||||
|
|
||||||
|
|
||||||
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
||||||
|
"""Format the slide ID and XY tile coordinates into a unique tile ID."""
|
||||||
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
||||||
|
|
||||||
|
|
||||||
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
||||||
|
"""Save an image array in (C, H, W) format to disk."""
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
||||||
pil_image = PIL.Image.fromarray(array_hwc)
|
pil_image = PIL.Image.fromarray(array_hwc)
|
||||||
|
@ -56,59 +61,102 @@ def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
|
||||||
return pil_image
|
return pil_image
|
||||||
|
|
||||||
|
|
||||||
def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
|
def generate_tiles(slide_image: np.ndarray, tile_size: int, foreground_threshold: float,
|
||||||
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
occupancy_threshold: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
|
||||||
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
|
"""Split the foreground of an input slide image into tiles.
|
||||||
constant_values=255)
|
|
||||||
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)
|
|
||||||
|
|
||||||
selected: np.ndarray
|
:param slide_image: The RGB image array in (C, H, W) format.
|
||||||
occupancies: np.ndarray
|
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||||
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
|
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||||
|
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||||
|
:return: A tuple containing the image tiles (N, C, H, W), tile coordinates (N, 2), occupancies
|
||||||
|
(N,), and total number of discarded empty tiles.
|
||||||
|
"""
|
||||||
|
image_tiles, tile_locations = tiling.tile_array_2d(slide_image, tile_size=tile_size,
|
||||||
|
constant_values=255)
|
||||||
|
foreground_mask, _ = segment_foreground(image_tiles, foreground_threshold)
|
||||||
|
|
||||||
|
selected, occupancies = select_tiles(foreground_mask, occupancy_threshold)
|
||||||
n_discarded = (~selected).sum()
|
n_discarded = (~selected).sum()
|
||||||
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")
|
logging.info(f"Percentage tiles discarded: {n_discarded / len(selected) * 100:.2f}")
|
||||||
|
|
||||||
image_tiles = image_tiles[selected]
|
image_tiles = image_tiles[selected]
|
||||||
mask_tiles = mask_tiles[selected]
|
|
||||||
tile_locations = tile_locations[selected]
|
tile_locations = tile_locations[selected]
|
||||||
occupancies = occupancies[selected]
|
occupancies = occupancies[selected]
|
||||||
|
|
||||||
abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)
|
return image_tiles, tile_locations, occupancies, n_discarded
|
||||||
|
|
||||||
return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded
|
|
||||||
|
|
||||||
|
|
||||||
# TODO refactor this to separate metadata identification from saving. We might want the metadata
|
def get_tile_info(sample: Dict[SlideKey, Any], occupancy: float, tile_location: Sequence[int],
|
||||||
# even if the saving fails
|
rel_slide_dir: Path) -> Dict[TileKey, Any]:
|
||||||
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
|
"""Map slide information and tiling outputs into tile-specific information dictionary.
|
||||||
tile_location: Sequence[int], output_dir: Path) -> dict:
|
|
||||||
slide_id = sample['image_id']
|
:param sample: Slide dictionary.
|
||||||
|
:param occupancy: Estimated tile foreground occuppancy.
|
||||||
|
:param tile_location: Tile XY coordinates.
|
||||||
|
:param rel_slide_dir: Directory where tiles are saved, relative to dataset root.
|
||||||
|
:return: Tile information dictionary.
|
||||||
|
"""
|
||||||
|
slide_id = sample[SlideKey.SLIDE_ID]
|
||||||
descriptor = get_tile_descriptor(tile_location)
|
descriptor = get_tile_descriptor(tile_location)
|
||||||
image_tile_filename = f"train_images/{descriptor}.png"
|
rel_image_path = f"{rel_slide_dir}/{descriptor}.png"
|
||||||
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"
|
|
||||||
|
|
||||||
save_image(image_tile, output_dir / image_tile_filename)
|
tile_info = {
|
||||||
save_image(mask_tile, output_dir / mask_tile_filename)
|
TileKey.SLIDE_ID: slide_id,
|
||||||
|
TileKey.TILE_ID: get_tile_id(slide_id, tile_location),
|
||||||
tile_metadata = {
|
TileKey.IMAGE: rel_image_path,
|
||||||
'slide_id': slide_id,
|
TileKey.LABEL: sample[SlideKey.LABEL],
|
||||||
'tile_id': get_tile_id(slide_id, tile_location),
|
TileKey.TILE_X: tile_location[0],
|
||||||
'image': image_tile_filename,
|
TileKey.TILE_Y: tile_location[1],
|
||||||
'mask': mask_tile_filename,
|
TileKey.OCCUPANCY: occupancy,
|
||||||
'tile_x': tile_location[0],
|
TileKey.SLIDE_METADATA: {TileKey.from_slide_metadata_key(key): value
|
||||||
'tile_y': tile_location[1],
|
for key, value in sample[SlideKey.METADATA].items()}
|
||||||
'data_provider': sample['data_provider'],
|
|
||||||
'slide_isup_grade': sample['isup_grade'],
|
|
||||||
'slide_gleason_score': sample['gleason_score'],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tile_metadata
|
return tile_info
|
||||||
|
|
||||||
|
|
||||||
def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
def format_csv_row(tile_info: Dict[TileKey, Any], keys_to_save: Iterable[TileKey],
|
||||||
output_dir: Path, tile_progress: bool = False) -> None:
|
metadata_keys: Iterable[str]) -> str:
|
||||||
slide_id = sample['image_id']
|
"""Format tile information dictionary as a row to write to a dataset CSV tile.
|
||||||
slide_dir: Path = output_dir / (slide_id + "/")
|
|
||||||
|
:param tile_info: Tile information dictionary.
|
||||||
|
:param keys_to_save: Which main keys to include in the row, and in which order.
|
||||||
|
:param metadata_keys: Likewise for metadata keys.
|
||||||
|
:return: The formatted CSV row.
|
||||||
|
"""
|
||||||
|
tile_slide_metadata = tile_info.pop(TileKey.SLIDE_METADATA)
|
||||||
|
fields = [str(tile_info[key]) for key in keys_to_save]
|
||||||
|
fields.extend(str(tile_slide_metadata[key]) for key in metadata_keys)
|
||||||
|
dataset_row = ','.join(fields)
|
||||||
|
return dataset_row
|
||||||
|
|
||||||
|
|
||||||
|
def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_size: int,
|
||||||
|
foreground_threshold: Optional[float], occupancy_threshold: float, output_dir: Path,
|
||||||
|
tile_progress: bool = False) -> None:
|
||||||
|
"""Load and process a slide, saving tile images and information to a CSV file.
|
||||||
|
|
||||||
|
:param sample: Slide information dictionary, returned by the input slide dataset.
|
||||||
|
:param level: Magnification level at which to process the slide.
|
||||||
|
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
|
||||||
|
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||||
|
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||||
|
If `None` (default), an optimal threshold will be estimated automatically.
|
||||||
|
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||||
|
:param output_dir: Root directory for the output dataset; outputs for a single slide will be
|
||||||
|
saved inside `output_dir/slide_id/`.
|
||||||
|
:param tile_progress: Whether to display a progress bar in the terminal.
|
||||||
|
"""
|
||||||
|
slide_metadata: Dict[str, Any] = sample[SlideKey.METADATA]
|
||||||
|
keys_to_save = (TileKey.SLIDE_ID, TileKey.TILE_ID, TileKey.IMAGE, TileKey.LABEL,
|
||||||
|
TileKey.TILE_X, TileKey.TILE_Y, TileKey.OCCUPANCY)
|
||||||
|
metadata_keys = tuple(TileKey.from_slide_metadata_key(key) for key in slide_metadata)
|
||||||
|
csv_columns: Tuple[str, ...] = (*keys_to_save, *metadata_keys)
|
||||||
|
|
||||||
|
slide_id: str = sample[SlideKey.SLIDE_ID]
|
||||||
|
rel_slide_dir = Path(slide_id)
|
||||||
|
slide_dir = output_dir / rel_slide_dir
|
||||||
logging.info(f">>> Slide dir {slide_dir}")
|
logging.info(f">>> Slide dir {slide_dir}")
|
||||||
if slide_dir.exists(): # already processed slide - skip
|
if slide_dir.exists(): # already processed slide - skip
|
||||||
logging.info(f">>> Skipping {slide_dir} - already processed")
|
logging.info(f">>> Skipping {slide_dir} - already processed")
|
||||||
|
@ -119,50 +167,57 @@ def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupan
|
||||||
|
|
||||||
dataset_csv_path = slide_dir / "dataset.csv"
|
dataset_csv_path = slide_dir / "dataset.csv"
|
||||||
dataset_csv_file = dataset_csv_path.open('w')
|
dataset_csv_file = dataset_csv_path.open('w')
|
||||||
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
dataset_csv_file.write(','.join(csv_columns) + '\n') # write CSV header
|
||||||
|
|
||||||
tiles_failure = 0
|
n_failed_tiles = 0
|
||||||
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
||||||
failed_tiles_file = failed_tiles_csv_path.open('w')
|
failed_tiles_file = failed_tiles_csv_path.open('w')
|
||||||
failed_tiles_file.write('tile_id' + '\n')
|
failed_tiles_file.write('tile_id' + '\n')
|
||||||
|
|
||||||
logging.info(f"Loading slide {slide_id} ...")
|
logging.info(f"Loading slide {slide_id} ...")
|
||||||
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
|
loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin,
|
||||||
sample = loader(sample) # load 'image' and 'mask' from disk
|
foreground_threshold=foreground_threshold)
|
||||||
|
sample = loader(sample) # load 'image' from disk
|
||||||
|
|
||||||
logging.info(f"Tiling slide {slide_id} ...")
|
logging.info(f"Tiling slide {slide_id} ...")
|
||||||
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
|
image_tiles, rel_tile_locations, occupancies, _ = \
|
||||||
generate_tiles(sample, tile_size, occupancy_threshold)
|
generate_tiles(sample[SlideKey.IMAGE], tile_size,
|
||||||
|
sample[SlideKey.FOREGROUND_THRESHOLD],
|
||||||
|
occupancy_threshold)
|
||||||
|
|
||||||
|
tile_locations = (sample[SlideKey.SCALE] * rel_tile_locations
|
||||||
|
+ sample[SlideKey.ORIGIN]).astype(int)
|
||||||
|
|
||||||
n_tiles = image_tiles.shape[0]
|
n_tiles = image_tiles.shape[0]
|
||||||
|
|
||||||
|
logging.info(f"Saving tiles for slide {slide_id} ...")
|
||||||
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
||||||
try:
|
try:
|
||||||
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
|
tile_info = get_tile_info(sample, occupancies[i], tile_locations[i], rel_slide_dir)
|
||||||
slide_dir)
|
save_image(image_tiles[i], output_dir / tile_info[TileKey.IMAGE])
|
||||||
tile_metadata['occupancy'] = occupancies[i]
|
dataset_row = format_csv_row(tile_info, keys_to_save, metadata_keys)
|
||||||
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
|
|
||||||
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
|
|
||||||
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
|
|
||||||
dataset_csv_file.write(dataset_row + '\n')
|
dataset_csv_file.write(dataset_row + '\n')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tiles_failure += 1
|
n_failed_tiles += 1
|
||||||
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
|
descriptor = get_tile_descriptor(tile_locations[i])
|
||||||
failed_tiles_file.write(descriptor)
|
failed_tiles_file.write(descriptor + '\n')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
warnings.warn(f"An error occurred while saving tile "
|
warnings.warn(f"An error occurred while saving tile "
|
||||||
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
||||||
|
|
||||||
dataset_csv_file.close()
|
dataset_csv_file.close()
|
||||||
failed_tiles_file.close()
|
failed_tiles_file.close()
|
||||||
if tiles_failure > 0:
|
if n_failed_tiles > 0:
|
||||||
# TODO what we want to do with slides that have some failed tiles?
|
# TODO what we want to do with slides that have some failed tiles?
|
||||||
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
|
logging.warning(f"{slide_id} is incomplete. {n_failed_tiles} tiles failed.")
|
||||||
|
logging.info(f"Finished processing slide {slide_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||||
|
"""Combines all "*/dataset.csv" files into a single "dataset.csv" file in the given directory."""
|
||||||
full_csv = dataset_dir / "dataset.csv"
|
full_csv = dataset_dir / "dataset.csv"
|
||||||
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
|
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
|
||||||
# and it seems to find many more files
|
# and it seems to find many more files
|
||||||
|
@ -181,21 +236,40 @@ def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||||
return full_csv
|
return full_csv
|
||||||
|
|
||||||
|
|
||||||
def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
|
def main(slides_dataset: SlidesDataset, root_output_dir: Union[str, Path],
|
||||||
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:
|
level: int, tile_size: int, margin: int, foreground_threshold: Optional[float],
|
||||||
|
occupancy_threshold: float, parallel: bool = False, overwrite: bool = False,
|
||||||
|
n_slides: Optional[int] = None) -> None:
|
||||||
|
"""Process a slides dataset to produce a tiles dataset.
|
||||||
|
|
||||||
|
:param slides_dataset: Input tiles dataset object.
|
||||||
|
:param root_output_dir: The root directory of the output tiles dataset.
|
||||||
|
:param level: Magnification level at which to process the slide.
|
||||||
|
:param tile_size: Lateral dimensions of each tile, in pixels.
|
||||||
|
:param margin: Margin around the foreground bounding box, in pixels at lowest resolution.
|
||||||
|
:param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy.
|
||||||
|
If `None` (default), an optimal threshold will be estimated automatically.
|
||||||
|
:param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard.
|
||||||
|
:param parallel: Whether slides should be processed in parallel with multiprocessing.
|
||||||
|
:param overwrite: Whether to overwrite an existing output tiles dataset. If `True`, will delete
|
||||||
|
and recreate `root_output_dir`, otherwise will resume by skipping already processed slides.
|
||||||
|
:param n_slides: If given, limit the total number of slides for debugging.
|
||||||
|
"""
|
||||||
|
|
||||||
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
||||||
# to select a subsample use keyword n_slides
|
# to select a subsample use keyword n_slides
|
||||||
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore
|
dataset = Dataset(slides_dataset)[:n_slides] # type: ignore
|
||||||
|
|
||||||
output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
|
output_dir = Path(root_output_dir)
|
||||||
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")
|
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} "
|
||||||
|
f"{slides_dataset.__class__.__name__} tiles at: {output_dir}")
|
||||||
|
|
||||||
if overwrite and output_dir.exists():
|
if overwrite and output_dir.exists():
|
||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
||||||
|
|
||||||
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
||||||
|
foreground_threshold=foreground_threshold,
|
||||||
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
||||||
tile_progress=not parallel)
|
tile_progress=not parallel)
|
||||||
|
|
||||||
|
@ -217,11 +291,16 @@ def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(panda_dir="/tmp/datasets/PANDA",
|
from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset
|
||||||
root_output_dir="/datadrive",
|
|
||||||
level=1,
|
# Example set up for an existing slides dataset:
|
||||||
|
main(slides_dataset=TcgaPradDataset("/tmp/datasets/TCGA-PRAD"),
|
||||||
|
root_output_dir="/datadrive/TCGA-PRAD_tiles",
|
||||||
|
n_slides=5,
|
||||||
|
level=3,
|
||||||
tile_size=224,
|
tile_size=224,
|
||||||
margin=64,
|
margin=64,
|
||||||
|
foreground_threshold=None,
|
||||||
occupancy_threshold=0.05,
|
occupancy_threshold=0.05,
|
||||||
parallel=True,
|
parallel=False,
|
||||||
overwrite=False)
|
overwrite=True)
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import skimage.filters
|
||||||
|
from cucim import CuImage
|
||||||
|
from health_ml.utils import box_utils
|
||||||
|
from monai.data.image_reader import WSIReader
|
||||||
|
from monai.transforms import MapTransform
|
||||||
|
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||||
|
|
||||||
|
|
||||||
|
def get_luminance(slide: np.ndarray) -> np.ndarray:
|
||||||
|
"""Compute a grayscale version of the input slide.
|
||||||
|
|
||||||
|
:param slide: The RGB image array in (*, C, H, W) format.
|
||||||
|
:return: The single-channel luminance array as (*, H, W).
|
||||||
|
"""
|
||||||
|
# TODO: Consider more sophisticated luminance calculation if necessary
|
||||||
|
return slide.mean(axis=-3) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
|
||||||
|
-> Tuple[np.ndarray, float]:
|
||||||
|
"""Segment the given slide by thresholding its luminance.
|
||||||
|
|
||||||
|
:param slide: The RGB image array in (*, C, H, W) format.
|
||||||
|
:param threshold: Pixels with luminance below this value will be considered foreground.
|
||||||
|
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
|
||||||
|
:return: A tuple containing the boolean output array in (*, H, W) format and the threshold used.
|
||||||
|
"""
|
||||||
|
luminance = get_luminance(slide)
|
||||||
|
if threshold is None:
|
||||||
|
threshold = skimage.filters.threshold_otsu(luminance)
|
||||||
|
return luminance < threshold, threshold
|
||||||
|
|
||||||
|
|
||||||
|
def load_slide_at_level(reader: WSIReader, slide_obj: CuImage, level: int) -> np.ndarray:
|
||||||
|
"""Load full slide array at the given magnification level.
|
||||||
|
|
||||||
|
This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415)
|
||||||
|
fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417).
|
||||||
|
|
||||||
|
:param reader: A MONAI `WSIReader` using cuCIM backend.
|
||||||
|
:param slide_obj: The cuCIM image object returned by `reader.read(<image_file>)`.
|
||||||
|
:param level: Index of the desired magnification level as defined in the `slide_obj` headers.
|
||||||
|
:return: The loaded image array in (C, H, W) format.
|
||||||
|
"""
|
||||||
|
size = slide_obj.resolutions['level_dimensions'][level][::-1]
|
||||||
|
slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image
|
||||||
|
return slide
|
||||||
|
|
||||||
|
|
||||||
|
class LoadROId(MapTransform):
|
||||||
|
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI).
|
||||||
|
|
||||||
|
Operates on dictionaries, replacing the file path in `image_key` with the loaded array in
|
||||||
|
(C, H, W) format. Also adds the following entries:
|
||||||
|
- `SlideKey.ORIGIN` (tuple): top-right coordinates of the bounding box
|
||||||
|
- `SlideKey.SCALE` (float): corresponding scale, loaded from the file
|
||||||
|
- `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground
|
||||||
|
"""
|
||||||
|
def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0,
|
||||||
|
margin: int = 0, foreground_threshold: Optional[float] = None) -> None:
|
||||||
|
"""
|
||||||
|
:param reader: And instance of MONAI's `WSIReader`.
|
||||||
|
:param image_key: Image key in the input and output dictionaries.
|
||||||
|
:param level: Magnification level to load from the raw multi-scale file.
|
||||||
|
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||||
|
:param foreground_threshold: Pixels with luminance below this value will be considered foreground.
|
||||||
|
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
|
||||||
|
"""
|
||||||
|
super().__init__([image_key], allow_missing_keys=False)
|
||||||
|
self.reader = reader
|
||||||
|
self.image_key = image_key
|
||||||
|
self.level = level
|
||||||
|
self.margin = margin
|
||||||
|
self.foreground_threshold = foreground_threshold
|
||||||
|
|
||||||
|
def _get_bounding_box(self, slide_obj: CuImage) -> Tuple[box_utils.Box, float]:
|
||||||
|
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||||
|
highest_level = slide_obj.resolutions['level_count'] - 1
|
||||||
|
scale = slide_obj.resolutions['level_downsamples'][highest_level]
|
||||||
|
slide = load_slide_at_level(self.reader, slide_obj, level=highest_level)
|
||||||
|
|
||||||
|
foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold)
|
||||||
|
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
|
||||||
|
return bbox, threshold
|
||||||
|
|
||||||
|
def __call__(self, data: Dict) -> Dict:
|
||||||
|
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||||
|
|
||||||
|
level0_bbox, threshold = self._get_bounding_box(image_obj)
|
||||||
|
|
||||||
|
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
||||||
|
# but relative region size in pixels at the chosen level
|
||||||
|
origin = (level0_bbox.x, level0_bbox.y)
|
||||||
|
scale = image_obj.resolutions['level_downsamples'][self.level]
|
||||||
|
scaled_bbox = level0_bbox / scale
|
||||||
|
|
||||||
|
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
|
||||||
|
size=(scaled_bbox.w, scaled_bbox.h))
|
||||||
|
data[SlideKey.ORIGIN] = origin
|
||||||
|
data[SlideKey.SCALE] = scale
|
||||||
|
data[SlideKey.FOREGROUND_THRESHOLD] = threshold
|
||||||
|
|
||||||
|
image_obj.close()
|
||||||
|
return data
|
|
@ -1,61 +0,0 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
||||||
# ------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
"""
|
|
||||||
This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the
|
|
||||||
main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script
|
|
||||||
is the ability to submit to a cluster on azureml and to have the output files directly saved as a registered dataset.
|
|
||||||
|
|
||||||
To run execute, from inside the pre-processing folder,
|
|
||||||
python azure_tiles_creation.py --azureml
|
|
||||||
|
|
||||||
A json configuration file containing the credentials to the Azure workspace and an environment.yml file are expected
|
|
||||||
in input.
|
|
||||||
|
|
||||||
This has been tested on hi-mlv0.1.4.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
current_file = Path(__file__)
|
|
||||||
radiomics_root = current_file.absolute().parent.parent.parent.parent.parent
|
|
||||||
sys.path.append(str(radiomics_root))
|
|
||||||
from health_azure.himl import submit_to_azure_if_needed, DatasetConfig # noqa
|
|
||||||
from InnerEye.ML.Histopathology.preprocessing.create_tiles_dataset import main # noqa
|
|
||||||
|
|
||||||
# Pre-built environment file that contains all the requirements (RadiomicsNN + histo)
|
|
||||||
# Assuming ENV_NAME is a complete environment, `conda env export -n ENV_NAME -f ENV_NAME.yml` will create the desired file
|
|
||||||
ENVIRONMENT_FILE = radiomics_root.joinpath(Path("/envs/innereyeprivatetiles.yml"))
|
|
||||||
DATASET_NAME = "PANDA_tiles"
|
|
||||||
timestr = time.strftime("%Y%m%d-%H%M%S")
|
|
||||||
folder_name = DATASET_NAME + '_' + timestr
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print(f"Running {str(current_file)}")
|
|
||||||
input_dataset = DatasetConfig(name="PANDA", datastore="innereyedatasets", local_folder=Path("/tmp/datasets/PANDA"), use_mounting=True)
|
|
||||||
output_dataset = DatasetConfig(name=DATASET_NAME, datastore="innereyedatasets", local_folder=Path("/datadrive/"), use_mounting=True)
|
|
||||||
run_info = submit_to_azure_if_needed(entry_script=current_file,
|
|
||||||
snapshot_root_directory=radiomics_root,
|
|
||||||
workspace_config_file=Path("config.json"),
|
|
||||||
compute_cluster_name='training-pr-nc12', # training-nd24
|
|
||||||
default_datastore="innereyedatasets",
|
|
||||||
conda_environment_file=Path(ENVIRONMENT_FILE),
|
|
||||||
input_datasets=[input_dataset],
|
|
||||||
output_datasets=[output_dataset],
|
|
||||||
)
|
|
||||||
input_folder = run_info.input_datasets[0]
|
|
||||||
output_folder = Path(run_info.output_datasets[0], folder_name)
|
|
||||||
print(f'This will be the final ouput folder {str(output_folder)}')
|
|
||||||
|
|
||||||
main(panda_dir=str(input_folder),
|
|
||||||
root_output_dir=str(output_folder),
|
|
||||||
level=1,
|
|
||||||
tile_size=224,
|
|
||||||
margin=64,
|
|
||||||
occupancy_threshold=0.05,
|
|
||||||
parallel=True,
|
|
||||||
overwrite=False)
|
|
|
@ -5,6 +5,41 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SlideKey(str, Enum):
|
||||||
|
SLIDE_ID = 'slide_id'
|
||||||
|
IMAGE = 'image'
|
||||||
|
IMAGE_PATH = 'image_path'
|
||||||
|
MASK = 'mask'
|
||||||
|
MASK_PATH = 'mask_path'
|
||||||
|
LABEL = 'label'
|
||||||
|
SPLIT = 'split'
|
||||||
|
SCALE = 'scale'
|
||||||
|
ORIGIN = 'origin'
|
||||||
|
FOREGROUND_THRESHOLD = 'foreground_threshold'
|
||||||
|
METADATA = 'metadata'
|
||||||
|
|
||||||
|
|
||||||
|
class TileKey(str, Enum):
|
||||||
|
TILE_ID = 'tile_id'
|
||||||
|
SLIDE_ID = 'slide_id'
|
||||||
|
IMAGE = 'image'
|
||||||
|
IMAGE_PATH = 'image_path'
|
||||||
|
MASK = 'mask'
|
||||||
|
MASK_PATH = 'mask_path'
|
||||||
|
LABEL = 'label'
|
||||||
|
SPLIT = 'split'
|
||||||
|
TILE_X = 'tile_x'
|
||||||
|
TILE_Y = 'tile_y'
|
||||||
|
OCCUPANCY = 'occupancy'
|
||||||
|
FOREGROUND_THRESHOLD = 'foreground_threshold'
|
||||||
|
SLIDE_METADATA = 'slide_metadata'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_slide_metadata_key(slide_metadata_key: str) -> str:
|
||||||
|
return 'slide_' + slide_metadata_key
|
||||||
|
|
||||||
|
|
||||||
class ResultsKey(str, Enum):
|
class ResultsKey(str, Enum):
|
||||||
SLIDE_ID = 'slide_id'
|
SLIDE_ID = 'slide_id'
|
||||||
TILE_ID = 'tile_id'
|
TILE_ID = 'tile_id'
|
||||||
|
|
|
@ -4,29 +4,32 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import matplotlib.pyplot as plt
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from monai.data.dataset import Dataset
|
||||||
from monai.data.image_reader import WSIReader
|
from monai.data.image_reader import WSIReader
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||||
|
|
||||||
|
|
||||||
def load_image_dict(sample: dict, level: int, margin: int) -> dict:
|
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
|
||||||
"""
|
"""
|
||||||
Load image from metadata dictionary
|
Load image from metadata dictionary
|
||||||
param sample: dict describing image metadata. Example:
|
:param sample: dict describing image metadata. Example:
|
||||||
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
|
{'image_id': ['1ca999adbbc948e69783686e5b5414e4'],
|
||||||
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
|
'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'],
|
||||||
'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'],
|
'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'],
|
||||||
'data_provider': ['karolinska'],
|
'data_provider': ['karolinska'],
|
||||||
'isup_grade': tensor([0]),
|
'isup_grade': tensor([0]),
|
||||||
'gleason_score': ['0+0']}
|
'gleason_score': ['0+0']}
|
||||||
param level: level of resolution to be loaded
|
:param level: level of resolution to be loaded
|
||||||
param margin: margin to be included
|
:param margin: margin to be included
|
||||||
return: a dict containing the image data and metadata
|
:return: a dict containing the image data and metadata
|
||||||
"""
|
"""
|
||||||
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
|
loader = LoadPandaROId(WSIReader('cuCIM'), level=level, margin=margin)
|
||||||
img = loader(sample)
|
img = loader(sample)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -34,25 +37,25 @@ def load_image_dict(sample: dict, level: int, margin: int) -> dict:
|
||||||
def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int,
|
def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int,
|
||||||
title_key: str = 'data_provider') -> None:
|
title_key: str = 'data_provider') -> None:
|
||||||
"""
|
"""
|
||||||
param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path.
|
:param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path.
|
||||||
Look at the PandaDataset for more detail
|
Look at the PandaDataset for more detail
|
||||||
param nsamples: number of random samples to be visualized
|
:param nsamples: number of random samples to be visualized
|
||||||
param ncols: number of columns in the figure grid. Nrows is automatically inferred
|
:param ncols: number of columns in the figure grid. Nrows is automatically inferred
|
||||||
param level: level of resolution to be loaded
|
:param level: level of resolution to be loaded
|
||||||
param margin: margin to be included
|
:param margin: margin to be included
|
||||||
param title_key: key in image_dict used to label each subplot
|
:param title_key: metadata key in image_dict used to label each subplot
|
||||||
"""
|
"""
|
||||||
panda_dataset = PandaDataset(root_dir=panda_dir, n_slides=nsamples)
|
panda_dataset = Dataset(PandaDataset(root=panda_dir))[:nsamples] # type: ignore
|
||||||
loader = DataLoader(panda_dataset, batch_size=1)
|
loader = DataLoader(panda_dataset, batch_size=1)
|
||||||
|
|
||||||
nrows = math.ceil(nsamples/ncols)
|
nrows = math.ceil(nsamples/ncols)
|
||||||
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9))
|
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9))
|
||||||
|
|
||||||
for dict_images, ax in zip(loader, axes.flat):
|
for dict_images, ax in zip(loader, axes.flat):
|
||||||
slide_id = dict_images['image_id']
|
slide_id = dict_images[SlideKey.SLIDE_ID]
|
||||||
title = dict_images[title_key]
|
title = dict_images[SlideKey.METADATA][title_key]
|
||||||
print(f">>> Slide {slide_id}")
|
print(f">>> Slide {slide_id}")
|
||||||
img = load_image_dict(dict_images, level=level, margin=margin)
|
img = load_image_dict(dict_images, level=level, margin=margin)
|
||||||
ax.imshow(img['image'].transpose(1, 2, 0))
|
ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0))
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.core.step_result import Result
|
from pytorch_lightning.utilities.data import extract_batch_size
|
||||||
|
|
||||||
from InnerEye.Common import common_util
|
from InnerEye.Common import common_util
|
||||||
from InnerEye.ML.config import PaddingMode, SegmentationModelBase
|
from InnerEye.ML.config import PaddingMode, SegmentationModelBase
|
||||||
|
@ -502,7 +502,7 @@ def test_sample_metadata_field() -> None:
|
||||||
assert SAMPLE_METADATA_FIELD in fields
|
assert SAMPLE_METADATA_FIELD in fields
|
||||||
# Lightning attempts to determine the batch size by trying to find a tensor field in the sample.
|
# Lightning attempts to determine the batch size by trying to find a tensor field in the sample.
|
||||||
# This only works if any field other than Metadata is first.
|
# This only works if any field other than Metadata is first.
|
||||||
assert Result.unpack_batch_size(fields) == batch_size
|
assert extract_batch_size(fields) == batch_size
|
||||||
|
|
||||||
|
|
||||||
def test_custom_collate() -> None:
|
def test_custom_collate() -> None:
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
|
||||||
|
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||||
|
|
||||||
|
HISTO_TEST_DATA_DIR = str(tests_root_directory("ML/histopathology/test_data"))
|
||||||
|
|
||||||
|
|
||||||
|
class MockSlidesDataset(SlidesDataset):
|
||||||
|
DEFAULT_CSV_FILENAME = "test_slides_dataset.csv"
|
||||||
|
METADATA_COLUMNS = ('meta1', 'meta2')
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(root=HISTO_TEST_DATA_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slides_dataset() -> None:
|
||||||
|
dataset = MockSlidesDataset()
|
||||||
|
assert isinstance(dataset.dataset_df, pd.DataFrame)
|
||||||
|
assert dataset.dataset_df.index.name == dataset.SLIDE_ID_COLUMN
|
||||||
|
assert len(dataset) == len(dataset.dataset_df)
|
||||||
|
|
||||||
|
sample = dataset[0]
|
||||||
|
assert isinstance(sample, dict)
|
||||||
|
assert all(isinstance(key, SlideKey) for key in sample)
|
||||||
|
|
||||||
|
expected_keys = [SlideKey.SLIDE_ID, SlideKey.IMAGE, SlideKey.IMAGE_PATH, SlideKey.LABEL,
|
||||||
|
SlideKey.METADATA]
|
||||||
|
assert all(key in sample for key in expected_keys)
|
||||||
|
|
||||||
|
image_path = sample[SlideKey.IMAGE_PATH]
|
||||||
|
assert isinstance(image_path, str)
|
||||||
|
assert os.path.isfile(image_path)
|
||||||
|
|
||||||
|
metadata = sample[SlideKey.METADATA]
|
||||||
|
assert isinstance(metadata, dict)
|
||||||
|
assert all(meta_col in metadata for meta_col in type(dataset).METADATA_COLUMNS)
|
|
@ -1,34 +0,0 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
||||||
# ------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_PRAD_DATASET_DIR
|
|
||||||
from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.path.isdir(TCGA_PRAD_DATASET_DIR),
|
|
||||||
reason="TCGA-PRAD dataset is unavailable")
|
|
||||||
def test_dataset() -> None:
|
|
||||||
dataset = TcgaPradDataset(TCGA_PRAD_DATASET_DIR)
|
|
||||||
|
|
||||||
expected_length = 449
|
|
||||||
assert len(dataset) == expected_length
|
|
||||||
|
|
||||||
expected_num_positives = 10
|
|
||||||
assert dataset.dataset_df[dataset.LABEL_COLUMN].sum() == expected_num_positives
|
|
||||||
|
|
||||||
sample = dataset[0]
|
|
||||||
assert isinstance(sample, dict)
|
|
||||||
|
|
||||||
expected_keys = [dataset.SLIDE_ID_COLUMN, dataset.CASE_ID_COLUMN,
|
|
||||||
dataset.IMAGE_COLUMN, dataset.LABEL_COLUMN]
|
|
||||||
assert all(key in sample for key in expected_keys)
|
|
||||||
|
|
||||||
image_path = sample[dataset.IMAGE_COLUMN]
|
|
||||||
assert isinstance(image_path, str)
|
|
||||||
assert os.path.isfile(image_path)
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from cucim import CuImage
|
||||||
|
from monai.data.image_reader import WSIReader
|
||||||
|
|
||||||
|
from InnerEye.Common.fixed_paths_for_tests import tests_root_directory
|
||||||
|
from InnerEye.ML.Histopathology.preprocessing.tiling import tile_array_2d
|
||||||
|
from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, get_luminance, load_slide_at_level, segment_foreground
|
||||||
|
from InnerEye.ML.Histopathology.utils.naming import SlideKey
|
||||||
|
from Tests.ML.histopathology.datasets.test_slides_dataset import MockSlidesDataset
|
||||||
|
|
||||||
|
TEST_IMAGE_PATH = str(tests_root_directory("ML/histopathology/test_data/panda_wsi_example.tiff"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_slide() -> None:
|
||||||
|
level = 2
|
||||||
|
reader = WSIReader('cuCIM')
|
||||||
|
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||||
|
dims = slide_obj.resolutions['level_dimensions'][level][::-1]
|
||||||
|
|
||||||
|
slide = load_slide_at_level(reader, slide_obj, level)
|
||||||
|
assert isinstance(slide, np.ndarray)
|
||||||
|
expected_shape = (3, *dims)
|
||||||
|
assert slide.shape == expected_shape
|
||||||
|
frac_empty = (slide == 0).mean()
|
||||||
|
assert frac_empty == 0.0
|
||||||
|
|
||||||
|
larger_dims = (2 * dims[0], 2 * dims[1])
|
||||||
|
larger_slide, _ = reader.get_data(slide_obj, size=larger_dims, level=level)
|
||||||
|
assert isinstance(larger_slide, np.ndarray)
|
||||||
|
assert larger_slide.shape == (3, *larger_dims)
|
||||||
|
# Overlapping parts match exactly
|
||||||
|
assert np.array_equal(larger_slide[:, :dims[0], :dims[1]], slide)
|
||||||
|
# Non-overlapping parts are all empty
|
||||||
|
empty_fill_value = 0 # fill value seems to depend on the image
|
||||||
|
assert np.array_equiv(larger_slide[:, dims[0]:, :], empty_fill_value)
|
||||||
|
assert np.array_equiv(larger_slide[:, :, dims[1]:], empty_fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_luminance() -> None:
|
||||||
|
level = 2 # here we only need to test at a single resolution
|
||||||
|
reader = WSIReader('cuCIM')
|
||||||
|
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||||
|
|
||||||
|
slide = load_slide_at_level(reader, slide_obj, level)
|
||||||
|
slide_luminance = get_luminance(slide)
|
||||||
|
assert isinstance(slide_luminance, np.ndarray)
|
||||||
|
assert slide_luminance.shape == slide.shape[1:]
|
||||||
|
assert (slide_luminance <= 255).all() and (slide_luminance >= 0).all()
|
||||||
|
|
||||||
|
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
|
||||||
|
tiles_luminance = get_luminance(tiles)
|
||||||
|
assert isinstance(tiles_luminance, np.ndarray)
|
||||||
|
assert tiles_luminance.shape == (tiles.shape[0], *tiles.shape[2:])
|
||||||
|
assert (tiles_luminance <= 255).all() and (tiles_luminance >= 0).all()
|
||||||
|
|
||||||
|
slide_luminance_tiles, _ = tile_array_2d(np.expand_dims(slide_luminance, axis=0),
|
||||||
|
tile_size=224, constant_values=255)
|
||||||
|
assert np.array_equal(slide_luminance_tiles.squeeze(1), tiles_luminance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_segment_foreground() -> None:
|
||||||
|
level = 2 # here we only need to test at a single resolution
|
||||||
|
reader = WSIReader('cuCIM')
|
||||||
|
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||||
|
slide = load_slide_at_level(reader, slide_obj, level)
|
||||||
|
|
||||||
|
auto_mask, auto_threshold = segment_foreground(slide, threshold=None)
|
||||||
|
assert isinstance(auto_mask, np.ndarray)
|
||||||
|
assert auto_mask.dtype == bool
|
||||||
|
assert auto_mask.shape == slide.shape[1:]
|
||||||
|
assert 0 < auto_mask.sum() < auto_mask.size # auto-seg should not produce trivial mask
|
||||||
|
luminance = get_luminance(slide)
|
||||||
|
assert luminance.min() < auto_threshold < luminance.max()
|
||||||
|
|
||||||
|
mask, returned_threshold = segment_foreground(slide, threshold=auto_threshold)
|
||||||
|
assert isinstance(mask, np.ndarray)
|
||||||
|
assert mask.dtype == bool
|
||||||
|
assert mask.shape == slide.shape[1:]
|
||||||
|
assert np.array_equal(mask, auto_mask)
|
||||||
|
assert returned_threshold == auto_threshold
|
||||||
|
|
||||||
|
tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255)
|
||||||
|
tiles_mask, _ = segment_foreground(tiles, threshold=auto_threshold)
|
||||||
|
assert isinstance(tiles_mask, np.ndarray)
|
||||||
|
assert tiles_mask.dtype == bool
|
||||||
|
assert tiles_mask.shape == (tiles.shape[0], *tiles.shape[2:])
|
||||||
|
|
||||||
|
slide_mask_tiles, _ = tile_array_2d(np.expand_dims(mask, axis=0),
|
||||||
|
tile_size=224, constant_values=False)
|
||||||
|
assert np.array_equal(slide_mask_tiles.squeeze(1), tiles_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('level', [1, 2])
|
||||||
|
@pytest.mark.parametrize('foreground_threshold', [None, 215])
|
||||||
|
def test_get_bounding_box(level: int, foreground_threshold: Optional[float]) -> None:
|
||||||
|
margin = 0
|
||||||
|
reader = WSIReader('cuCIM')
|
||||||
|
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||||
|
foreground_threshold=foreground_threshold)
|
||||||
|
slide_obj: CuImage = reader.read(TEST_IMAGE_PATH)
|
||||||
|
level0_bbox, _ = loader._get_bounding_box(slide_obj)
|
||||||
|
|
||||||
|
highest_level = slide_obj.resolutions['level_count'] - 1
|
||||||
|
# level = highest_level
|
||||||
|
slide = load_slide_at_level(reader, slide_obj, level=level)
|
||||||
|
scale = slide_obj.resolutions['level_downsamples'][level]
|
||||||
|
bbox = level0_bbox / scale
|
||||||
|
assert bbox.x >= 0 and bbox.y >= 0
|
||||||
|
assert bbox.x + bbox.w <= slide.shape[1]
|
||||||
|
assert bbox.y + bbox.h <= slide.shape[2]
|
||||||
|
|
||||||
|
# Now with nonzero margin
|
||||||
|
margin = 42
|
||||||
|
loader_margin = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||||
|
foreground_threshold=foreground_threshold)
|
||||||
|
level0_bbox_margin, _ = loader_margin._get_bounding_box(slide_obj)
|
||||||
|
# Here we test the box differences at the highest resolution, because margin is
|
||||||
|
# specified in low-res pixels. Otherwise could fail due to rounding error.
|
||||||
|
level0_scale: float = slide_obj.resolutions['level_downsamples'][highest_level]
|
||||||
|
level0_margin = int(level0_scale * margin)
|
||||||
|
assert level0_bbox_margin.x == level0_bbox.x - level0_margin
|
||||||
|
assert level0_bbox_margin.y == level0_bbox.y - level0_margin
|
||||||
|
assert level0_bbox_margin.w == level0_bbox.w + 2 * level0_margin
|
||||||
|
assert level0_bbox_margin.h == level0_bbox.h + 2 * level0_margin
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('level', [1, 2])
|
||||||
|
@pytest.mark.parametrize('margin', [0, 42])
|
||||||
|
@pytest.mark.parametrize('foreground_threshold', [None, 215])
|
||||||
|
def test_load_roi(level: int, margin: int, foreground_threshold: Optional[float]) -> None:
|
||||||
|
dataset = MockSlidesDataset()
|
||||||
|
sample = dataset[0]
|
||||||
|
reader = WSIReader('cuCIM')
|
||||||
|
loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin,
|
||||||
|
foreground_threshold=foreground_threshold)
|
||||||
|
loaded_sample = loader(sample)
|
||||||
|
assert isinstance(loaded_sample, dict)
|
||||||
|
# Check that none of the input keys were removed
|
||||||
|
assert all(key in loaded_sample for key in sample)
|
||||||
|
|
||||||
|
# Check that the expected new keys were inserted
|
||||||
|
additional_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD]
|
||||||
|
assert all(key in loaded_sample for key in additional_keys)
|
||||||
|
|
||||||
|
assert isinstance(loaded_sample[SlideKey.IMAGE], np.ndarray)
|
||||||
|
image_shape = loaded_sample[SlideKey.IMAGE].shape
|
||||||
|
assert len(image_shape)
|
||||||
|
assert image_shape[0] == 3
|
||||||
|
|
||||||
|
origin = loaded_sample[SlideKey.ORIGIN]
|
||||||
|
assert isinstance(origin, tuple)
|
||||||
|
assert len(origin) == 2
|
||||||
|
assert all(isinstance(coord, int) for coord in origin)
|
||||||
|
|
||||||
|
assert isinstance(loaded_sample[SlideKey.SCALE], (int, float))
|
||||||
|
assert loaded_sample[SlideKey.SCALE] >= 1.0
|
||||||
|
|
||||||
|
assert isinstance(loaded_sample[SlideKey.FOREGROUND_THRESHOLD], (int, float))
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:06eb0acaa2883181e9b6ab976863f71cc843a75ed9175fae8fe9b879635af1b0
|
||||||
|
size 816563
|
|
@ -0,0 +1,2 @@
|
||||||
|
slide_id,image,label,meta1,meta2
|
||||||
|
foo,panda_wsi_example.tiff,0,bar,baz
|
|
|
@ -20,6 +20,7 @@ dependencies:
|
||||||
- azureml-tensorboard==1.36.0
|
- azureml-tensorboard==1.36.0
|
||||||
- conda-merge==0.1.5
|
- conda-merge==0.1.5
|
||||||
- cryptography==3.3.2
|
- cryptography==3.3.2
|
||||||
|
- cucim==21.10.1; platform_system=="Linux"
|
||||||
- dataclasses-json==0.5.2
|
- dataclasses-json==0.5.2
|
||||||
- docker==4.3.1
|
- docker==4.3.1
|
||||||
- flake8==3.8.3
|
- flake8==3.8.3
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
testpaths=Tests TestsOutsidePackage TestSubmodule
|
testpaths=Tests TestsOutsidePackage TestSubmodule
|
||||||
norecursedirs=azure-pipelines docs datasets sphinx-docs InnerEye logs outputs test_data
|
norecursedirs=azure-pipelines docs sphinx-docs InnerEye logs outputs test_data Tests/ML/datasets
|
||||||
addopts=--strict-markers
|
addopts=--strict-markers
|
||||||
markers=
|
markers=
|
||||||
gpu: Test needs a GPU to run
|
gpu: Test needs a GPU to run
|
||||||
|
|
Загрузка…
Ссылка в новой задаче