зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable LoadROId with Openslide (#672)
Refactor LoadRoiD/LoadPandaRoid transforms to work with OpenSlide and removed params redundancy with LoadingParams class. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
a1d24a8a63
Коммит
615c8ad188
|
@ -98,7 +98,7 @@ jobs:
|
|||
with:
|
||||
flags: ${{ env.folder }}
|
||||
|
||||
smoke_test_slidespandaimagenetmil:
|
||||
smoke_test_cucim_slidespandaimagenetmil:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -111,7 +111,22 @@ jobs:
|
|||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_slidespandaimagenetmil_aml
|
||||
make smoke_test_cucim_slidespandaimagenetmil_aml
|
||||
|
||||
smoke_test_openslide_slidespandaimagenetmil:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_openslide_slidespandaimagenetmil_aml
|
||||
|
||||
smoke_test_tilespandaimagenetmil:
|
||||
runs-on: ubuntu-20.04
|
||||
|
@ -254,7 +269,8 @@ jobs:
|
|||
flake8,
|
||||
mypy,
|
||||
pytest,
|
||||
smoke_test_slidespandaimagenetmil,
|
||||
smoke_test_cucim_slidespandaimagenetmil,
|
||||
smoke_test_openslide_slidespandaimagenetmil,
|
||||
smoke_test_tilespandaimagenetmil,
|
||||
smoke_test_tcgacrckimagenetmil,
|
||||
smoke_test_tcgacrcksslmil,
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
[settings]
|
||||
line_length=120
|
||||
sections=FUTURE,STDLIB,THIRDPARTY,SUBMODULE,FIRSTPARTY,LOCALFOLDER
|
||||
known_submodule=health_ml,health_azure,health_cpath,SSL
|
|
@ -194,16 +194,28 @@ define DDP_SAMPLER_ARGS
|
|||
--pl_replace_sampler_ddp=False
|
||||
endef
|
||||
|
||||
define OPENSLIDE_BACKEND_ARGS
|
||||
--backend=OpenSlide
|
||||
endef
|
||||
|
||||
# The following test takes around 5 minutes
|
||||
smoke_test_slidespandaimagenetmil_local:
|
||||
smoke_test_cucim_slidespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS};}
|
||||
|
||||
# Once running in AML the following test takes around 6 minutes
|
||||
smoke_test_slidespandaimagenetmil_aml:
|
||||
smoke_test_cucim_slidespandaimagenetmil_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke_test_openslide_slidespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${OPENSLIDE_BACKEND_ARGS};}
|
||||
|
||||
smoke_test_openslide_slidespandaimagenetmil_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${OPENSLIDE_BACKEND_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
# The following test takes about 6 minutes
|
||||
smoke_test_tilespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
|
@ -279,6 +291,6 @@ smoke_test_tiles_panda_no_ddp_sampler_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
|
||||
smoke tests local: smoke_test_cucim_slidespandaimagenetmil_local smoke_test_openslide_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
smoke tests AML: smoke_test_cucim_slidespandaimagenetmil_aml smoke_test_openslide_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
|
|
|
@ -15,6 +15,7 @@ from pytorch_lightning.callbacks import Callback
|
|||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
|
||||
|
@ -34,7 +35,7 @@ from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
|
|||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
|
||||
|
||||
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
|
||||
class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
|
||||
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
|
||||
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
|
||||
full DeepMIL model depending on the type of dataset (tiles/slides based).
|
||||
|
@ -55,12 +56,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
"If 0 (default), will return all samples in each bag. "
|
||||
"If > 0 , bags larger than `max_bag_size_inf` will yield "
|
||||
"random subsets of instances.")
|
||||
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
|
||||
level: int = param.Integer(1, bounds=(0, None), doc="The whole slide image level at which the image is extracted."
|
||||
"Whole slide images are represented in a pyramid consisting of"
|
||||
"multiple images at different resolutions."
|
||||
"If 1 (default), will extract baseline image at the resolution"
|
||||
"at level 1.")
|
||||
# Outputs Handler parameters:
|
||||
num_top_slides: int = param.Integer(10, bounds=(0, None), doc="Number of slides to select when saving top and "
|
||||
"bottom tiles. If set to 10 (default), it selects 10 "
|
||||
|
@ -78,10 +73,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
|
||||
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
|
||||
"very high for small num_gpus. This parameters sets an upper bound.")
|
||||
wsi_has_mask: bool = param.Boolean(default=True,
|
||||
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
|
||||
"region of the WSI. If False, will load the whole WSI.")
|
||||
wsi_backend: str = param.String(default="cuCIM", doc="The backend to use for loading WSI. ")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
@ -149,14 +140,12 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
outputs_root=self.outputs_folder,
|
||||
n_classes=n_classes,
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
class_names=self.class_names,
|
||||
primary_val_metric=self.primary_val_metric,
|
||||
maximise=self.maximise_primary_metric,
|
||||
val_plot_options=self.get_val_plot_options(),
|
||||
test_plot_options=self.get_test_plot_options(),
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
backend=self.wsi_backend,
|
||||
loading_params=create_from_matching_params(self, LoadingParams),
|
||||
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
|
||||
)
|
||||
if self.num_top_slides > 0:
|
||||
|
|
|
@ -11,7 +11,8 @@ from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, PANDA
|
|||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.models.encoders import HistoSSLEncoder, ImageNetSimCLREncoder, Resnet18, SSLEncoder
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
|
||||
from health_cpath.utils.naming import PlotOption, SlideKey
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
@ -36,7 +37,13 @@ class BaseDeepSMILEPanda(BaseMIL):
|
|||
# declared in OptimizerParams:
|
||||
l_rate=5e-4,
|
||||
weight_decay=1e-4,
|
||||
adam_betas=(0.9, 0.99))
|
||||
adam_betas=(0.9, 0.99),
|
||||
# loading params:
|
||||
backend=WSIBackend.CUCIM,
|
||||
roi_type=ROIType.WHOLE,
|
||||
image_key=SlideKey.IMAGE,
|
||||
mask_key=SlideKey.MASK,
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
|
||||
|
@ -121,7 +128,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMILSlides:
|
||||
level=1,
|
||||
tile_size=224,
|
||||
background_val=255,
|
||||
|
@ -145,8 +151,8 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
level=self.level,
|
||||
tiling_params=create_from_matching_params(self, TilingParams),
|
||||
loading_params=create_from_matching_params(self, LoadingParams),
|
||||
seed=self.get_effective_random_seed(),
|
||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||
crossval_count=self.crossval_count,
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch import optim
|
|||
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
TransformerPooling,
|
||||
|
@ -107,8 +108,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
level=self.level,
|
||||
tiling_params=create_from_matching_params(self, TilingParams),
|
||||
loading_params=create_from_matching_params(self, LoadingParams),
|
||||
seed=self.get_effective_random_seed(),
|
||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||
crossval_count=self.crossval_count,
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
|
||||
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||
from health_ml.utils.common_utils import _create_generator
|
||||
|
@ -17,11 +17,10 @@ from health_ml.utils.common_utils import _create_generator
|
|||
from health_cpath.utils.wsi_utils import TilingParams, image_collate
|
||||
from health_cpath.models.transforms import LoadTilesBatchd
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_cpath.utils.naming import ModelKey
|
||||
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from monai.transforms import Compose, LoadImaged, SplitDimd
|
||||
from monai.data.image_reader import WSIReader
|
||||
from monai.transforms import Compose
|
||||
|
||||
|
||||
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
|
||||
|
@ -273,47 +272,29 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
level: int = 1,
|
||||
backend: str = 'cuCIM',
|
||||
wsi_reader_args: Dict[str, Any] = {},
|
||||
tiling_params: TilingParams = TilingParams(),
|
||||
loading_params: LoadingParams,
|
||||
tiling_params: TilingParams,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
:param level: the whole slide image level at which the image is extracted, defaults to 1
|
||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
|
||||
:param backend: the WSI reader backend, defaults to "cuCIM".
|
||||
:param wsi_reader_args: additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
||||
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
||||
:param tiling_params: the tiling on the fly parameters, defaults to TileOnTheFlyParams()
|
||||
:param tiling_params: the tiling on the fly parameters.
|
||||
:param loading_params: the loading parameters.
|
||||
:param kwargs: additional parameters to pass to the parent class HistoDataModule
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.tiling_params = tiling_params
|
||||
# WSIReader params
|
||||
self.level = level
|
||||
self.backend = backend
|
||||
self.wsi_reader_args = wsi_reader_args
|
||||
self.loading_params = loading_params
|
||||
|
||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
||||
base_transform = Compose(
|
||||
[
|
||||
LoadImaged(
|
||||
keys=SlideKey.IMAGE,
|
||||
reader=WSIReader,
|
||||
dtype=np.uint8,
|
||||
image_only=True,
|
||||
level=self.level,
|
||||
backend=self.backend,
|
||||
**self.wsi_reader_args,
|
||||
),
|
||||
self.loading_params.get_load_roid_transform(),
|
||||
self.tiling_params.get_tiling_transform(bag_size=self.bag_sizes[stage], stage=stage),
|
||||
# GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
|
||||
# tiles to be able to apply augmentations on each tile independently
|
||||
SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True),
|
||||
self.tiling_params.get_split_transform(),
|
||||
]
|
||||
)
|
||||
if self.transforms_dict and self.transforms_dict[stage]:
|
||||
|
||||
if self.transforms_dict and self.transforms_dict[stage]:
|
||||
transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten()
|
||||
else:
|
||||
transforms = base_transform
|
||||
|
|
|
@ -2,25 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
import pandas as pd
|
||||
from monai.config import KeysCollection
|
||||
from monai.data.image_reader import ImageReader, WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
|
||||
from health_ml.utils import box_utils
|
||||
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
|
||||
try:
|
||||
from cucim import CuImage
|
||||
except ImportError: # noqa: E722
|
||||
logging.warning("cucim library not available, code may fail")
|
||||
|
||||
|
||||
class PandaDataset(SlidesDataset):
|
||||
"""Dataset class for loading files from the PANDA challenge dataset.
|
||||
|
@ -52,89 +38,3 @@ class PandaDataset(SlidesDataset):
|
|||
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
|
||||
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
|
||||
self.validate_columns()
|
||||
|
||||
|
||||
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
|
||||
class ReadImaged(MapTransform):
|
||||
"""Basic transform to read image files."""
|
||||
|
||||
def __init__(self, reader: ImageReader, keys: KeysCollection,
|
||||
allow_missing_keys: bool = False, **kwargs: Any) -> None:
|
||||
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||
self.reader = reader
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
for key in self.keys:
|
||||
if key in data or not self.allow_missing_keys:
|
||||
data[key] = self.reader.read(data[key], **self.kwargs)
|
||||
return data
|
||||
|
||||
|
||||
class LoadPandaROId(MapTransform):
|
||||
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI).
|
||||
|
||||
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
|
||||
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
|
||||
- `'location'` (tuple): top-right coordinates of the bounding box
|
||||
- `'size'` (tuple): width and height of the bounding box
|
||||
- `'level'` (int): chosen magnification level
|
||||
- `'scale'` (float): corresponding scale, loaded from the file
|
||||
"""
|
||||
|
||||
def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = 'mask',
|
||||
level: int = 0, margin: int = 0, **kwargs: Any) -> None:
|
||||
"""
|
||||
:param reader: And instance of MONAI's `WSIReader`.
|
||||
:param image_key: Image key in the input and output dictionaries.
|
||||
:param mask_key: Mask key in the input and output dictionaries.
|
||||
:param level: Magnification level to load from the raw multi-scale files.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||
"""
|
||||
super().__init__([image_key, mask_key], allow_missing_keys=False)
|
||||
self.reader = reader
|
||||
self.image_key = image_key
|
||||
self.mask_key = mask_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _get_bounding_box(self, mask_obj: 'CuImage') -> box_utils.Box:
|
||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||
highest_level = mask_obj.resolutions['level_count'] - 1
|
||||
scale = mask_obj.resolutions['level_downsamples'][highest_level]
|
||||
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
||||
|
||||
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
||||
|
||||
bbox = box_utils.get_bounding_box(foreground_mask)
|
||||
padded_bbox = bbox.add_margin(self.margin)
|
||||
scaled_bbox = scale * padded_bbox
|
||||
return scaled_bbox
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
mask_obj: CuImage = self.reader.read(data[self.mask_key])
|
||||
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||
|
||||
level0_bbox = self._get_bounding_box(mask_obj)
|
||||
|
||||
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
scale = mask_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
get_data_kwargs = dict(
|
||||
location=origin,
|
||||
size=(scaled_bbox.h, scaled_bbox.w),
|
||||
level=self.level,
|
||||
)
|
||||
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
|
||||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
return data
|
||||
|
|
|
@ -15,13 +15,13 @@ from typing import Tuple, Union, List
|
|||
|
||||
import numpy as np
|
||||
from monai.data import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from tqdm import tqdm
|
||||
from health_ml.utils.box_utils import Box
|
||||
|
||||
from health_cpath.preprocessing import tiling
|
||||
from health_cpath.utils.naming import SlideKey, TileKey
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.preprocessing.loading import LoadMaskROId, WSIBackend
|
||||
from health_cpath.preprocessing.create_tiles_dataset import get_tile_id, save_image, merge_dataset_csv_files
|
||||
|
||||
CSV_COLUMNS = (
|
||||
|
@ -138,8 +138,7 @@ def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupan
|
|||
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
|
||||
print(f"Loading slide {slide_id} ...")
|
||||
reader = WSIReader(backend="cucim")
|
||||
loader = LoadPandaROId(reader, level=level, margin=margin)
|
||||
loader = LoadMaskROId(backend=WSIBackend.CUCIM, level=level, margin=margin)
|
||||
try:
|
||||
sample = loader(sample) # load 'image' and 'mask' from disk
|
||||
failed = False
|
||||
|
|
|
@ -13,13 +13,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import PIL
|
||||
from monai.data import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from tqdm import tqdm
|
||||
from health_ml.utils.box_utils import Box
|
||||
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.preprocessing import tiling
|
||||
from health_cpath.preprocessing.loading import LoadROId, segment_foreground
|
||||
from health_cpath.preprocessing.loading import LoadROId, WSIBackend, segment_foreground
|
||||
from health_cpath.utils.naming import SlideKey, TileKey
|
||||
|
||||
|
||||
|
@ -177,7 +176,7 @@ def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_siz
|
|||
failed_tiles_file.write('tile_id' + '\n')
|
||||
|
||||
print(f"Loading slide {slide_id} ...")
|
||||
loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin,
|
||||
loader = LoadROId(backend=WSIBackend.CUCIM, level=level, margin=margin,
|
||||
foreground_threshold=foreground_threshold)
|
||||
sample = loader(sample) # load 'image' from disk
|
||||
|
||||
|
|
|
@ -2,22 +2,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import param
|
||||
import numpy as np
|
||||
import skimage.filters
|
||||
|
||||
from enum import Enum
|
||||
from health_ml.utils import box_utils
|
||||
from monai.data.image_reader import WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
|
||||
try:
|
||||
from cucim import CuImage
|
||||
except ImportError: # noqa: E722
|
||||
logging.warning("cucim library not available, code may fail")
|
||||
from monai.data.wsi_reader import WSIReader
|
||||
from monai.transforms import MapTransform, LoadImaged
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
|
||||
def get_luminance(slide: np.ndarray) -> np.ndarray:
|
||||
|
@ -30,8 +24,7 @@ def get_luminance(slide: np.ndarray) -> np.ndarray:
|
|||
return slide.mean(axis=-3) # type: ignore
|
||||
|
||||
|
||||
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
|
||||
-> Tuple[np.ndarray, float]:
|
||||
def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) -> Tuple[np.ndarray, float]:
|
||||
"""Segment the given slide by thresholding its luminance.
|
||||
|
||||
:param slide: The RGB image array in (*, C, H, W) format.
|
||||
|
@ -45,24 +38,48 @@ def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \
|
|||
return luminance < threshold, threshold
|
||||
|
||||
|
||||
def load_slide_at_level(reader: WSIReader, slide_obj: 'CuImage', level: int) -> np.ndarray:
|
||||
"""Load full slide array at the given magnification level.
|
||||
|
||||
This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415)
|
||||
fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417).
|
||||
|
||||
:param reader: A MONAI `WSIReader` using cuCIM backend.
|
||||
:param slide_obj: The cuCIM image object returned by `reader.read(<image_file>)`.
|
||||
:param level: Index of the desired magnification level as defined in the `slide_obj` headers.
|
||||
:return: The loaded image array in (C, H, W) format.
|
||||
"""
|
||||
size = slide_obj.resolutions['level_dimensions'][level][::-1]
|
||||
slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image
|
||||
return slide
|
||||
class ROIType(str, Enum):
|
||||
"""Options for the ROI selection. Either a bounding box defined by foreground or a mask can be used."""
|
||||
FOREGROUND = 'foreground'
|
||||
MASK = 'segmentation_mask'
|
||||
WHOLE = 'whole_slide'
|
||||
|
||||
|
||||
class LoadROId(MapTransform):
|
||||
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI).
|
||||
class WSIBackend(str, Enum):
|
||||
"""Options for the WSI reader backend."""
|
||||
OPENSLIDE = 'OpenSlide'
|
||||
CUCIM = 'cuCIM'
|
||||
|
||||
|
||||
class BaseLoadROId:
|
||||
"""Abstract base class for loading a region of interest (ROI) from a slide. The ROI is defined by a bounding box."""
|
||||
|
||||
def __init__(
|
||||
self, backend: str = WSIBackend.CUCIM, image_key: str = SlideKey.IMAGE, level: int = 1, margin: int = 0,
|
||||
backend_args: Dict = {}
|
||||
) -> None:
|
||||
"""
|
||||
:param backend: The WSI reader backend to use. One of 'OpenSlide' or 'cuCIM'. Default: 'cuCIM'.
|
||||
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
|
||||
:param level: Magnification level to load from the raw multi-scale file, 0 is the highest resolution. Default: 1
|
||||
which loads the second highest resolution.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping. Default: 0.
|
||||
:param backend_args: Additional arguments to pass to the WSI reader backend. Default: {}.
|
||||
"""
|
||||
self.reader = WSIReader(backend=backend, **backend_args)
|
||||
self.image_key = image_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
|
||||
def _get_bounding_box(self, slide_obj: Any) -> box_utils.Box:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LoadROId(MapTransform, BaseLoadROId):
|
||||
"""Transform that loads a pathology slide, cropped to an estimated bounding box (ROI) of the foreground tissue.
|
||||
|
||||
Operates on dictionaries, replacing the file path in `image_key` with the loaded array in
|
||||
(C, H, W) format. Also adds the following entries:
|
||||
|
@ -71,49 +88,169 @@ class LoadROId(MapTransform):
|
|||
- `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground
|
||||
"""
|
||||
|
||||
def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0,
|
||||
margin: int = 0, foreground_threshold: Optional[float] = None) -> None:
|
||||
def __init__(
|
||||
self, image_key: str = SlideKey.IMAGE, foreground_threshold: Optional[float] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
:param reader: And instance of MONAI's `WSIReader`.
|
||||
:param image_key: Image key in the input and output dictionaries.
|
||||
:param level: Magnification level to load from the raw multi-scale file.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
|
||||
:param foreground_threshold: Pixels with luminance below this value will be considered foreground.
|
||||
If `None` (default), an optimal threshold will be estimated automatically using Otsu's method.
|
||||
:param kwargs: Additional arguments for `BaseLoadROId`.
|
||||
"""
|
||||
super().__init__([image_key], allow_missing_keys=False)
|
||||
self.reader = reader
|
||||
self.image_key = image_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
MapTransform.__init__(self, [image_key], allow_missing_keys=False)
|
||||
BaseLoadROId.__init__(self, image_key=image_key, **kwargs)
|
||||
self.foreground_threshold = foreground_threshold
|
||||
|
||||
def _get_bounding_box(self, slide_obj: 'CuImage') -> Tuple[box_utils.Box, float]:
|
||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||
highest_level = slide_obj.resolutions['level_count'] - 1
|
||||
scale = slide_obj.resolutions['level_downsamples'][highest_level]
|
||||
slide = load_slide_at_level(self.reader, slide_obj, level=highest_level)
|
||||
def _get_bounding_box(self, slide_obj: Any) -> box_utils.Box:
|
||||
"""Estimate bounding box at the lowest resolution (i.e. highest level) of the slide."""
|
||||
highest_level = self.reader.get_level_count(slide_obj) - 1
|
||||
scale = self.reader.get_downsample_ratio(slide_obj, highest_level)
|
||||
slide, _ = self.reader.get_data(slide_obj, level=highest_level)
|
||||
|
||||
foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold)
|
||||
self.foreground_threshold = threshold
|
||||
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
|
||||
return bbox, threshold
|
||||
return bbox
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
image_obj: CuImage = self.reader.read(data[self.image_key])
|
||||
try:
|
||||
image_obj = self.reader.read(data[self.image_key])
|
||||
level0_bbox = self._get_bounding_box(image_obj)
|
||||
|
||||
level0_bbox, threshold = self._get_bounding_box(image_obj)
|
||||
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
scale = self.reader.get_downsample_ratio(image_obj, self.level)
|
||||
scaled_bbox = level0_bbox / scale
|
||||
|
||||
# cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
scale = image_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
|
||||
size=(scaled_bbox.h, scaled_bbox.w))
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.FOREGROUND_THRESHOLD] = threshold
|
||||
|
||||
image_obj.close()
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level,
|
||||
size=(scaled_bbox.h, scaled_bbox.w))
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.FOREGROUND_THRESHOLD] = self.foreground_threshold
|
||||
finally:
|
||||
image_obj.close()
|
||||
return data
|
||||
|
||||
|
||||
class LoadMaskROId(MapTransform, BaseLoadROId):
|
||||
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI) defined by the mask.
|
||||
|
||||
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
|
||||
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
|
||||
- `'location'` (tuple): top-right coordinates of the bounding box
|
||||
- `'size'` (tuple): width and height of the bounding box
|
||||
- `'level'` (int): chosen magnification level
|
||||
- `'scale'` (float): corresponding scale, loaded from the file
|
||||
"""
|
||||
|
||||
def __init__(self, image_key: str = SlideKey.IMAGE, mask_key: str = SlideKey.MASK, **kwargs: Any) -> None:
|
||||
"""
|
||||
:param image_key: Image key in the input and output dictionaries. Default: 'image'.
|
||||
:param mask_key: Mask key in the input and output dictionaries. Default: 'mask'.
|
||||
:param kwargs: Additional arguments for `BaseLoadROId`.
|
||||
"""
|
||||
MapTransform.__init__(self, [image_key, mask_key], allow_missing_keys=False)
|
||||
BaseLoadROId.__init__(self, image_key=image_key, **kwargs)
|
||||
self.mask_key = mask_key
|
||||
|
||||
def _get_bounding_box(self, mask_obj: Any) -> box_utils.Box:
|
||||
"""Estimate bounding box at the lowest resolution (i.e. highest level) of the mask."""
|
||||
highest_level = self.reader.get_level_count(mask_obj) - 1
|
||||
scale = self.reader.get_downsample_ratio(mask_obj, highest_level)
|
||||
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
||||
|
||||
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
||||
|
||||
bbox = box_utils.get_bounding_box(foreground_mask)
|
||||
padded_bbox = bbox.add_margin(self.margin)
|
||||
scaled_bbox = scale * padded_bbox
|
||||
return scaled_bbox
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
try:
|
||||
mask_obj = self.reader.read(data[self.mask_key])
|
||||
image_obj = self.reader.read(data[self.image_key])
|
||||
|
||||
level0_bbox = self._get_bounding_box(mask_obj)
|
||||
|
||||
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
scale = self.reader.get_downsample_ratio(mask_obj, self.level)
|
||||
scaled_bbox = level0_bbox / scale
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
get_data_kwargs = dict(
|
||||
location=origin,
|
||||
size=(scaled_bbox.h, scaled_bbox.w),
|
||||
level=self.level,
|
||||
)
|
||||
mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
|
||||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
finally:
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
return data
|
||||
|
||||
|
||||
class LoadingParams(param.Parameterized):
|
||||
"""Parameters for loading a whole slide image."""
|
||||
|
||||
level: int = param.Integer(
|
||||
default=1,
|
||||
doc="Magnification level to load from the raw multi-scale files. Default: 1.")
|
||||
margin: int = param.Integer(
|
||||
default=0, doc="Amount in pixels by which to enlarge the estimated bounding box for cropping")
|
||||
backend: WSIBackend = param.ClassSelector(
|
||||
default=WSIBackend.CUCIM,
|
||||
class_=WSIBackend,
|
||||
doc="WSI reader backend. Default: cuCIM.")
|
||||
roi_type: ROIType = param.ClassSelector(
|
||||
default=ROIType.WHOLE,
|
||||
class_=ROIType,
|
||||
doc="ROI type to use for cropping the slide. Default: `ROIType.WHOLE`. no cropping is performed.")
|
||||
image_key: str = param.String(
|
||||
default=SlideKey.IMAGE,
|
||||
doc="Key for the image in the data dictionary.")
|
||||
mask_key: str = param.String(
|
||||
default=SlideKey.MASK,
|
||||
doc="Key for the mask in the data dictionary. This only applies to `LoadMaskROId`.")
|
||||
foreground_threshold: Optional[float] = param.Number(
|
||||
default=None,
|
||||
bounds=(0, 255.),
|
||||
allow_None=True,
|
||||
doc="Threshold for foreground mask. If None, the threshold is selected automatically with otsu thresholding."
|
||||
"This only applies to `LoadROId`.")
|
||||
|
||||
def set_roi_type_to_foreground(self) -> None:
|
||||
"""Set the ROI type to foreground. This is useful for plotting even if we load whole slides during
|
||||
training. This help us reduce the size of thrumbnails to only meaningful tissue. We only hardcode it to
|
||||
foreground in the WHOLE case. Otherwise, keep it as is if a mask is available."""
|
||||
if self.roi_type == ROIType.WHOLE:
|
||||
self.roi_type = ROIType.FOREGROUND
|
||||
|
||||
def get_load_roid_transform(self) -> Callable:
|
||||
"""Returns a transform to load a slide and mask, cropped to the mask bounding box (ROI) defined by either the
|
||||
mask or the foreground."""
|
||||
if self.roi_type == ROIType.WHOLE:
|
||||
return LoadImaged(keys=self.image_key, reader=WSIReader, image_only=True, level=self.level, # type: ignore
|
||||
backend=self.backend, dtype=np.uint8, **self.get_additionl_backend_args())
|
||||
elif self.roi_type == ROIType.FOREGROUND:
|
||||
return LoadROId(backend=self.backend, image_key=self.image_key, level=self.level,
|
||||
margin=self.margin, foreground_threshold=self.foreground_threshold,
|
||||
backend_args=self.get_additionl_backend_args())
|
||||
elif self.roi_type == ROIType.MASK:
|
||||
return LoadMaskROId(backend=self.backend, image_key=self.image_key,
|
||||
mask_key=self.mask_key, level=self.level, margin=self.margin,
|
||||
backend_args=self.get_additionl_backend_args())
|
||||
else:
|
||||
raise ValueError(f"Unknown ROI type: {self.roi_type}. Choose from {list(ROIType)}.")
|
||||
|
||||
def get_additionl_backend_args(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of additional arguments for the WSI reader backend. Multi processing is
|
||||
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
||||
This function can be overridden in BaseMIL to add additional arguments for the backend."""
|
||||
return dict()
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from azureml.core import Workspace
|
||||
|
||||
himl_histo_root_dir = Path(__file__).parent.parent.parent
|
||||
himl_root = himl_histo_root_dir.parent.parent
|
||||
|
@ -16,11 +17,11 @@ from health_azure import DatasetConfig # noqa: E402
|
|||
from health_azure.utils import get_workspace # noqa: E402
|
||||
|
||||
|
||||
def mount_dataset(dataset_id: str) -> Any:
|
||||
ws = get_workspace()
|
||||
target_folder = "/tmp/datasets/" + dataset_id
|
||||
def mount_dataset(dataset_id: str, tmp_root: str = "/tmp/datasets", aml_workspace: Optional[Workspace] = None) -> Any:
|
||||
ws = get_workspace(aml_workspace)
|
||||
target_folder = "/".join([tmp_root, dataset_id])
|
||||
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
|
||||
_, mount_ctx = dataset.to_input_dataset_local(ws)
|
||||
_, mount_ctx = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
||||
assert mount_ctx is not None # for mypy
|
||||
mount_ctx.start()
|
||||
return mount_ctx
|
||||
|
|
|
@ -18,6 +18,7 @@ from torchmetrics.metric import Metric
|
|||
|
||||
from health_azure.utils import replace_directory
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, TilesSelector
|
||||
from health_cpath.utils.naming import MetricsKey, ModelKey, PlotOption, ResultsKey
|
||||
|
||||
|
@ -254,24 +255,22 @@ class OutputsPolicy:
|
|||
class DeepMILOutputsHandler:
|
||||
"""Class that manages writing validation and test outputs for DeepMIL models."""
|
||||
|
||||
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
|
||||
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, loading_params: LoadingParams,
|
||||
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
|
||||
maximise: bool, val_plot_options: Collection[PlotOption],
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
|
||||
backend: str = "cuCIM", val_set_is_dist: bool = True) -> None:
|
||||
test_plot_options: Collection[PlotOption],
|
||||
val_set_is_dist: bool = True,) -> None:
|
||||
"""
|
||||
:param outputs_root: Root directory where to save all produced outputs.
|
||||
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
|
||||
:param tile_size: The size of each tile.
|
||||
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
|
||||
:param class_names: List of class names. For binary (`n_classes == 1`), expects `len(class_names) == 2`.
|
||||
If `None`, will return `('0', '1', ...)`.
|
||||
:param primary_val_metric: Name of the validation metric to track for saving best epoch outputs.
|
||||
:param maximise: Whether higher is better for `primary_val_metric`.
|
||||
:param val_plot_options: The desired plot options for validation time.
|
||||
:param test_plot_options: The desired plot options for test time.
|
||||
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
|
||||
:param backend: The backend to use for reading the tiles. Default is "cuCIM".
|
||||
:param loading_params: Parameters for loading WSI to create plots. This paramter is passed to PlotsHandler.
|
||||
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
|
||||
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
|
||||
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
|
||||
|
@ -280,7 +279,6 @@ class DeepMILOutputsHandler:
|
|||
self.outputs_root = outputs_root
|
||||
self.n_classes = n_classes
|
||||
self.tile_size = tile_size
|
||||
self.level = level
|
||||
self.class_names = validate_class_names(class_names, self.n_classes)
|
||||
|
||||
self.outputs_policy = OutputsPolicy(outputs_root=outputs_root,
|
||||
|
@ -291,21 +289,17 @@ class DeepMILOutputsHandler:
|
|||
|
||||
self.val_plots_handler = DeepMILPlotsHandler(
|
||||
plot_options=val_plot_options,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.VAL,
|
||||
wsi_has_mask=wsi_has_mask,
|
||||
backend=backend,
|
||||
loading_params=loading_params,
|
||||
)
|
||||
self.test_plots_handler = DeepMILPlotsHandler(
|
||||
plot_options=test_plot_options,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.TEST,
|
||||
wsi_has_mask=wsi_has_mask,
|
||||
backend=backend,
|
||||
loading_params=loading_params,
|
||||
)
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch import Tensor
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
from health_cpath.utils.viz_utils import (
|
||||
plot_attention_tiles,
|
||||
plot_heatmap_overlay,
|
||||
|
@ -176,38 +177,32 @@ class DeepMILPlotsHandler:
|
|||
def __init__(
|
||||
self,
|
||||
plot_options: Collection[PlotOption],
|
||||
level: int = 1,
|
||||
loading_params: LoadingParams,
|
||||
tile_size: int = 224,
|
||||
num_columns: int = 4,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
stage: str = '',
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
wsi_has_mask: bool = True,
|
||||
backend: str = "cuCIM",
|
||||
) -> None:
|
||||
"""Class that handles the plotting of DeepMIL results.
|
||||
|
||||
:param plot_options: A set of plot options to produce the desired plot outputs.
|
||||
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
|
||||
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
|
||||
:param tile_size: _description_, defaults to 224
|
||||
:param loading_params: The loading parameters to use when loading the whole slide images.
|
||||
:param tile_size: The size of the tiles to use when plotting the attention tiles, defaults to 224
|
||||
:param num_columns: Number of columns to create the subfigures grid, defaults to 4
|
||||
:param figsize: The figure size of tiles attention plots, defaults to (10, 10)
|
||||
:param stage: Test or Validation, used to name the plots
|
||||
:param class_names: List of class names, defaults to None
|
||||
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
|
||||
:param wsi_has_mask: Whether the whole slide images have a mask, defaults to True
|
||||
:param backend: The backend to use for loading the whole slide images, defaults to "cuCIM"
|
||||
"""
|
||||
self.plot_options = plot_options
|
||||
self.class_names = validate_class_names_for_plot_options(class_names, plot_options)
|
||||
self.level = level
|
||||
self.tile_size = tile_size
|
||||
self.num_columns = num_columns
|
||||
self.figsize = figsize
|
||||
self.stage = stage
|
||||
self.wsi_has_mask = wsi_has_mask
|
||||
self.backend = backend
|
||||
self.loading_params = loading_params
|
||||
self.loading_params.set_roi_type_to_foreground()
|
||||
self.slides_dataset: Optional[SlidesDataset] = None
|
||||
|
||||
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
|
||||
|
@ -216,8 +211,7 @@ class DeepMILPlotsHandler:
|
|||
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = self.slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask,
|
||||
backend=self.backend)
|
||||
slide_dict = load_image_dict(slide_dict, loading_params=self.loading_params)
|
||||
return slide_dict
|
||||
|
||||
def save_slide_node_figures(
|
||||
|
@ -237,7 +231,7 @@ class DeepMILPlotsHandler:
|
|||
|
||||
if PlotOption.ATTENTION_HEATMAP in self.plot_options:
|
||||
save_attention_heatmap(
|
||||
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.level
|
||||
case, slide_node, slide_dict, case_dir, results, self.tile_size, level=self.loading_params.level
|
||||
)
|
||||
|
||||
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
|
||||
|
|
|
@ -17,20 +17,16 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
|||
|
||||
from monai.data.meta_tensor import MetaTensor
|
||||
from monai.data.dataset import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from torch.utils.data import DataLoader
|
||||
from health_cpath.preprocessing.loading import LoadROId
|
||||
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.naming import ResultsKey
|
||||
from health_cpath.utils.heatmap_utils import location_selected_tiles
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
def load_image_dict(
|
||||
sample: dict, level: int, margin: int, wsi_has_mask: bool = True, backend: str = "cuCIM"
|
||||
) -> Dict[SlideKey, Any]:
|
||||
def load_image_dict(sample: dict, loading_params: LoadingParams) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
:param sample: dict describing image metadata. Example:
|
||||
|
@ -40,13 +36,9 @@ def load_image_dict(
|
|||
'data_provider': ['karolinska'],
|
||||
'isup_grade': tensor([0]),
|
||||
'gleason_score': ['0+0']}
|
||||
:param level: level of resolution to be loaded
|
||||
:param margin: margin to be included
|
||||
:param wsi_has_mask: whether the WSI has a mask
|
||||
:param backend: backend to be used to load the image (cuCIM or OpenSlide)
|
||||
:param loading_params: LoadingParams object that contains the parameters to load the image.
|
||||
"""
|
||||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
|
||||
loader = loading_params.get_load_roid_transform()
|
||||
img = loader(sample)
|
||||
if isinstance(img[SlideKey.IMAGE], MetaTensor):
|
||||
# New monai transforms return a MetaTensor, we need to convert it to a numpy array for backward compatibility
|
||||
|
@ -76,7 +68,8 @@ def plot_panda_data_sample(
|
|||
slide_id = dict_images[SlideKey.SLIDE_ID]
|
||||
title = dict_images[SlideKey.METADATA][title_key]
|
||||
print(f">>> Slide {slide_id}")
|
||||
img = load_image_dict(dict_images, level=level, margin=margin)
|
||||
loading_params = LoadingParams(level=level, margin=margin, backend=WSIBackend.CUCIM, roi_type=ROIType.MASK)
|
||||
img = load_image_dict(dict_images, loading_params)
|
||||
ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0))
|
||||
ax.set_title(title)
|
||||
fig.tight_layout()
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional
|
|||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_ml.utils.bag_utils import multibag_collate
|
||||
from monai.data.meta_tensor import MetaTensor
|
||||
from monai.transforms import RandGridPatchd, GridPatchd
|
||||
from monai.transforms import RandGridPatchd, GridPatchd, SplitDimd
|
||||
|
||||
|
||||
def image_collate(batch: List) -> Any:
|
||||
|
@ -99,3 +99,9 @@ class TilingParams(param.Parameterized):
|
|||
pad_mode=self.tile_pad_mode, # type: ignore
|
||||
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
|
||||
)
|
||||
|
||||
def get_split_transform(self) -> Callable:
|
||||
"""GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
|
||||
tiles to be able to apply augmentations on each tile independently.
|
||||
"""
|
||||
return SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import time
|
||||
import pytest
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
@ -14,6 +15,7 @@ from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTi
|
|||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from testhisto.datamodules.test_slides_datamodule import get_loading_params
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
||||
|
||||
|
@ -64,7 +66,7 @@ def test_slides_datamodule_different_bag_sizes(
|
|||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=0,
|
||||
loading_params=get_loading_params(level=0),
|
||||
)
|
||||
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
||||
|
@ -98,7 +100,7 @@ def test_slides_datamodule_different_batch_sizes(
|
|||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=0,
|
||||
loading_params=get_loading_params(level=0),
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
||||
|
@ -134,6 +136,8 @@ def _test_datamodule_pl_ddp_sampler_true(
|
|||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
if rank == 0:
|
||||
time.sleep(15) # slow down rank 0 to avoid concurrent file access
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
|
||||
|
||||
|
@ -141,6 +145,8 @@ def _test_datamodule_pl_ddp_sampler_false(
|
|||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
if rank == 0:
|
||||
time.sleep(15) # slow down rank 0 to avoid concurrent file access
|
||||
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
|
||||
|
||||
|
@ -151,13 +157,13 @@ def _test_datamodule_pl_ddp_sampler_false(
|
|||
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
|
||||
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
|
||||
pl_replace_sampler_ddp=True,
|
||||
seed=42)
|
||||
seed=42, tiling_params=TilingParams(),
|
||||
loading_params=get_loading_params())
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
|
||||
slides_datamodule.pl_replace_sampler_ddp = False
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test fails with Broken Pipe Error. To be fixed.")
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
|
@ -173,6 +179,7 @@ def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
|
|||
with patch("torch.distributed.is_initialized", return_value=True):
|
||||
with patch("torch.distributed.get_world_size", return_value=2):
|
||||
slides_datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
|
||||
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False,
|
||||
tiling_params=TilingParams(), loading_params=get_loading_params()
|
||||
)
|
||||
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)
|
||||
|
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
|||
from monai.transforms import RandFlipd
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
from torch.utils.data import DataLoader
|
||||
from health_cpath.preprocessing.loading import LoadingParams, ROIType, WSIBackend
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
|
@ -20,15 +21,21 @@ from health_cpath.datasets.panda_dataset import PandaDataset
|
|||
|
||||
from health_cpath.utils.naming import SlideKey, ModelKey
|
||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule
|
||||
from testhisto.mocks.slides_generator import (
|
||||
MockPandaSlidesGenerator,
|
||||
MockHistoDataType,
|
||||
TilesPositioningType,
|
||||
)
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, MockHistoDataType, TilesPositioningType
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
def get_loading_params(level: int = 0, roi_type: ROIType = ROIType.FOREGROUND) -> LoadingParams:
|
||||
return LoadingParams(
|
||||
level=level,
|
||||
backend=WSIBackend.CUCIM,
|
||||
roi_type=roi_type,
|
||||
foreground_threshold=255,
|
||||
margin=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir_diagonal(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
|
@ -84,11 +91,11 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
@pytest.mark.parametrize("roi_type", [ROIType.FOREGROUND, ROIType.WHOLE])
|
||||
def test_tiling_on_the_fly(roi_type: ROIType, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = 16
|
||||
tile_size = 28
|
||||
level = 0
|
||||
channels = 3
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
|
@ -96,7 +103,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
|||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
loading_params=get_loading_params(level=0, roi_type=roi_type),
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
for sample in dataloader:
|
||||
|
@ -113,10 +120,10 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
@pytest.mark.parametrize("roi_type", [ROIType.FOREGROUND, ROIType.WHOLE])
|
||||
def test_tiling_without_fixed_tile_count(roi_type: ROIType, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = None
|
||||
level = 0
|
||||
assert_batch_index = 0
|
||||
min_expected_tile_count = 16
|
||||
datamodule = PandaSlidesDataModule(
|
||||
|
@ -124,7 +131,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Pa
|
|||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
loading_params=get_loading_params(level=0, roi_type=roi_type),
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
for sample in dataloader:
|
||||
|
@ -146,7 +153,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
|||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tiling_params=TilingParams(tile_size=tile_size),
|
||||
level=level,
|
||||
loading_params=get_loading_params(level=level),
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
for sample in dataloader:
|
||||
|
@ -165,7 +172,6 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
|||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
level = 0
|
||||
overlap = .5
|
||||
expected_tile_matches = 16
|
||||
min_expected_tile_count = 32
|
||||
|
@ -175,7 +181,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal:
|
|||
max_bag_size=None,
|
||||
batch_size=batch_size,
|
||||
tiling_params=TilingParams(tile_size=28, tile_overlap=overlap),
|
||||
level=level
|
||||
loading_params=get_loading_params(level=0),
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
for sample in dataloader:
|
||||
|
@ -206,14 +212,13 @@ def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> Non
|
|||
|
||||
batch_size = 1
|
||||
tile_count = 4
|
||||
level = 0
|
||||
flipdatamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
loading_params=get_loading_params(level=0),
|
||||
transforms_dict=get_transforms_dict(),
|
||||
)
|
||||
flip_train_tiles = retrieve_tiles(flipdatamodule.train_dataloader())
|
||||
|
@ -256,7 +261,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
|
|||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
|
||||
tile_count = 2
|
||||
level = 0
|
||||
assert_batch_index = 0
|
||||
n_tiles_list = [4, 5, 6, 7, 8, 9]
|
||||
|
||||
|
@ -266,7 +270,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
|
|||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
loading_params=get_loading_params(level=0),
|
||||
)
|
||||
train_dataloader = datamodule.train_dataloader()
|
||||
for sample in train_dataloader:
|
||||
|
|
|
@ -15,6 +15,7 @@ from pytorch_lightning import seed_everything
|
|||
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import DeepSMILESlidesPandaBenchmark
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.preprocessing.loading import ROIType, WSIBackend
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
|
||||
|
@ -50,6 +51,10 @@ def test_panda_reproducibility(tmp_path: Path) -> None:
|
|||
container.tile_size = tile_size
|
||||
container.max_bag_size = num_tiles
|
||||
container.local_datasets = [tmp_path]
|
||||
container.backend = WSIBackend.CUCIM
|
||||
container.roi_type = ROIType.FOREGROUND
|
||||
container.margin = 0
|
||||
container.level = 0
|
||||
|
||||
def test_data_items_are_equal(loader_fn_names: List[str]) -> None:
|
||||
"""Creates a new data module from the container, and checks if all the data loaders specified in
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Set
|
||||
from health_cpath.preprocessing.loading import ROIType, WSIBackend
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda, DeepSMILETilesPanda
|
||||
|
@ -21,15 +22,15 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=Resnet18.__name__,
|
||||
tile_size=28,
|
||||
# Data Module parameters
|
||||
batch_size=2,
|
||||
max_bag_size=4,
|
||||
max_bag_size_inf=4,
|
||||
batch_size_inf=2,
|
||||
encoding_chunk_size=4,
|
||||
max_bag_size=4,
|
||||
max_bag_size_inf=0,
|
||||
cache_mode=CacheMode.NONE,
|
||||
precache_location=CacheLocation.NONE,
|
||||
# declared in DatasetParams:
|
||||
|
@ -39,6 +40,11 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
crossval_count=1,
|
||||
ssl_checkpoint=None,
|
||||
analyse_loss=analyse_loss,
|
||||
# Loading parameters
|
||||
level=0,
|
||||
backend=WSIBackend.CUCIM,
|
||||
roi_type=ROIType.FOREGROUND,
|
||||
foreground_threshold=255,
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
@ -63,14 +69,13 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
|
|||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=Resnet18.__name__,
|
||||
tile_size=28,
|
||||
# Data Module parameters
|
||||
batch_size=2,
|
||||
batch_size_inf=2,
|
||||
encoding_chunk_size=4,
|
||||
level=0,
|
||||
max_bag_size=4,
|
||||
max_bag_size_inf=0,
|
||||
# declared in DatasetParams:
|
||||
|
@ -78,6 +83,11 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
|
|||
# declared in TrainerParams:
|
||||
max_epochs=2,
|
||||
crossval_count=1,
|
||||
# Loading parameters
|
||||
level=0,
|
||||
backend=WSIBackend.CUCIM,
|
||||
roi_type=ROIType.FOREGROUND,
|
||||
foreground_threshold=255,
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from monai.transforms import LoadImaged
|
||||
from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader, WSIReader
|
||||
from health_cpath.datasets.default_paths import PANDA_DATASET_ID
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.preprocessing.loading import BaseLoadROId, LoadingParams, ROIType, WSIBackend, LoadROId, LoadMaskROId
|
||||
from health_cpath.scripts.mount_azure_dataset import mount_dataset
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("roi_type", [r for r in ROIType])
|
||||
@pytest.mark.parametrize("backend", [b for b in WSIBackend])
|
||||
def test_get_load_roid_transform(backend: WSIBackend, roi_type: ROIType) -> None:
|
||||
loading_params = LoadingParams(backend=backend, roi_type=roi_type)
|
||||
transform = loading_params.get_load_roid_transform()
|
||||
transform_type = {ROIType.MASK: LoadMaskROId, ROIType.FOREGROUND: LoadROId, ROIType.WHOLE: LoadImaged}
|
||||
assert isinstance(transform, transform_type[roi_type])
|
||||
reader_type = {WSIBackend.CUCIM: CuCIMWSIReader, WSIBackend.OPENSLIDE: OpenSlideWSIReader}
|
||||
if roi_type in [ROIType.MASK, ROIType.FOREGROUND]:
|
||||
assert isinstance(transform, BaseLoadROId)
|
||||
assert isinstance(transform.reader, WSIReader) # type: ignore
|
||||
assert isinstance(transform.reader.reader, reader_type[backend]) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is failing because of issue #655")
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_load_slide(tmp_path: Path) -> None:
|
||||
_ = mount_dataset(dataset_id=PANDA_DATASET_ID, tmp_root=str(tmp_path), aml_workspace=DEFAULT_WORKSPACE.workspace)
|
||||
root_path = tmp_path / PANDA_DATASET_ID
|
||||
|
||||
def _check_load_roi_transforms(
|
||||
backend: WSIBackend, expected_keys: List[SlideKey], expected_shape: Tuple[int, int, int]
|
||||
) -> None:
|
||||
loading_params.backend = backend
|
||||
load_transform = loading_params.get_load_roid_transform()
|
||||
sample = PandaDataset(root_path)[0]
|
||||
slide_dict = load_transform(sample)
|
||||
assert all([k in slide_dict for k in expected_keys])
|
||||
assert slide_dict[SlideKey.IMAGE].shape == expected_shape
|
||||
|
||||
# WSI ROIType
|
||||
loading_params = LoadingParams(roi_type=ROIType.WHOLE, level=2)
|
||||
wsi_expected_keys = [SlideKey.IMAGE, SlideKey.SLIDE_ID]
|
||||
wsi_expected_shape = (3, 1840, 1728)
|
||||
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
|
||||
_check_load_roi_transforms(backend, wsi_expected_keys, wsi_expected_shape)
|
||||
|
||||
# Foreground ROIType
|
||||
loading_params = LoadingParams(roi_type=ROIType.FOREGROUND, level=2)
|
||||
foreground_expected_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD, SlideKey.IMAGE]
|
||||
foreground_expected_shape = (3, 1342, 340)
|
||||
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
|
||||
_check_load_roi_transforms(backend, foreground_expected_keys, foreground_expected_shape)
|
||||
|
||||
# Mask ROI transforms
|
||||
loading_params = LoadingParams(roi_type=ROIType.MASK, level=2)
|
||||
mask_expected_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.IMAGE]
|
||||
mask_expected_shape = (3, 1344, 341)
|
||||
for backend in [WSIBackend.CUCIM, WSIBackend.OPENSLIDE]:
|
||||
_check_load_roi_transforms(backend, mask_expected_keys, mask_expected_shape)
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.distributed
|
||||
import torch.multiprocessing
|
||||
from ruamel.yaml import YAML
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
from torch.testing import assert_close
|
||||
|
@ -31,7 +32,7 @@ def _create_outputs_handler(outputs_root: Path) -> DeepMILOutputsHandler:
|
|||
outputs_root=outputs_root,
|
||||
n_classes=1,
|
||||
tile_size=224,
|
||||
level=1,
|
||||
loading_params=LoadingParams(level=1),
|
||||
class_names=None,
|
||||
primary_val_metric=_PRIMARY_METRIC_KEY,
|
||||
maximise=True,
|
||||
|
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
from typing import Any, Collection, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from health_cpath.preprocessing.loading import LoadingParams, ROIType
|
||||
from health_cpath.utils.naming import PlotOption, ResultsKey
|
||||
from health_cpath.utils.plots_utils import DeepMILPlotsHandler, save_confusion_matrix, save_pr_curve
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TilesSelector
|
||||
|
@ -18,7 +18,15 @@ from testhisto.mocks.container import MockDeepSMILETilesPanda
|
|||
def test_plots_handler_wrong_class_names() -> None:
|
||||
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
|
||||
with pytest.raises(ValueError, match=r"No class_names were provided while activating confusion matrix plotting."):
|
||||
_ = DeepMILPlotsHandler(plot_options, class_names=[])
|
||||
_ = DeepMILPlotsHandler(plot_options, class_names=[], loading_params=LoadingParams())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("roi_type", [r for r in ROIType])
|
||||
def test_plots_handler_always_uses_roid_loading(roi_type: ROIType) -> None:
|
||||
plot_options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
|
||||
loading_params = LoadingParams(roi_type=roi_type)
|
||||
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo", "bar"], loading_params=loading_params)
|
||||
assert plots_handler.loading_params.roi_type in [ROIType.MASK, ROIType.FOREGROUND]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -71,7 +79,7 @@ def assert_plot_func_called_if_among_plot_options(
|
|||
],
|
||||
)
|
||||
def test_plots_handler_plots_only_desired_plot_options(plot_options: Collection[PlotOption]) -> None:
|
||||
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"])
|
||||
plots_handler = DeepMILPlotsHandler(plot_options, class_names=["foo1", "foo2"], loading_params=LoadingParams())
|
||||
plots_handler.slides_dataset = MagicMock()
|
||||
|
||||
n_tiles = 4
|
||||
|
|
|
@ -8,7 +8,6 @@ import math
|
|||
import random
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
@ -25,7 +24,7 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
|
|||
from health_cpath.utils.naming import ResultsKey
|
||||
from health_cpath.utils.heatmap_utils import location_selected_tiles
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
|
||||
from health_cpath.utils.viz_utils import save_figure, load_image_dict
|
||||
from health_cpath.utils.viz_utils import save_figure
|
||||
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
|
||||
|
||||
|
||||
|
@ -295,13 +294,3 @@ def test_location_selected_tiles(level: int) -> None:
|
|||
assert max(tile_xs) <= slide_image.shape[2] // factor
|
||||
assert min(tile_ys) >= 0
|
||||
assert max(tile_ys) <= slide_image.shape[1] // factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["cuCIM", "OpenSlide"])
|
||||
@pytest.mark.parametrize("wsi_has_mask", [True, False])
|
||||
def test_load_image_dict(wsi_has_mask: bool, backend: str) -> None:
|
||||
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
|
||||
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
|
||||
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask, backend=backend)
|
||||
assert mock_load_panda_roi.called == wsi_has_mask
|
||||
assert mock_load_roi.called == (not wsi_has_mask)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from unittest.mock import patch
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
@ -84,3 +85,12 @@ def test_tiling_params(stage: ModelKey) -> None:
|
|||
expected_transform_type = RandGridPatchd if stage == ModelKey.TRAIN else GridPatchd
|
||||
transform = params.get_tiling_transform(stage=stage, bag_size=10)
|
||||
assert isinstance(transform, expected_transform_type)
|
||||
|
||||
|
||||
def test_tiling_params_split_transform() -> None:
|
||||
params = TilingParams()
|
||||
with patch("health_cpath.utils.wsi_utils.SplitDimd") as mock_split_dim:
|
||||
_ = params.get_split_transform()
|
||||
mock_split_dim.assert_called_once()
|
||||
call_args = mock_split_dim.call_args_list[0][1]
|
||||
assert call_args == {'keys': SlideKey.IMAGE, 'dim': 0, 'keepdim': False, 'list_output': True}
|
||||
|
|
Загрузка…
Ссылка в новой задаче