Upgrade monai to 1.0.1 
Note that we now need to convert monai outputs to torch tensors since
Rand/GridPatchd transforms return MetaTensors a new data structure in
the latest releases. This is handled in the collate function
This commit is contained in:
Kenza Bouzid 2022-11-18 14:20:26 +00:00 коммит произвёл GitHub
Родитель 9e812e7ca8
Коммит e68f9b1db7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 166 добавлений и 137 удалений

22
hi-ml-cpath/.vscode/launch.json поставляемый
Просмотреть файл

@ -44,15 +44,22 @@
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py", "program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
"args": [ "args": [
"--model=health_cpath.SlidesPandaImageNetMIL", "--model=health_cpath.SlidesPandaImageNetMIL",
"--pl_fast_dev_run=10", "--pl_limit_train_batches=5",
"--pl_limit_test_batches=5",
"--pl_limit_val_batches=5",
"--max_epochs=2",
"--max_num_gpus=1",
"--crossval_count=0", "--crossval_count=0",
"--batch_size=2", "--batch_size=2",
"--batch_size_inf=2",
"--max_bag_size=4", "--max_bag_size=4",
"--max_bag_size_inf=4", "--max_bag_size_inf=4",
"--num_top_slides=2", "--num_top_slides=2",
"--num_top_tiles=2" "--num_top_tiles=2",
"--strictly_aml_v1=True",
], ],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": false,
}, },
{ {
"name": "Python: Run TilesPandaImageNetMIL locally", "name": "Python: Run TilesPandaImageNetMIL locally",
@ -61,15 +68,22 @@
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py", "program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
"args": [ "args": [
"--model=health_cpath.TilesPandaImageNetMIL", "--model=health_cpath.TilesPandaImageNetMIL",
"--pl_fast_dev_run=10", "--pl_limit_train_batches=5",
"--pl_limit_test_batches=5",
"--pl_limit_val_batches=5",
"--max_epochs=2",
"--max_num_gpus=1",
"--crossval_count=0", "--crossval_count=0",
"--batch_size=2", "--batch_size=2",
"--batch_size_inf=2",
"--max_bag_size=4", "--max_bag_size=4",
"--max_bag_size_inf=4", "--max_bag_size_inf=4",
"--num_top_slides=2", "--num_top_slides=2",
"--num_top_tiles=2" "--num_top_tiles=2",
"--strictly_aml_v1=True",
], ],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": false,
}, },
] ]
} }

Просмотреть файл

@ -41,7 +41,7 @@ dependencies:
- libpng=1.6.37=hbc83047_0 - libpng=1.6.37=hbc83047_0
- libstdcxx-ng=11.2.0=h1234567_1 - libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0 - libtasn1=4.16.0=h27cfd23_0
- libtiff=4.4.0=hecacb30_1 - libtiff=4.4.0=hecacb30_2
- libunistring=0.9.10=h27cfd23_0 - libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0 - libuuid=1.41.5=h5eee18b_0
- libuv=1.40.0=h7b6447c_0 - libuv=1.40.0=h7b6447c_0
@ -92,7 +92,7 @@ dependencies:
- astroid==2.12.12 - astroid==2.12.12
- async-timeout==4.0.2 - async-timeout==4.0.2
- attrs==21.4.0 - attrs==21.4.0
- azure-ai-ml==1.1.0 - azure-ai-ml==1.1.1
- azure-common==1.1.28 - azure-common==1.1.28
- azure-core==1.26.1 - azure-core==1.26.1
- azure-graphrbac==0.61.1 - azure-graphrbac==0.61.1
@ -158,7 +158,7 @@ dependencies:
- google-api-core==2.10.2 - google-api-core==2.10.2
- google-auth==1.35.0 - google-auth==1.35.0
- google-auth-oauthlib==0.4.6 - google-auth-oauthlib==0.4.6
- googleapis-common-protos==1.56.4 - googleapis-common-protos==1.57.0
- greenlet==2.0.1 - greenlet==2.0.1
- grpcio==1.50.0 - grpcio==1.50.0
- gunicorn==20.1.0 - gunicorn==20.1.0
@ -199,7 +199,7 @@ dependencies:
- mdit-py-plugins==0.2.8 - mdit-py-plugins==0.2.8
- mlflow==2.0.1 - mlflow==2.0.1
- mlflow-skinny==2.0.1 - mlflow-skinny==2.0.1
- monai==0.8.0 - monai==1.0.1
- more-itertools==8.10.0 - more-itertools==8.10.0
- msal==1.20.0 - msal==1.20.0
- msal-extensions==0.3.1 - msal-extensions==0.3.1

Просмотреть файл

@ -3,7 +3,7 @@ cucim==22.04.00
girder-client==3.1.14 girder-client==3.1.14
hi-ml hi-ml
lightning-bolts==0.4.0 lightning-bolts==0.4.0
monai==0.8.0 monai==1.0.1
more-itertools==8.10.0 more-itertools==8.10.0
numpy==1.22.0 numpy==1.22.0
pillow==9.0.1 pillow==9.0.1

Просмотреть файл

@ -16,6 +16,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from health_azure.utils import create_from_matching_params from health_azure.utils import create_from_matching_params
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.utils import fixed_paths from health_ml.utils import fixed_paths
from health_ml.deep_learning_config import OptimizerParams from health_ml.deep_learning_config import OptimizerParams
@ -316,27 +317,11 @@ class BaseMILTiles(BaseMIL):
return deepmil_module return deepmil_module
class BaseMILSlides(BaseMIL): class BaseMILSlides(BaseMIL, TilingParams):
"""BaseSlidesMIL is an abstract subclass of BaseMIL for running MIL experiments on slides datasets. It is """BaseSlidesMIL is an abstract subclass of BaseMIL for running MIL experiments on slides datasets. It is
responsible for instantiating the full DeepMIL model in slides settings. Subclasses should define their datamodules responsible for instantiating the full DeepMIL model in slides settings. Subclasses should define their datamodules
and configure experiment-specific parameters. and configure experiment-specific parameters.
""" """
# Slides Data module parameters:
tile_size: int = param.Integer(224, bounds=(0, None), doc="Size of the square tile, defaults to 224.")
step: int = param.Integer(None, bounds=(0, None),
doc="Step size to define the offset between tiles."
"If None (default), it takes the same value as tile_size."
"If step < tile_size, it creates overlapping tiles."
"If step > tile_size, it skips some chunks in the wsi.")
random_offset: bool = param.Boolean(False, doc="If True, randomize position of the grid, instead of starting at"
"the top-left corner,")
pad_full: bool = param.Boolean(False, doc="If True, pad image to the size evenly divisible by tile_size")
background_val: int = param.Integer(255, bounds=(0, None),
doc="Threshold to estimate the foreground in a whole slide image.")
filter_mode: str = param.String("min", doc="mode must be in ['min', 'max', 'random']. If total number of tiles is"
"greater than tile_count, then sort by intensity sum, and take the "
"smallest (for min), largest (for max) or random (for random) subset, "
"defaults to 'min' (which assumes background is high value).")
def create_model(self) -> SlidesDeepMILModule: def create_model(self) -> SlidesDeepMILModule:
self.data_module = self.get_data_module() self.data_module = self.get_data_module()

Просмотреть файл

@ -3,25 +3,17 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
from typing import Any, Optional, Set from typing import Any, Optional, Set
from health_azure.utils import is_running_in_azure_ml, create_from_matching_params
from health_azure.utils import is_running_in_azure_ml from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILSlides, BaseMILTiles
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
from health_cpath.datamodules.panda_module import ( from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
PandaSlidesDataModule, from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID
PandaTilesDataModule)
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
from health_cpath.models.encoders import (
HistoSSLEncoder,
ImageNetSimCLREncoder,
Resnet18,
SSLEncoder)
from health_cpath.configs.classification.BaseMIL import BaseMILSlides, BaseMILTiles, BaseMIL
from health_cpath.datasets.panda_dataset import PandaDataset from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datasets.default_paths import ( from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
PANDA_DATASET_ID, from health_cpath.models.encoders import HistoSSLEncoder, ImageNetSimCLREncoder, Resnet18, SSLEncoder
PANDA_5X_TILES_DATASET_ID)
from health_cpath.utils.naming import PlotOption from health_cpath.utils.naming import PlotOption
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_ml.utils.checkpoint_utils import CheckpointParser from health_ml.utils.checkpoint_utils import CheckpointParser
@ -132,7 +124,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
# declared in BaseMILSlides: # declared in BaseMILSlides:
level=1, level=1,
tile_size=224, tile_size=224,
random_offset=True,
background_val=255, background_val=255,
azure_datasets=[PANDA_DATASET_ID],) azure_datasets=[PANDA_DATASET_ID],)
default_kwargs.update(kwargs) default_kwargs.update(kwargs)
@ -152,16 +143,11 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
root_path=self.local_datasets[0], root_path=self.local_datasets[0],
batch_size=self.batch_size, batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf, batch_size_inf=self.batch_size_inf,
level=self.level,
max_bag_size=self.max_bag_size, max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf, max_bag_size_inf=self.max_bag_size_inf,
tile_size=self.tile_size, level=self.level,
step=self.step, tiling_params=create_from_matching_params(self, TilingParams),
random_offset=self.random_offset,
seed=self.get_effective_random_seed(), seed=self.get_effective_random_seed(),
pad_full=self.pad_full,
background_val=self.background_val,
filter_mode=self.filter_mode,
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN), transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
crossval_count=self.crossval_count, crossval_count=self.crossval_count,
crossval_index=self.crossval_index, crossval_index=self.crossval_index,

Просмотреть файл

@ -9,6 +9,7 @@ from torch import optim
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
from health_azure.utils import create_from_matching_params from health_azure.utils import create_from_matching_params
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.networks.layers.attention_layers import ( from health_ml.networks.layers.attention_layers import (
TransformerPooling, TransformerPooling,
TransformerPoolingBenchmark TransformerPoolingBenchmark
@ -107,13 +108,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
max_bag_size=self.max_bag_size, max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf, max_bag_size_inf=self.max_bag_size_inf,
level=self.level, level=self.level,
tile_size=self.tile_size, tiling_params=create_from_matching_params(self, TilingParams),
step=self.step,
random_offset=self.random_offset,
seed=self.get_effective_random_seed(), seed=self.get_effective_random_seed(),
pad_full=self.pad_full,
background_val=self.background_val,
filter_mode=self.filter_mode,
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN), transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
crossval_count=self.crossval_count, crossval_count=self.crossval_count,
crossval_index=self.crossval_index, crossval_index=self.crossval_index,

Просмотреть файл

@ -8,21 +8,19 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from health_ml.utils.bag_utils import BagDataset, multibag_collate from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator from health_ml.utils.common_utils import _create_generator
from health_cpath.utils.wsi_utils import image_collate from health_cpath.utils.wsi_utils import TilingParams, image_collate
from health_cpath.models.transforms import LoadTilesBatchd from health_cpath.models.transforms import LoadTilesBatchd
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
from health_cpath.utils.naming import ModelKey from health_cpath.utils.naming import ModelKey, SlideKey
from monai.transforms.compose import Compose from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from monai.transforms.io.dictionary import LoadImaged from monai.transforms import Compose, LoadImaged, SplitDimd
from monai.apps.pathology.transforms import TileOnGridd
from monai.data.image_reader import WSIReader from monai.data.image_reader import WSIReader
@ -65,12 +63,12 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size. :param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default), :param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
tile_count used for tiling on the fly at training time. num_patches used for tiling on the fly at training time.
:param max_bag_size_inf: Upper bound on number of tiles in each loaded bag during validation and test stages. :param max_bag_size_inf: Upper bound on number of tiles in each loaded bag during validation and test stages.
If 0 (default), will return all samples in each bag. If > 0 , bags larger than `max_bag_size_inf` will yield If 0 (default), will return all samples in each bag. If > 0 , bags larger than `max_bag_size_inf` will yield
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
tile_count used for tiling on the fly at validation and test time. num_patches used for tiling on the fly at validation and test time.
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in :param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
train/val/test splits is handled independently in `get_splits()`. (default: `None`) train/val/test splits is handled independently in `get_splits()`. (default: `None`)
:param transforms_dict: A dictionary that contains transform, or a composition of transforms using :param transforms_dict: A dictionary that contains transform, or a composition of transforms using
@ -276,61 +274,31 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
def __init__( def __init__(
self, self,
level: int = 1, level: int = 1,
tile_size: int = 224,
step: Optional[int] = None,
random_offset: bool = True,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = 'min',
backend: str = 'cuCIM', backend: str = 'cuCIM',
wsi_reader_args: Dict[str, Any] = {}, wsi_reader_args: Dict[str, Any] = {},
tiling_params: TilingParams = TilingParams(),
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
:param level: the whole slide image level at which the image is extracted, defaults to 1 :param level: the whole slide image level at which the image is extracted, defaults to 1
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
:param tile_size: size of the square tile, defaults to 224 :param backend: the WSI reader backend, defaults to "cuCIM".
this param is passed to TileOnGridd monai transform for tiling on the fly. :param wsi_reader_args: additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in
the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param random_offset: randomize position of the grid, instead of starting from the top-left corner,
defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param pad_full: pad image to the size evenly divisible by tile_size, defaults to False
This param is passed to TileOnGridd monai transform for tiling on the fly.
:param background_val: the background constant to ignore background tiles (e.g. 255 for white background),
defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
monai transform for tiling on the fly.
:param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform
:param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only. enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
:param tiling_params: the tiling on the fly parameters, defaults to TileOnTheFlyParams()
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
# Tiling on the fly params self.tiling_params = tiling_params
self.tile_size = tile_size
self.step = step
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode
# WSIReader params # WSIReader params
self.level = level self.level = level
self.backend = backend self.backend = backend
self.wsi_reader_args = wsi_reader_args self.wsi_reader_args = wsi_reader_args
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
# max_bag_size_inf to None if set to 0
for stage_key, max_bag_size in self.bag_sizes.items():
if max_bag_size == 0:
self.bag_sizes[stage_key] = None # type: ignore
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset: def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
base_transform = Compose( base_transform = Compose(
[ [
LoadImaged( LoadImaged(
keys=slides_dataset.IMAGE_COLUMN, keys=SlideKey.IMAGE,
reader=WSIReader, reader=WSIReader,
dtype=np.uint8, dtype=np.uint8,
image_only=True, image_only=True,
@ -338,17 +306,10 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
backend=self.backend, backend=self.backend,
**self.wsi_reader_args, **self.wsi_reader_args,
), ),
TileOnGridd( self.tiling_params.get_tiling_transform(bag_size=self.bag_sizes[stage], stage=stage),
keys=slides_dataset.IMAGE_COLUMN, # GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
tile_count=self.bag_sizes[stage], # tiles to be able to apply augmentations on each tile independently
tile_size=self.tile_size, SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True),
step=self.step,
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
pad_full=self.pad_full,
background_val=self.background_val,
filter_mode=self.filter_mode,
return_list_of_dicts=True,
),
] ]
) )
if self.transforms_dict and self.transforms_dict[stage]: if self.transforms_dict and self.transforms_dict[stage]:

Просмотреть файл

@ -15,6 +15,7 @@ from math import ceil
from pathlib import Path from pathlib import Path
from typing import Sequence, List, Any, Dict, Optional, Union, Tuple from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
from monai.data.meta_tensor import MetaTensor
from monai.data.dataset import Dataset from monai.data.dataset import Dataset
from monai.data.image_reader import WSIReader from monai.data.image_reader import WSIReader
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -47,6 +48,9 @@ def load_image_dict(
transform = LoadPandaROId if wsi_has_mask else LoadROId transform = LoadPandaROId if wsi_has_mask else LoadROId
loader = transform(WSIReader(backend=backend), level=level, margin=margin) loader = transform(WSIReader(backend=backend), level=level, margin=margin)
img = loader(sample) img = loader(sample)
if isinstance(img[SlideKey.IMAGE], MetaTensor):
# New monai transforms return a MetaTensor, we need to convert it to a numpy array for backward compatibility
img[SlideKey.IMAGE] = img[SlideKey.IMAGE].numpy()
return img return img

Просмотреть файл

@ -1,9 +1,12 @@
import torch import torch
import param
import numpy as np import numpy as np
from typing import Any, List from typing import Any, Callable, List, Optional
from health_cpath.utils.naming import SlideKey from health_cpath.utils.naming import ModelKey, SlideKey
from health_ml.utils.bag_utils import multibag_collate from health_ml.utils.bag_utils import multibag_collate
from monai.data.meta_tensor import MetaTensor
from monai.transforms import RandGridPatchd, GridPatchd
def image_collate(batch: List) -> Any: def image_collate(batch: List) -> Any:
@ -11,15 +14,88 @@ def image_collate(batch: List) -> Any:
Combine instances from a list of dicts into a single dict, by stacking them along first dim Combine instances from a list of dicts into a single dict, by stacking them along first dim
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW} [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
followed by the default collate which will form a batch BxNx3xHxW. followed by the default collate which will form a batch BxNx3xHxW.
The list of dicts refers to the the list of tiles produced by the TileOnGridd transform applied on a WSI. The list of dicts refers to the the list of tiles produced by the Rand/GridPatchd transform applied on a WSI.
""" """
for i, item in enumerate(batch): for i, item in enumerate(batch):
data = item[0] data = item[0]
if isinstance(data[SlideKey.IMAGE], torch.Tensor): if isinstance(data[SlideKey.IMAGE], MetaTensor):
# MetaTensor is a monai class that is used to store metadata along with the image
# We need to convert it to torch tensor to avoid adding the metadata to the batch
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE].as_tensor() for ix in item], dim=0)
elif isinstance(data[SlideKey.IMAGE], torch.Tensor):
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0) data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
else: else:
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item])) data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL]) data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
batch[i] = data batch[i] = data
return multibag_collate(batch) return multibag_collate(batch)
class TilingParams(param.Parameterized):
"""Parameters for Tiling On the Fly a WSI using RandGridPatchd and GridPatchd monai transforms"""
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="The size of the tile, Default: 224")
tile_overlap: int = param.Number(
default=0,
bounds=(0.0, 1.0),
doc="The amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).")
tile_sort_fn: Optional[str] = param.String(
default='min',
doc="When bag_size is fixed, it determines whether to keep tiles with highest intensity values (`'max'`), "
"lowest values (`'min'`) that assumes background is high values, or in their default order (`None`). ")
tile_pad_mode: Optional[str] = param.String(
default=None,
doc="The mode of padding, refer to NumpyPadMode and PytorchPadMode. Defaults to None, for no padding.")
intensity_threshold: float = param.Number(
default=255.,
doc="The intensity threshold to filter out tiles based on intensity values. Default to None.")
background_val: int = param.Integer(
default=255,
doc="The intensity value of background. Default to 255.")
rand_min_offset: int = param.Integer(
default=0,
bounds=(0, None),
doc="The minimum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
"the random version of RandGridPatchd used at training time. Default to 0.")
rand_max_offset: int = param.Integer(
default=None,
bounds=(0, None),
doc="The maximum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
"the random version of RandGridPatchd used at training time. Default to None.")
inf_offset: Optional[int] = param.Integer(
default=None,
doc="The offset to be used for inference sampling. This parameter is passed to GridPatchd. Default to None.")
@property
def scaled_threshold(self) -> float:
"""Returns the threshold to be used for filtering out tiles based on intensity values. We need to multiply
the threshold by the tile size to account for the fact that the intensity is computed on the entire tile"""
return 0.999 * 3 * self.intensity_threshold * self.tile_size * self.tile_size
def get_tiling_transform(self, bag_size: int, stage: ModelKey,) -> Callable:
if stage == ModelKey.TRAIN:
return RandGridPatchd(
keys=[SlideKey.IMAGE],
patch_size=(self.tile_size, self.tile_size),
min_offset=self.rand_min_offset,
max_offset=self.rand_max_offset,
num_patches=bag_size,
overlap=self.tile_overlap,
sort_fn=self.tile_sort_fn,
threshold=self.scaled_threshold,
pad_mode=self.tile_pad_mode, # type: ignore
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
)
else:
return GridPatchd(
keys=[SlideKey.IMAGE],
patch_size=(self.tile_size, self.tile_size),
offset=self.inf_offset, # type: ignore
num_patches=bag_size,
overlap=self.tile_overlap,
sort_fn=self.tile_sort_fn,
threshold=self.scaled_threshold,
pad_mode=self.tile_pad_mode, # type: ignore
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
)

Просмотреть файл

@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ
from health_cpath.datamodules.base_module import HistoDataModule from health_cpath.datamodules.base_module import HistoDataModule
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
from health_cpath.utils.naming import ModelKey, SlideKey from health_cpath.utils.naming import ModelKey, SlideKey
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.utils.common_utils import is_gpu_available from health_ml.utils.common_utils import is_gpu_available
from testhisto.utils.utils_testhisto import run_distributed from testhisto.utils.utils_testhisto import run_distributed
@ -62,11 +63,9 @@ def test_slides_datamodule_different_bag_sizes(
batch_size=2, batch_size=2,
max_bag_size=max_bag_size, max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf, max_bag_size_inf=max_bag_size_inf,
tile_size=28, tiling_params=TilingParams(tile_size=28),
level=0, level=0,
) )
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform # For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4]) _assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
@ -98,7 +97,7 @@ def test_slides_datamodule_different_batch_sizes(
batch_size_inf=batch_size_inf, batch_size_inf=batch_size_inf,
max_bag_size=16, max_bag_size=16,
max_bag_size_inf=16, max_bag_size_inf=16,
tile_size=28, tiling_params=TilingParams(tile_size=28),
level=0, level=0,
) )
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf) _assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)

Просмотреть файл

@ -12,6 +12,7 @@ from pathlib import Path
from monai.transforms import RandFlipd from monai.transforms import RandFlipd
from typing import Generator, Dict, Callable, Union, Tuple from typing import Generator, Dict, Callable, Union, Tuple
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.utils.common_utils import is_gpu_available from health_ml.utils.common_utils import is_gpu_available
from health_cpath.datamodules.base_module import SlidesDataModule from health_cpath.datamodules.base_module import SlidesDataModule
@ -94,7 +95,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
root_path=mock_panda_slides_root_dir_diagonal, root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size, batch_size=batch_size,
max_bag_size=tile_count, max_bag_size=tile_count,
tile_size=tile_size, tiling_params=TilingParams(tile_size=28),
level=level, level=level,
) )
dataloader = datamodule.train_dataloader() dataloader = datamodule.train_dataloader()
@ -115,7 +116,6 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None: def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1 batch_size = 1
tile_count = None tile_count = None
tile_size = 28
level = 0 level = 0
assert_batch_index = 0 assert_batch_index = 0
min_expected_tile_count = 16 min_expected_tile_count = 16
@ -123,7 +123,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Pa
root_path=mock_panda_slides_root_dir_diagonal, root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size, batch_size=batch_size,
max_bag_size=tile_count, max_bag_size=tile_count,
tile_size=tile_size, tiling_params=TilingParams(tile_size=28),
level=level, level=level,
) )
dataloader = datamodule.train_dataloader() dataloader = datamodule.train_dataloader()
@ -145,7 +145,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
root_path=mock_panda_slides_root_dir_diagonal, root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size, batch_size=batch_size,
max_bag_size=tile_count, max_bag_size=tile_count,
tile_size=tile_size, tiling_params=TilingParams(tile_size=tile_size),
level=level, level=level,
) )
dataloader = datamodule.train_dataloader() dataloader = datamodule.train_dataloader()
@ -165,9 +165,8 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
@pytest.mark.gpu @pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("batch_size", [1, 2])
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None: def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
tile_size = 28
level = 0 level = 0
step = 14 overlap = .5
expected_tile_matches = 16 expected_tile_matches = 16
min_expected_tile_count = 32 min_expected_tile_count = 32
assert_batch_index = 0 assert_batch_index = 0
@ -175,8 +174,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal:
root_path=mock_panda_slides_root_dir_diagonal, root_path=mock_panda_slides_root_dir_diagonal,
max_bag_size=None, max_bag_size=None,
batch_size=batch_size, batch_size=batch_size,
tile_size=tile_size, tiling_params=TilingParams(tile_size=28, tile_overlap=overlap),
step=step,
level=level level=level
) )
dataloader = datamodule.train_dataloader() dataloader = datamodule.train_dataloader()
@ -208,14 +206,13 @@ def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> Non
batch_size = 1 batch_size = 1
tile_count = 4 tile_count = 4
tile_size = 28
level = 0 level = 0
flipdatamodule = PandaSlidesDataModule( flipdatamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir_diagonal, root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size, batch_size=batch_size,
max_bag_size=tile_count, max_bag_size=tile_count,
max_bag_size_inf=0, max_bag_size_inf=0,
tile_size=tile_size, tiling_params=TilingParams(tile_size=28),
level=level, level=level,
transforms_dict=get_transforms_dict(), transforms_dict=get_transforms_dict(),
) )
@ -259,7 +256,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
@pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("batch_size", [1, 2])
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None: def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
tile_count = 2 tile_count = 2
tile_size = 28
level = 0 level = 0
assert_batch_index = 0 assert_batch_index = 0
n_tiles_list = [4, 5, 6, 7, 8, 9] n_tiles_list = [4, 5, 6, 7, 8, 9]
@ -269,7 +265,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
batch_size=batch_size, batch_size=batch_size,
max_bag_size=tile_count, max_bag_size=tile_count,
max_bag_size_inf=0, max_bag_size_inf=0,
tile_size=tile_size, tiling_params=TilingParams(tile_size=28),
level=level, level=level,
) )
train_dataloader = datamodule.train_dataloader() train_dataloader = datamodule.train_dataloader()

Просмотреть файл

@ -2,10 +2,12 @@ import torch
import pytest import pytest
import numpy as np import numpy as np
from health_cpath.utils.naming import ModelKey, SlideKey
from health_cpath.utils.wsi_utils import TilingParams, image_collate
from monai.data.meta_tensor import MetaTensor
from monai.transforms import RandGridPatchd, GridPatchd
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from typing import Sequence from typing import Sequence
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.wsi_utils import image_collate
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -35,8 +37,10 @@ class MockTiledWSIDataset(Dataset):
img: Union[np.ndarray, torch.Tensor] img: Union[np.ndarray, torch.Tensor]
if self.img_type == "np": if self.img_type == "np":
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size)) img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
else: elif self.img_type == "torch":
img = torch.randint(0, 255, size=(tile_count, *self.tile_size)) img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
elif self.img_type == "metatensor":
img = MetaTensor(torch.randint(0, 255, size=(tile_count, *self.tile_size)))
return [{SlideKey.SLIDE_ID: self.slide_ids[index], return [{SlideKey.SLIDE_ID: self.slide_ids[index],
SlideKey.IMAGE: img, SlideKey.IMAGE: img,
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff", SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
@ -45,7 +49,7 @@ class MockTiledWSIDataset(Dataset):
] ]
@pytest.mark.parametrize("img_type", ["np", "torch"]) @pytest.mark.parametrize("img_type", ["np", "torch", "metatensor"])
@pytest.mark.parametrize("random_n_tiles", [False, True]) @pytest.mark.parametrize("random_n_tiles", [False, True])
def test_image_collate(random_n_tiles: bool, img_type: str) -> None: def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during # random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
@ -72,3 +76,11 @@ def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
assert all((value_list[idx] == samples_list[idx][key]) for idx in range(batch_size)) assert all((value_list[idx] == samples_list[idx][key]) for idx in range(batch_size))
else: else:
assert all(torch.equal(value_list[idx], samples_list[idx][key]) for idx in range(batch_size)) assert all(torch.equal(value_list[idx], samples_list[idx][key]) for idx in range(batch_size))
@pytest.mark.parametrize("stage", [m for m in ModelKey])
def test_tiling_params(stage: ModelKey) -> None:
params = TilingParams()
expected_transform_type = RandGridPatchd if stage == ModelKey.TRAIN else GridPatchd
transform = params.get_tiling_transform(stage=stage, bag_size=10)
assert isinstance(transform, expected_transform_type)