ENH: Enable LoadROId with Openslide (#672)

Refactor LoadRoiD/LoadPandaRoid transforms to work with OpenSlide and
removed params redundancy with LoadingParams class.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Kenza Bouzid 2022-11-22 22:51:34 +00:00 коммит произвёл GitHub
Родитель a1d24a8a63
Коммит 615c8ad188
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
25 изменённых файлов: 454 добавлений и 317 удалений

22
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -98,7 +98,7 @@ jobs:
with:
flags: ${{ env.folder }}
smoke_test_slidespandaimagenetmil:
smoke_test_cucim_slidespandaimagenetmil:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -111,7 +111,22 @@ jobs:
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_slidespandaimagenetmil_aml
make smoke_test_cucim_slidespandaimagenetmil_aml
smoke_test_openslide_slidespandaimagenetmil:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_openslide_slidespandaimagenetmil_aml
smoke_test_tilespandaimagenetmil:
runs-on: ubuntu-20.04
@ -254,7 +269,8 @@ jobs:
flake8,
mypy,
pytest,
smoke_test_slidespandaimagenetmil,
smoke_test_cucim_slidespandaimagenetmil,
smoke_test_openslide_slidespandaimagenetmil,
smoke_test_tilespandaimagenetmil,
smoke_test_tcgacrckimagenetmil,
smoke_test_tcgacrcksslmil,

4
.isort.cfg Normal file
Просмотреть файл

@ -0,0 +1,4 @@
[settings]
line_length=120
sections=FUTURE,STDLIB,THIRDPARTY,SUBMODULE,FIRSTPARTY,LOCALFOLDER
known_submodule=health_ml,health_azure,health_cpath,SSL

Просмотреть файл

@ -194,16 +194,28 @@ define DDP_SAMPLER_ARGS
--pl_replace_sampler_ddp=False
endef
define OPENSLIDE_BACKEND_ARGS
--backend=OpenSlide
endef
# The following test takes around 5 minutes
smoke_test_slidespandaimagenetmil_local:
smoke_test_cucim_slidespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
# Once running in AML the following test takes around 6 minutes
smoke_test_slidespandaimagenetmil_aml:
smoke_test_cucim_slidespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke_test_openslide_slidespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${OPENSLIDE_BACKEND_ARGS};}
smoke_test_openslide_slidespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${OPENSLIDE_BACKEND_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
# The following test takes about 6 minutes
smoke_test_tilespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
@ -279,6 +291,6 @@ smoke_test_tiles_panda_no_ddp_sampler_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
smoke tests local: smoke_test_cucim_slidespandaimagenetmil_local smoke_test_openslide_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml
smoke tests AML: smoke_test_cucim_slidespandaimagenetmil_aml smoke_test_openslide_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml

Просмотреть файл

@ -15,6 +15,7 @@ from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from health_azure.utils import create_from_matching_params
from health_cpath.preprocessing.loading import LoadingParams
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
from health_cpath.utils.wsi_utils import TilingParams
@ -34,7 +35,7 @@ from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
from health_cpath.utils.tiles_selection_utils import TilesSelector
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
full DeepMIL model depending on the type of dataset (tiles/slides based).
@ -55,12 +56,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
"If 0 (default), will return all samples in each bag. "
"If > 0 , bags larger than `max_bag_size_inf` will yield "
"random subsets of instances.")
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
level: int = param.Integer(1, bounds=(0, None), doc="The whole slide image level at which the image is extracted."
"Whole slide images are represented in a pyramid consisting of"
"multiple images at different resolutions."
"If 1 (default), will extract baseline image at the resolution"
"at level 1.")
# Outputs Handler parameters:
num_top_slides: int = param.Integer(10, bounds=(0, None), doc="Number of slides to select when saving top and "
"bottom tiles. If set to 10 (default), it selects 10 "
@ -78,10 +73,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
"very high for small num_gpus. This parameters sets an upper bound.")
wsi_has_mask: bool = param.Boolean(default=True,
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
"region of the WSI. If False, will load the whole WSI.")
wsi_backend: str = param.String(default="cuCIM", doc="The backend to use for loading WSI. ")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -149,14 +140,12 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
outputs_root=self.outputs_folder,
n_classes=n_classes,
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names,
primary_val_metric=self.primary_val_metric,
maximise=self.maximise_primary_metric,
val_plot_options=self.get_val_plot_options(),
test_plot_options=self.get_test_plot_options(),
wsi_has_mask=self.wsi_has_mask,
backend=self.wsi_backend,
loading_params=create_from_matching_params(self, LoadingParams),
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
)
if self.num_top_slides > 0:

Просмотреть файл

@ -11,7 +11,8 @@ from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, PANDA
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.models.encoders import HistoSSLEncoder, ImageNetSimCLREncoder, Resnet18, SSLEncoder
from health_cpath.utils.naming import PlotOption
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
from health_cpath.utils.naming import PlotOption, SlideKey
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_ml.utils.checkpoint_utils import CheckpointParser
@ -36,7 +37,13 @@ class BaseDeepSMILEPanda(BaseMIL):
# declared in OptimizerParams:
l_rate=5e-4,
weight_decay=1e-4,
adam_betas=(0.9, 0.99))
adam_betas=(0.9, 0.99),
# loading params:
backend=WSIBackend.CUCIM,
roi_type=ROIType.WHOLE,
image_key=SlideKey.IMAGE,
mask_key=SlideKey.MASK,
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
@ -121,7 +128,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMILSlides:
level=1,
tile_size=224,
background_val=255,
@ -145,8 +151,8 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
level=self.level,
tiling_params=create_from_matching_params(self, TilingParams),
loading_params=create_from_matching_params(self, LoadingParams),
seed=self.get_effective_random_seed(),
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
crossval_count=self.crossval_count,

Просмотреть файл

@ -9,6 +9,7 @@ from torch import optim
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
from health_azure.utils import create_from_matching_params
from health_cpath.preprocessing.loading import LoadingParams
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.networks.layers.attention_layers import (
TransformerPooling,
@ -107,8 +108,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
level=self.level,
tiling_params=create_from_matching_params(self, TilingParams),
loading_params=create_from_matching_params(self, LoadingParams),
seed=self.get_effective_random_seed(),
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
crossval_count=self.crossval_count,

Просмотреть файл

@ -3,13 +3,13 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import numpy as np
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, DistributedSampler
from health_cpath.preprocessing.loading import LoadingParams
from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator
@ -17,11 +17,10 @@ from health_ml.utils.common_utils import _create_generator
from health_cpath.utils.wsi_utils import TilingParams, image_collate
from health_cpath.models.transforms import LoadTilesBatchd
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
from health_cpath.utils.naming import ModelKey, SlideKey
from health_cpath.utils.naming import ModelKey
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from monai.transforms import Compose, LoadImaged, SplitDimd
from monai.data.image_reader import WSIReader
from monai.transforms import Compose
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
@ -273,47 +272,29 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
def __init__(
self,
level: int = 1,
backend: str = 'cuCIM',
wsi_reader_args: Dict[str, Any] = {},
tiling_params: TilingParams = TilingParams(),
loading_params: LoadingParams,
tiling_params: TilingParams,
**kwargs: Any,
) -> None:
"""
:param level: the whole slide image level at which the image is extracted, defaults to 1
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
:param backend: the WSI reader backend, defaults to "cuCIM".
:param wsi_reader_args: additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
:param tiling_params: the tiling on the fly parameters, defaults to TileOnTheFlyParams()
:param tiling_params: the tiling on the fly parameters.
:param loading_params: the loading parameters.
:param kwargs: additional parameters to pass to the parent class HistoDataModule
"""
super().__init__(**kwargs)
self.tiling_params = tiling_params
# WSIReader params
self.level = level
self.backend = backend
self.wsi_reader_args = wsi_reader_args
self.loading_params = loading_params
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
base_transform = Compose(
[
LoadImaged(
keys=SlideKey.IMAGE,
reader=WSIReader,
dtype=np.uint8,
image_only=True,
level=self.level,
backend=self.backend,
**self.wsi_reader_args,
),
self.loading_params.get_load_roid_transform(),
self.tiling_params.get_tiling_transform(bag_size=self.bag_sizes[stage], stage=stage),
# GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
# tiles to be able to apply augmentations on each tile independently
SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True),
self.tiling_params.get_split_transform(),
]
)
if self.transforms_dict and self.transforms_dict[stage]:
if self.transforms_dict and self.transforms_dict[stage]:
transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten()
else:
transforms = base_transform

Просмотреть файл

@ -2,25 +2,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import pandas as pd
from pathlib import Path
from typing import Any, Dict, Union, Optional
import pandas as pd
from monai.config import KeysCollection
from monai.data.image_reader import ImageReader, WSIReader
from monai.transforms import MapTransform
from health_cpath.utils.naming import SlideKey
from health_ml.utils import box_utils
from health_cpath.datasets.base_dataset import SlidesDataset
try:
from cucim import CuImage
except ImportError: # noqa: E722
logging.warning("cucim library not available, code may fail")
class PandaDataset(SlidesDataset):
"""Dataset class for loading files from the PANDA challenge dataset.
@ -52,89 +38,3 @@ class PandaDataset(SlidesDataset):
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
self.validate_columns()
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
class ReadImaged(MapTransform):
"""Basic transform to read image files."""
def __init__(self, reader: ImageReader, keys: KeysCollection,
allow_missing_keys: bool = False, **kwargs: Any) -> None:
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.reader = reader
self.kwargs = kwargs
def __call__(self, data: Dict) -> Dict:
for key in self.keys:
if key in data or not self.allow_missing_keys:
data[key] = self.reader.read(data[key], **self.kwargs)
return data
class LoadPandaROId(MapTransform):
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI).
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
- `'location'` (tuple): top-right coordinates of the bounding box
- `'size'` (tuple): width and height of the bounding box
- `'level'` (int): chosen magnification level
- `'scale'` (float): corresponding scale, loaded from the file
"""
def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = 'mask',
level: int = 0, margin: int = 0, **kwargs: Any) -> None:
"""
:param reader: And instance of MONAI's `WSIReader`.
:param image_key: Image key in the input and output dictionaries.
:param mask_key: Mask key in the input and output dictionaries.
:param level: Magnification level to load from the raw multi-scale files.
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
"""
super().__init__([image_key, mask_key], allow_missing_keys=False)
self.reader = reader
self.image_key = image_key
self.mask_key = mask_key
self.level = level
self.margin = margin
self.kwargs = kwargs
def _get_bounding_box(self, mask_obj: 'CuImage') -> box_utils.Box:
# Estimate bounding box at the lowest resolution (i.e. highest level)
highest_level = mask_obj.resolutions['level_count'] - 1
scale = mask_obj.resolutions['level_downsamples'][highest_level]
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
bbox = box_utils.get_bounding_box(foreground_mask)
padded_bbox = bbox.add_margin(self.margin)
scaled_bbox = scale * padded_bbox
return scaled_bbox
def __call__(self, data: Dict) -> Dict:
mask_obj: CuImage = self.reader.read(data[self.mask_key])
image_obj: CuImage = self.reader.read(data[self.image_key])
level0_bbox = self._get_bounding_box(mask_obj)
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
scale = mask_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
origin = (level0_bbox.y, level0_bbox.x)
get_data_kwargs = dict(
location=origin,
size=(scaled_bbox.h, scaled_bbox.w),
level=self.level,
)
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
data.update(get_data_kwargs)
data[SlideKey.SCALE] = scale
data[SlideKey.ORIGIN] = origin
mask_obj.close()
image_obj.close()
return data

Просмотреть файл

@ -15,13 +15,13 @@ from typing import Tuple, Union, List
import numpy as np
from monai.data import Dataset
from monai.data.image_reader import WSIReader
from tqdm import tqdm
from health_ml.utils.box_utils import Box
from health_cpath.preprocessing import tiling
from health_cpath.utils.naming import SlideKey, TileKey
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.preprocessing.loading import LoadMaskROId, WSIBackend
from health_cpath.preprocessing.create_tiles_dataset import get_tile_id, save_image, merge_dataset_csv_files
CSV_COLUMNS = (
@ -138,8 +138,7 @@ def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupan
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
print(f"Loading slide {slide_id} ...")
reader = WSIReader(backend="cucim")
loader = LoadPandaROId(reader, level=level, margin=margin)
loader = LoadMaskROId(backend=WSIBackend.CUCIM, level=level, margin=margin)
try:
sample = loader(sample) # load 'image' and 'mask' from disk
failed = False

Просмотреть файл

@ -13,13 +13,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import numpy as np
import PIL
from monai.data import Dataset
from monai.data.image_reader import WSIReader
from tqdm import tqdm
from health_ml.utils.box_utils import Box
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.preprocessing import tiling
from health_cpath.preprocessing.loading import LoadROId, segment_foreground
from health_cpath.preprocessing.loading import LoadROId, WSIBackend, segment_foreground
from health_cpath.utils.naming import SlideKey, TileKey
@ -177,7 +176,7 @@ def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_siz
failed_tiles_file.write('tile_id' + '\n')
print(f"Loading slide {slide_id} ...")
loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin,
loader = LoadROId(backend=WSIBackend.CUCIM, level=level, margin=margin,
foreground_threshold=foreground_threshold)
sample = loader(sample) # load 'image' from disk

Просмотреть файл

@ -2,22 +2,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Dict, Optional, Tuple
import param
import numpy as np
import skimage.filters
from enum import Enum
from health_ml.utils import box_utils
from monai.data.image_reader import WSIReader
from monai.transforms import MapTransform
from health_cpath.utils.naming import SlideKey
try:
from cucim import CuImage
except ImportError: # noqa: E722
logging.warning("cucim library not available, code may fail")
from monai.data.wsi_reader import WSIReader
from monai.transforms import MapTransform, LoadImaged
from typing import Any, Callable, Dict, Optional, Tuple
def get_luminance(slide: np.ndarray) -> np.ndarray:
@ -30,8 +24,7 @@ def get_luminance(slide: np.ndarray) -> np.ndarray:
return slide.mean(axis=-3) # type: ignore
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
-> Tuple[np.ndarray, float]:
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) -> Tuple[np.ndarray, float]:
"""Segment the given slide by thresholding its luminance.
:param slide: The RGB image array in (*, C, H, W) format.
@ -45,24 +38,48 @@ def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
return luminance < threshold, threshold
def load_slide_at_level(reader: WSIReader, slide_obj: 'CuImage', level: int) -> np.ndarray:
"""Load full slide array at the given magnification level.
This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415)
fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417).
:param reader: A MONAI `WSIReader` using cuCIM backend.
:param slide_obj: The cuCIM image object returned by `reader.read(<image_file>)`.
:param level: Index of the desired magnification level as defined in the `slide_obj` headers.
:return: The loaded image array in (C, H, W) format.
"""
size = slide_obj.resolutions['level_dimensions'][level][::-1]
slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image
return slide
class ROIType(str, Enum):
"""Options for the ROI selection. Either a bounding box defined by foreground or a mask can be used."""
FOREGROUND = 'foreground'
MASK = 'segmentation_mask'
WHOLE = 'whole_slide'
class LoadROId(MapTransform):
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI).
class WSIBackend(str, Enum):
"""Options for the WSI reader backend."""
OPENSLIDE = 'OpenSlide'
CUCIM = 'cuCIM'
class BaseLoadROId:
"""Abstract base class for loading a region of interest (ROI) from a slide. The ROI is defined by a bounding box."""
def __init__(
self, backend: str = WSIBackend.CUCIM, image_key: str = SlideKey.IMAGE, level: int = 1, margin: int = 0,
backend_args: Dict = {}
) -> None:
"""
:param backend: The WSI reader backend to use. One of 'OpenSlide' or 'cuCIM'. Default: 'cuCIM'.
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
:param level: Magnification level to load from the raw multi-scale file, 0 is the highest resolution. Default: 1
which loads the second highest resolution.
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping. Default: 0.
:param backend_args: Additional arguments to pass to the WSI reader backend. Default: {}.
"""
self.reader = WSIReader(backend=backend, **backend_args)
self.image_key = image_key
self.level = level
self.margin = margin
def _get_bounding_box(self, slide_obj: Any) -> box_utils.Box:
raise NotImplementedError
def __call__(self, data: Dict) -> Dict:
raise NotImplementedError
class LoadROId(MapTransform, BaseLoadROId):
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI) of the foreground tissue.
Operates on dictionaries, replacing the file path in `image_key` with the loaded array in
(C, H, W) format. Also adds the following entries:
@ -71,49 +88,169 @@ class LoadROId(MapTransform):
- `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground
"""
def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0,
margin: int = 0, foreground_threshold: Optional[float] = None) -> None:
def __init__(
self, image_key: str = SlideKey.IMAGE, foreground_threshold: Optional[float] = None, **kwargs: Any
) -> None:
"""
:param reader: And instance of MONAI's `WSIReader`.
:param image_key: Image key in the input and output dictionaries.
:param level: Magnification level to load from the raw multi-scale file.
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
:param foreground_threshold: Pixels with luminance below this value will be considered foreground.
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
:param kwargs: Additional arguments for `BaseLoadROId`.
"""
super().__init__([image_key], allow_missing_keys=False)
self.reader = reader
self.image_key = image_key
self.level = level
self.margin = margin
MapTransform.__init__(self, [image_key], allow_missing_keys=False)
BaseLoadROId.__init__(self, image_key=image_key, **kwargs)
self.foreground_threshold = foreground_threshold
def _get_bounding_box(self, slide_obj: 'CuImage') -> Tuple[box_utils.Box, float]:
# Estimate bounding box at the lowest resolution (i.e. highest level)
highest_level = slide_obj.resolutions['level_count'] - 1
scale = slide_obj.resolutions['level_downsamples'][highest_level]
slide = load_slide_at_level(self.reader, slide_obj, level=highest_level)
def _get_bounding_box(self, slide_obj: Any) -> box_utils.Box:
"""Estimate bounding box at the lowest resolution (i.e. highest level) of the slide."""
highest_level = self.reader.get_level_count(slide_obj) - 1
scale = self.reader.get_downsample_ratio(slide_obj, highest_level)
slide, _ = self.reader.get_data(slide_obj, level=highest_level)
foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold)
self.foreground_threshold = threshold
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
return bbox, threshold
return bbox
def __call__(self, data: Dict) -> Dict:
image_obj: CuImage = self.reader.read(data[self.image_key])
try:
image_obj = self.reader.read(data[self.image_key])
level0_bbox = self._get_bounding_box(image_obj)
level0_bbox, threshold = self._get_bounding_box(image_obj)
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
origin = (level0_bbox.y, level0_bbox.x)
scale = self.reader.get_downsample_ratio(image_obj, self.level)
scaled_bbox = level0_bbox / scale
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
origin = (level0_bbox.y, level0_bbox.x)
scale = image_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
size=(scaled_bbox.h, scaled_bbox.w))
data[SlideKey.ORIGIN] = origin
data[SlideKey.SCALE] = scale
data[SlideKey.FOREGROUND_THRESHOLD] = threshold
image_obj.close()
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
size=(scaled_bbox.h, scaled_bbox.w))
data[SlideKey.ORIGIN] = origin
data[SlideKey.SCALE] = scale
data[SlideKey.FOREGROUND_THRESHOLD] = self.foreground_threshold
finally:
image_obj.close()
return data
class LoadMaskROId(MapTransform, BaseLoadROId):
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI) defined by the mask.
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
- `'location'` (tuple): top-right coordinates of the bounding box
- `'size'` (tuple): width and height of the bounding box
- `'level'` (int): chosen magnification level
- `'scale'` (float): corresponding scale, loaded from the file
"""
def __init__(self, image_key: str = SlideKey.IMAGE, mask_key: str = SlideKey.MASK, **kwargs: Any) -> None:
"""
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
:param mask_key: Mask key in the input and output dictionaries. Default: 'mask'.
:param kwargs: Additional arguments for `BaseLoadROId`.
"""
MapTransform.__init__(self, [image_key, mask_key], allow_missing_keys=False)
BaseLoadROId.__init__(self, image_key=image_key, **kwargs)
self.mask_key = mask_key
def _get_bounding_box(self, mask_obj: Any) -> box_utils.Box:
"""Estimate bounding box at the lowest resolution (i.e. highest level) of the mask."""
highest_level = self.reader.get_level_count(mask_obj) - 1
scale = self.reader.get_downsample_ratio(mask_obj, highest_level)
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
bbox = box_utils.get_bounding_box(foreground_mask)
padded_bbox = bbox.add_margin(self.margin)
scaled_bbox = scale * padded_bbox
return scaled_bbox
def __call__(self, data: Dict) -> Dict:
try:
mask_obj = self.reader.read(data[self.mask_key])
image_obj = self.reader.read(data[self.image_key])
level0_bbox = self._get_bounding_box(mask_obj)
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
scale = self.reader.get_downsample_ratio(mask_obj, self.level)
scaled_bbox = level0_bbox / scale
origin = (level0_bbox.y, level0_bbox.x)
get_data_kwargs = dict(
location=origin,
size=(scaled_bbox.h, scaled_bbox.w),
level=self.level,
)
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
data.update(get_data_kwargs)
data[SlideKey.SCALE] = scale
data[SlideKey.ORIGIN] = origin
finally:
mask_obj.close()
image_obj.close()
return data
class LoadingParams(param.Parameterized):
"""Parameters for loading a whole slide image."""
level: int = param.Integer(
default=1,
doc="Magnification level to load from the raw multi-scale files. Default: 1.")
margin: int = param.Integer(
default=0, doc="Amount in pixels by which to enlarge the estimated bounding box for cropping")
backend: WSIBackend = param.ClassSelector(
default=WSIBackend.CUCIM,
class_=WSIBackend,
doc="WSI reader backend. Default: cuCIM.")
roi_type: ROIType = param.ClassSelector(
default=ROIType.WHOLE,
class_=ROIType,
doc="ROI type to use for cropping the slide. Default: `ROIType.WHOLE`. no cropping is performed.")
image_key: str = param.String(
default=SlideKey.IMAGE,
doc="Key for the image in the data dictionary.")
mask_key: str = param.String(
default=SlideKey.MASK,
doc="Key for the mask in the data dictionary. This only applies to `LoadMaskROId`.")
foreground_threshold: Optional[float] = param.Number(
default=None,
bounds=(0, 255.),
allow_None=True,
doc="Threshold for foreground mask. If None, the threshold is selected automatically with otsu thresholding."
"This only applies to `LoadROId`.")
def set_roi_type_to_foreground(self) -> None:
"""Set the ROI type to foreground. This is useful for plotting even if we load whole slides during
training. This help us reduce the size of thrumbnails to only meaningful tissue. We only hardcode it to
foreground in the WHOLE case. Otherwise, keep it as is if a mask is available."""
if self.roi_type == ROIType.WHOLE:
self.roi_type = ROIType.FOREGROUND
def get_load_roid_transform(self) -> Callable:
"""Returns a transform to load a slide and mask, cropped to the mask bounding box (ROI) defined by either the
mask or the foreground."""
if self.roi_type == ROIType.WHOLE:
return LoadImaged(keys=self.image_key, reader=WSIReader, image_only=True, level=self.level, # type: ignore
backend=self.backend, dtype=np.uint8, **self.get_additionl_backend_args())
elif self.roi_type == ROIType.FOREGROUND:
return LoadROId(backend=self.backend, image_key=self.image_key, level=self.level,
margin=self.margin, foreground_threshold=self.foreground_threshold,
backend_args=self.get_additionl_backend_args())
elif self.roi_type == ROIType.MASK:
return LoadMaskROId(backend=self.backend, image_key=self.image_key,
mask_key=self.mask_key, level=self.level, margin=self.margin,
backend_args=self.get_additionl_backend_args())
else:
raise ValueError(f"Unknown ROI type: {self.roi_type}. Choose from {list(ROIType)}.")
def get_additionl_backend_args(self) -> Dict[str, Any]:
"""Returns a dictionary of additional arguments for the WSI reader backend. Multi processing is
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
This function can be overridden in BaseMIL to add additional arguments for the backend."""
return dict()

Просмотреть файл

@ -5,7 +5,8 @@
from pathlib import Path
import sys
import time
from typing import Any
from typing import Any, Optional
from azureml.core import Workspace
himl_histo_root_dir = Path(__file__).parent.parent.parent
himl_root = himl_histo_root_dir.parent.parent
@ -16,11 +17,11 @@ from health_azure import DatasetConfig # noqa: E402
from health_azure.utils import get_workspace # noqa: E402
def mount_dataset(dataset_id: str) -> Any:
ws = get_workspace()
target_folder = "/tmp/datasets/" + dataset_id
def mount_dataset(dataset_id: str, tmp_root: str = "/tmp/datasets", aml_workspace: Optional[Workspace] = None) -> Any:
ws = get_workspace(aml_workspace)
target_folder = "/".join([tmp_root, dataset_id])
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
_, mount_ctx = dataset.to_input_dataset_local(ws)
_, mount_ctx = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
assert mount_ctx is not None # for mypy
mount_ctx.start()
return mount_ctx

Просмотреть файл

@ -18,6 +18,7 @@ from torchmetrics.metric import Metric
from health_azure.utils import replace_directory
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.preprocessing.loading import LoadingParams
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, TilesSelector
from health_cpath.utils.naming import MetricsKey, ModelKey, PlotOption, ResultsKey
@ -254,24 +255,22 @@ class OutputsPolicy:
class DeepMILOutputsHandler:
"""Class that manages writing validation and test outputs for DeepMIL models."""
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, loading_params: LoadingParams,
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
maximise: bool, val_plot_options: Collection[PlotOption],
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
backend: str = "cuCIM", val_set_is_dist: bool = True) -> None:
test_plot_options: Collection[PlotOption],
val_set_is_dist: bool = True,) -> None:
"""
:param outputs_root: Root directory where to save all produced outputs.
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
:param tile_size: The size of each tile.
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: List of class names. For binary (`n_classes == 1`), expects `len(class_names) == 2`.
If `None`, will return `('0', '1', ...)`.
:param primary_val_metric: Name of the validation metric to track for saving best epoch outputs.
:param maximise: Whether higher is better for `primary_val_metric`.
:param val_plot_options: The desired plot options for validation time.
:param test_plot_options: The desired plot options for test time.
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
:param backend: The backend to use for reading the tiles. Default is "cuCIM".
:param loading_params: Parameters for loading WSI to create plots. This paramter is passed to PlotsHandler.
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
@ -280,7 +279,6 @@ class DeepMILOutputsHandler:
self.outputs_root = outputs_root
self.n_classes = n_classes
self.tile_size = tile_size
self.level = level
self.class_names = validate_class_names(class_names, self.n_classes)
self.outputs_policy = OutputsPolicy(outputs_root=outputs_root,
@ -291,21 +289,17 @@ class DeepMILOutputsHandler:
self.val_plots_handler = DeepMILPlotsHandler(
plot_options=val_plot_options,
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.VAL,
wsi_has_mask=wsi_has_mask,
backend=backend,
loading_params=loading_params,
)
self.test_plots_handler = DeepMILPlotsHandler(
plot_options=test_plot_options,
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.TEST,
wsi_has_mask=wsi_has_mask,
backend=backend,
loading_params=loading_params,
)
self.val_set_is_dist = val_set_is_dist

Просмотреть файл

@ -11,6 +11,7 @@ from torch import Tensor
import matplotlib.pyplot as plt
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.preprocessing.loading import LoadingParams
from health_cpath.utils.viz_utils import (
plot_attention_tiles,
plot_heatmap_overlay,
@ -176,38 +177,32 @@ class DeepMILPlotsHandler:
def __init__(
self,
plot_options: Collection[PlotOption],
level: int = 1,
loading_params: LoadingParams,
tile_size: int = 224,
num_columns: int = 4,
figsize: Tuple[int, int] = (10, 10),
stage: str = '',
class_names: Optional[Sequence[str]] = None,
wsi_has_mask: bool = True,
backend: str = "cuCIM",
) -> None:
"""Class that handles the plotting of DeepMIL results.
:param plot_options: A set of plot options to produce the desired plot outputs.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
:param tile_size: _description_, defaults to 224
:param loading_params: The loading parameters to use when loading the whole slide images.
:param tile_size: The size of the tiles to use when plotting the attention tiles, defaults to 224
:param num_columns: Number of columns to create the subfigures grid, defaults to 4
:param figsize: The figure size of tiles attention plots, defaults to (10, 10)
:param stage: Test or Validation, used to name the plots
:param class_names: List of class names, defaults to None
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
:param wsi_has_mask: Whether the whole slide images have a mask, defaults to True
:param backend: The backend to use for loading the whole slide images, defaults to "cuCIM"
"""
self.plot_options = plot_options
self.class_names = validate_class_names_for_plot_options(class_names, plot_options)
self.level = level
self.tile_size = tile_size
self.num_columns = num_columns
self.figsize = figsize
self.stage = stage
self.wsi_has_mask = wsi_has_mask
self.backend = backend
self.loading_params = loading_params
self.loading_params.set_roi_type_to_foreground()
self.slides_dataset: Optional[SlidesDataset] = None
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
@ -216,8 +211,7 @@ class DeepMILPlotsHandler:
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = self.slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask,
backend=self.backend)
slide_dict = load_image_dict(slide_dict, loading_params=self.loading_params)
return slide_dict
def save_slide_node_figures(
@ -237,7 +231,7 @@ class DeepMILPlotsHandler:
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
save_attention_heatmap(
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.loading_params.level
)
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:

Просмотреть файл

@ -17,20 +17,16 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
from monai.data.meta_tensor import MetaTensor
from monai.data.dataset import Dataset
from monai.data.image_reader import WSIReader
from torch.utils.data import DataLoader
from health_cpath.preprocessing.loading import LoadROId
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.naming import ResultsKey
from health_cpath.utils.heatmap_utils import location_selected_tiles
from health_cpath.utils.tiles_selection_utils import SlideNode
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
def load_image_dict(
sample: dict, level: int, margin: int, wsi_has_mask: bool = True, backend: str = "cuCIM"
) -> Dict[SlideKey, Any]:
def load_image_dict(sample: dict, loading_params: LoadingParams) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
@ -40,13 +36,9 @@ def load_image_dict(
'data_provider': ['karolinska'],
'isup_grade': tensor([0]),
'gleason_score': ['0+0']}
:param level: level of resolution to be loaded
:param margin: margin to be included
:param wsi_has_mask: whether the WSI has a mask
:param backend: backend to be used to load the image (cuCIM or OpenSlide)
:param loading_params: LoadingParams object that contains the parameters to load the image.
"""
transform = LoadPandaROId if wsi_has_mask else LoadROId
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
loader = loading_params.get_load_roid_transform()
img = loader(sample)
if isinstance(img[SlideKey.IMAGE], MetaTensor):
# New monai transforms return a MetaTensor, we need to convert it to a numpy array for backward compatibility
@ -76,7 +68,8 @@ def plot_panda_data_sample(
slide_id = dict_images[SlideKey.SLIDE_ID]
title = dict_images[SlideKey.METADATA][title_key]
print(f">>> Slide {slide_id}")
img = load_image_dict(dict_images, level=level, margin=margin)
loading_params = LoadingParams(level=level, margin=margin, backend=WSIBackend.CUCIM, roi_type=ROIType.MASK)
img = load_image_dict(dict_images, loading_params)
ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0))
ax.set_title(title)
fig.tight_layout()

Просмотреть файл

@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional
from health_cpath.utils.naming import ModelKey, SlideKey
from health_ml.utils.bag_utils import multibag_collate
from monai.data.meta_tensor import MetaTensor
from monai.transforms import RandGridPatchd, GridPatchd
from monai.transforms import RandGridPatchd, GridPatchd, SplitDimd
def image_collate(batch: List) -> Any:
@ -99,3 +99,9 @@ class TilingParams(param.Parameterized):
pad_mode=self.tile_pad_mode, # type: ignore
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
)
def get_split_transform(self) -> Callable:
"""GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
tiles to be able to apply augmentations on each tile independently.
"""
return SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True)

Просмотреть файл

@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import time
import pytest
import torch
from pathlib import Path
@ -14,6 +15,7 @@ from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTi
from health_cpath.utils.naming import ModelKey, SlideKey
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.utils.common_utils import is_gpu_available
from testhisto.datamodules.test_slides_datamodule import get_loading_params
from testhisto.utils.utils_testhisto import run_distributed
@ -64,7 +66,7 @@ def test_slides_datamodule_different_bag_sizes(
max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf,
tiling_params=TilingParams(tile_size=28),
level=0,
loading_params=get_loading_params(level=0),
)
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
@ -98,7 +100,7 @@ def test_slides_datamodule_different_batch_sizes(
max_bag_size=16,
max_bag_size_inf=16,
tiling_params=TilingParams(tile_size=28),
level=0,
loading_params=get_loading_params(level=0),
)
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
@ -134,6 +136,8 @@ def _test_datamodule_pl_ddp_sampler_true(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
if rank == 0:
time.sleep(15) # slow down rank 0 to avoid concurrent file access
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
@ -141,6 +145,8 @@ def _test_datamodule_pl_ddp_sampler_false(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
if rank == 0:
time.sleep(15) # slow down rank 0 to avoid concurrent file access
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
@ -151,13 +157,13 @@ def _test_datamodule_pl_ddp_sampler_false(
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
pl_replace_sampler_ddp=True,
seed=42)
seed=42, tiling_params=TilingParams(),
loading_params=get_loading_params())
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
slides_datamodule.pl_replace_sampler_ddp = False
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
@pytest.mark.skip(reason="Test fails with Broken Pipe Error. To be fixed.")
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
@ -173,6 +179,7 @@ def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
with patch("torch.distributed.is_initialized", return_value=True):
with patch("torch.distributed.get_world_size", return_value=2):
slides_datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False,
tiling_params=TilingParams(), loading_params=get_loading_params()
)
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)

Просмотреть файл

@ -12,6 +12,7 @@ from pathlib import Path
from monai.transforms import RandFlipd
from typing import Generator, Dict, Callable, Union, Tuple
from torch.utils.data import DataLoader
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.utils.common_utils import is_gpu_available
@ -20,15 +21,21 @@ from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.utils.naming import SlideKey, ModelKey
from health_cpath.datamodules.panda_module import PandaSlidesDataModule
from testhisto.mocks.slides_generator import (
MockPandaSlidesGenerator,
MockHistoDataType,
TilesPositioningType,
)
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, MockHistoDataType, TilesPositioningType
no_gpu = not is_gpu_available()
def get_loading_params(level: int = 0, roi_type: ROIType = ROIType.FOREGROUND) -> LoadingParams:
return LoadingParams(
level=level,
backend=WSIBackend.CUCIM,
roi_type=roi_type,
foreground_threshold=255,
margin=0,
)
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir_diagonal(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
@ -84,11 +91,11 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
@pytest.mark.parametrize("roi_type", [ROIType.FOREGROUND, ROIType.WHOLE])
def test_tiling_on_the_fly(roi_type: ROIType, mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = 16
tile_size = 28
level = 0
channels = 3
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
@ -96,7 +103,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size=batch_size,
max_bag_size=tile_count,
tiling_params=TilingParams(tile_size=28),
level=level,
loading_params=get_loading_params(level=0, roi_type=roi_type),
)
dataloader = datamodule.train_dataloader()
for sample in dataloader:
@ -113,10 +120,10 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
@pytest.mark.parametrize("roi_type", [ROIType.FOREGROUND, ROIType.WHOLE])
def test_tiling_without_fixed_tile_count(roi_type: ROIType, mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = None
level = 0
assert_batch_index = 0
min_expected_tile_count = 16
datamodule = PandaSlidesDataModule(
@ -124,7 +131,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Pa
batch_size=batch_size,
max_bag_size=tile_count,
tiling_params=TilingParams(tile_size=28),
level=level,
loading_params=get_loading_params(level=0, roi_type=roi_type),
)
dataloader = datamodule.train_dataloader()
for sample in dataloader:
@ -146,7 +153,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
batch_size=batch_size,
max_bag_size=tile_count,
tiling_params=TilingParams(tile_size=tile_size),
level=level,
loading_params=get_loading_params(level=level),
)
dataloader = datamodule.train_dataloader()
for sample in dataloader:
@ -165,7 +172,6 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [1, 2])
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
level = 0
overlap = .5
expected_tile_matches = 16
min_expected_tile_count = 32
@ -175,7 +181,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal:
max_bag_size=None,
batch_size=batch_size,
tiling_params=TilingParams(tile_size=28, tile_overlap=overlap),
level=level
loading_params=get_loading_params(level=0),
)
dataloader = datamodule.train_dataloader()
for sample in dataloader:
@ -206,14 +212,13 @@ def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> Non
batch_size = 1
tile_count = 4
level = 0
flipdatamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
max_bag_size_inf=0,
tiling_params=TilingParams(tile_size=28),
level=level,
loading_params=get_loading_params(level=0),
transforms_dict=get_transforms_dict(),
)
flip_train_tiles = retrieve_tiles(flipdatamodule.train_dataloader())
@ -256,7 +261,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
tile_count = 2
level = 0
assert_batch_index = 0
n_tiles_list = [4, 5, 6, 7, 8, 9]
@ -266,7 +270,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
max_bag_size=tile_count,
max_bag_size_inf=0,
tiling_params=TilingParams(tile_size=28),
level=level,
loading_params=get_loading_params(level=0),
)
train_dataloader = datamodule.train_dataloader()
for sample in train_dataloader:

Просмотреть файл

@ -15,6 +15,7 @@ from pytorch_lightning import seed_everything
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import DeepSMILESlidesPandaBenchmark
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.preprocessing.loading import ROIType, WSIBackend
from health_cpath.utils.naming import SlideKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
@ -50,6 +51,10 @@ def test_panda_reproducibility(tmp_path: Path) -> None:
container.tile_size = tile_size
container.max_bag_size = num_tiles
container.local_datasets = [tmp_path]
container.backend = WSIBackend.CUCIM
container.roi_type = ROIType.FOREGROUND
container.margin = 0
container.level = 0
def test_data_items_are_equal(loader_fn_names: List[str]) -> None:
"""Creates a new data module from the container, and checks if all the data loaders specified in

Просмотреть файл

@ -4,6 +4,7 @@
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Optional, Set
from health_cpath.preprocessing.loading import ROIType, WSIBackend
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda, DeepSMILETilesPanda
@ -21,15 +22,15 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=Resnet18.__name__,
tile_size=28,
# Data Module parameters
batch_size=2,
max_bag_size=4,
max_bag_size_inf=4,
batch_size_inf=2,
encoding_chunk_size=4,
max_bag_size=4,
max_bag_size_inf=0,
cache_mode=CacheMode.NONE,
precache_location=CacheLocation.NONE,
# declared in DatasetParams:
@ -39,6 +40,11 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
crossval_count=1,
ssl_checkpoint=None,
analyse_loss=analyse_loss,
# Loading parameters
level=0,
backend=WSIBackend.CUCIM,
roi_type=ROIType.FOREGROUND,
foreground_threshold=255,
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
@ -63,14 +69,13 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=Resnet18.__name__,
tile_size=28,
# Data Module parameters
batch_size=2,
batch_size_inf=2,
encoding_chunk_size=4,
level=0,
max_bag_size=4,
max_bag_size_inf=0,
# declared in DatasetParams:
@ -78,6 +83,11 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
# declared in TrainerParams:
max_epochs=2,
crossval_count=1,
# Loading parameters
level=0,
backend=WSIBackend.CUCIM,
roi_type=ROIType.FOREGROUND,
foreground_threshold=255,
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)

Просмотреть файл

@ -0,0 +1,71 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pytest
from pathlib import Path
from typing import List, Tuple
from monai.transforms import LoadImaged
from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader, WSIReader
from health_cpath.datasets.default_paths import PANDA_DATASET_ID
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.preprocessing.loading import BaseLoadROId, LoadingParams, ROIType, WSIBackend, LoadROId, LoadMaskROId
from health_cpath.scripts.mount_azure_dataset import mount_dataset
from health_cpath.utils.naming import SlideKey
from health_ml.utils.common_utils import is_gpu_available
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
no_gpu = not is_gpu_available()
@pytest.mark.parametrize("roi_type", [r for r in ROIType])
@pytest.mark.parametrize("backend", [b for b in WSIBackend])
def test_get_load_roid_transform(backend: WSIBackend, roi_type: ROIType) -> None:
loading_params = LoadingParams(backend=backend, roi_type=roi_type)
transform = loading_params.get_load_roid_transform()
transform_type = {ROIType.MASK: LoadMaskROId, ROIType.FOREGROUND: LoadROId, ROIType.WHOLE: LoadImaged}
assert isinstance(transform, transform_type[roi_type])
reader_type = {WSIBackend.CUCIM: CuCIMWSIReader, WSIBackend.OPENSLIDE: OpenSlideWSIReader}
if roi_type in [ROIType.MASK, ROIType.FOREGROUND]:
assert isinstance(transform, BaseLoadROId)
assert isinstance(transform.reader, WSIReader) # type: ignore
assert isinstance(transform.reader.reader, reader_type[backend]) # type: ignore
@pytest.mark.skip(reason="This test is failing because of issue #655")
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_load_slide(tmp_path: Path) -> None:
_ = mount_dataset(dataset_id=PANDA_DATASET_ID, tmp_root=str(tmp_path), aml_workspace=DEFAULT_WORKSPACE.workspace)
root_path = tmp_path / PANDA_DATASET_ID
def _check_load_roi_transforms(
backend: WSIBackend, expected_keys: List[SlideKey], expected_shape: Tuple[int, int, int]
) -> None:
loading_params.backend = backend
load_transform = loading_params.get_load_roid_transform()
sample = PandaDataset(root_path)[0]
slide_dict = load_transform(sample)
assert all([k in slide_dict for k in expected_keys])
assert slide_dict[SlideKey.IMAGE].shape == expected_shape
# WSI ROIType
loading_params = LoadingParams(roi_type=ROIType.WHOLE, level=2)
wsi_expected_keys = [SlideKey.IMAGE, SlideKey.SLIDE_ID]
wsi_expected_shape = (3, 1840, 1728)
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
_check_load_roi_transforms(backend, wsi_expected_keys, wsi_expected_shape)
# Foreground ROIType
loading_params = LoadingParams(roi_type=ROIType.FOREGROUND, level=2)
foreground_expected_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD, SlideKey.IMAGE]
foreground_expected_shape = (3, 1342, 340)
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
_check_load_roi_transforms(backend, foreground_expected_keys, foreground_expected_shape)
# Mask ROI transforms
loading_params = LoadingParams(roi_type=ROIType.MASK, level=2)
mask_expected_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.IMAGE]
mask_expected_shape = (3, 1344, 341)
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
_check_load_roi_transforms(backend, mask_expected_keys, mask_expected_shape)

Просмотреть файл

@ -7,6 +7,7 @@ import torch
import torch.distributed
import torch.multiprocessing
from ruamel.yaml import YAML
from health_cpath.preprocessing.loading import LoadingParams
from health_cpath.utils.tiles_selection_utils import TilesSelector
from testhisto.utils.utils_testhisto import run_distributed
from torch.testing import assert_close
@ -31,7 +32,7 @@ def _create_outputs_handler(outputs_root: Path) -> DeepMILOutputsHandler:
outputs_root=outputs_root,
n_classes=1,
tile_size=224,
level=1,
loading_params=LoadingParams(level=1),
class_names=None,
primary_val_metric=_PRIMARY_METRIC_KEY,
maximise=True,

Просмотреть файл

@ -8,7 +8,7 @@ from pathlib import Path
from typing import Any, Collection, Dict, List
from unittest.mock import MagicMock, patch
import pytest
from health_cpath.preprocessing.loading import LoadingParams, ROIType
from health_cpath.utils.naming import PlotOption, ResultsKey
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix, save_pr_curve
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
@ -18,7 +18,15 @@ from testhisto.mocks.container import MockDeepSMILETilesPanda
def test_plots_handler_wrong_class_names() -> None:
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
_ = DeepMILPlotsHandler(plot_options, class_names=[])
_ = DeepMILPlotsHandler(plot_options, class_names=[], loading_params=LoadingParams())
@pytest.mark.parametrize("roi_type", [r for r in ROIType])
def test_plots_handler_always_uses_roid_loading(roi_type: ROIType) -> None:
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
loading_params = LoadingParams(roi_type=roi_type)
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo", "bar"], loading_params=loading_params)
assert plots_handler.loading_params.roi_type in [ROIType.MASK, ROIType.FOREGROUND]
@pytest.mark.parametrize(
@ -71,7 +79,7 @@ def assert_plot_func_called_if_among_plot_options(
],
)
def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[PlotOption]) -> None:
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"])
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"], loading_params=LoadingParams())
plots_handler.slides_dataset = MagicMock()
n_tiles = 4

Просмотреть файл

@ -8,7 +8,6 @@ import math
import random
from pathlib import Path
from typing import List, Optional
from unittest.mock import MagicMock, patch
import matplotlib
import numpy as np
@ -25,7 +24,7 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
from health_cpath.utils.naming import ResultsKey
from health_cpath.utils.heatmap_utils import location_selected_tiles
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
from health_cpath.utils.viz_utils import save_figure, load_image_dict
from health_cpath.utils.viz_utils import save_figure
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
@ -295,13 +294,3 @@ def test_location_selected_tiles(level: int) -> None:
assert max(tile_xs) <= slide_image.shape[2] // factor
assert min(tile_ys) >= 0
assert max(tile_ys) <= slide_image.shape[1] // factor
@pytest.mark.parametrize("backend", ["cuCIM", "OpenSlide"])
@pytest.mark.parametrize("wsi_has_mask", [True, False])
def test_load_image_dict(wsi_has_mask: bool, backend: str) -> None:
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask, backend=backend)
assert mock_load_panda_roi.called == wsi_has_mask
assert mock_load_roi.called == (not wsi_has_mask)

Просмотреть файл

@ -1,3 +1,4 @@
from unittest.mock import patch
import torch
import pytest
import numpy as np
@ -84,3 +85,12 @@ def test_tiling_params(stage: ModelKey) -> None:
expected_transform_type = RandGridPatchd if stage == ModelKey.TRAIN else GridPatchd
transform = params.get_tiling_transform(stage=stage, bag_size=10)
assert isinstance(transform, expected_transform_type)
def test_tiling_params_split_transform() -> None:
params = TilingParams()
with patch("health_cpath.utils.wsi_utils.SplitDimd") as mock_split_dim:
_ = params.get_split_transform()
mock_split_dim.assert_called_once()
call_args = mock_split_dim.call_args_list[0][1]
assert call_args == {'keys': SlideKey.IMAGE, 'dim': 0, 'keepdim': False, 'list_output': True}