зеркало из https://github.com/microsoft/hi-ml.git
ENH: Upgrade monai to 1.0.1 (#668)
Upgrade monai to 1.0.1 Note that we now need to convert monai outputs to torch tensors since Rand/GridPatchd transforms return MetaTensors a new data structure in the latest releases. This is handled in the collate function
This commit is contained in:
Родитель
9e812e7ca8
Коммит
e68f9b1db7
|
@ -44,15 +44,22 @@
|
|||
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
||||
"args": [
|
||||
"--model=health_cpath.SlidesPandaImageNetMIL",
|
||||
"--pl_fast_dev_run=10",
|
||||
"--pl_limit_train_batches=5",
|
||||
"--pl_limit_test_batches=5",
|
||||
"--pl_limit_val_batches=5",
|
||||
"--max_epochs=2",
|
||||
"--max_num_gpus=1",
|
||||
"--crossval_count=0",
|
||||
"--batch_size=2",
|
||||
"--batch_size_inf=2",
|
||||
"--max_bag_size=4",
|
||||
"--max_bag_size_inf=4",
|
||||
"--num_top_slides=2",
|
||||
"--num_top_tiles=2"
|
||||
"--num_top_tiles=2",
|
||||
"--strictly_aml_v1=True",
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
},
|
||||
{
|
||||
"name": "Python: Run TilesPandaImageNetMIL locally",
|
||||
|
@ -61,15 +68,22 @@
|
|||
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
||||
"args": [
|
||||
"--model=health_cpath.TilesPandaImageNetMIL",
|
||||
"--pl_fast_dev_run=10",
|
||||
"--pl_limit_train_batches=5",
|
||||
"--pl_limit_test_batches=5",
|
||||
"--pl_limit_val_batches=5",
|
||||
"--max_epochs=2",
|
||||
"--max_num_gpus=1",
|
||||
"--crossval_count=0",
|
||||
"--batch_size=2",
|
||||
"--batch_size_inf=2",
|
||||
"--max_bag_size=4",
|
||||
"--max_bag_size_inf=4",
|
||||
"--num_top_slides=2",
|
||||
"--num_top_tiles=2"
|
||||
"--num_top_tiles=2",
|
||||
"--strictly_aml_v1=True",
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ dependencies:
|
|||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.4.0=hecacb30_1
|
||||
- libtiff=4.4.0=hecacb30_2
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuuid=1.41.5=h5eee18b_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
|
@ -92,7 +92,7 @@ dependencies:
|
|||
- astroid==2.12.12
|
||||
- async-timeout==4.0.2
|
||||
- attrs==21.4.0
|
||||
- azure-ai-ml==1.1.0
|
||||
- azure-ai-ml==1.1.1
|
||||
- azure-common==1.1.28
|
||||
- azure-core==1.26.1
|
||||
- azure-graphrbac==0.61.1
|
||||
|
@ -158,7 +158,7 @@ dependencies:
|
|||
- google-api-core==2.10.2
|
||||
- google-auth==1.35.0
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- googleapis-common-protos==1.56.4
|
||||
- googleapis-common-protos==1.57.0
|
||||
- greenlet==2.0.1
|
||||
- grpcio==1.50.0
|
||||
- gunicorn==20.1.0
|
||||
|
@ -199,7 +199,7 @@ dependencies:
|
|||
- mdit-py-plugins==0.2.8
|
||||
- mlflow==2.0.1
|
||||
- mlflow-skinny==2.0.1
|
||||
- monai==0.8.0
|
||||
- monai==1.0.1
|
||||
- more-itertools==8.10.0
|
||||
- msal==1.20.0
|
||||
- msal-extensions==0.3.1
|
||||
|
|
|
@ -3,7 +3,7 @@ cucim==22.04.00
|
|||
girder-client==3.1.14
|
||||
hi-ml
|
||||
lightning-bolts==0.4.0
|
||||
monai==0.8.0
|
||||
monai==1.0.1
|
||||
more-itertools==8.10.0
|
||||
numpy==1.22.0
|
||||
pillow==9.0.1
|
||||
|
|
|
@ -16,6 +16,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
|||
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
|
@ -316,27 +317,11 @@ class BaseMILTiles(BaseMIL):
|
|||
return deepmil_module
|
||||
|
||||
|
||||
class BaseMILSlides(BaseMIL):
|
||||
class BaseMILSlides(BaseMIL, TilingParams):
|
||||
"""BaseSlidesMIL is an abstract subclass of BaseMIL for running MIL experiments on slides datasets. It is
|
||||
responsible for instantiating the full DeepMIL model in slides settings. Subclasses should define their datamodules
|
||||
and configure experiment-specific parameters.
|
||||
"""
|
||||
# Slides Data module parameters:
|
||||
tile_size: int = param.Integer(224, bounds=(0, None), doc="Size of the square tile, defaults to 224.")
|
||||
step: int = param.Integer(None, bounds=(0, None),
|
||||
doc="Step size to define the offset between tiles."
|
||||
"If None (default), it takes the same value as tile_size."
|
||||
"If step < tile_size, it creates overlapping tiles."
|
||||
"If step > tile_size, it skips some chunks in the wsi.")
|
||||
random_offset: bool = param.Boolean(False, doc="If True, randomize position of the grid, instead of starting at"
|
||||
"the top-left corner,")
|
||||
pad_full: bool = param.Boolean(False, doc="If True, pad image to the size evenly divisible by tile_size")
|
||||
background_val: int = param.Integer(255, bounds=(0, None),
|
||||
doc="Threshold to estimate the foreground in a whole slide image.")
|
||||
filter_mode: str = param.String("min", doc="mode must be in ['min', 'max', 'random']. If total number of tiles is"
|
||||
"greater than tile_count, then sort by intensity sum, and take the "
|
||||
"smallest (for min), largest (for max) or random (for random) subset, "
|
||||
"defaults to 'min' (which assumes background is high value).")
|
||||
|
||||
def create_model(self) -> SlidesDeepMILModule:
|
||||
self.data_module = self.get_data_module()
|
||||
|
|
|
@ -3,25 +3,17 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any, Optional, Set
|
||||
|
||||
from health_azure.utils import is_running_in_azure_ml
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_azure.utils import is_running_in_azure_ml, create_from_matching_params
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILSlides, BaseMILTiles
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
from health_cpath.datamodules.panda_module import (
|
||||
PandaSlidesDataModule,
|
||||
PandaTilesDataModule)
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.models.encoders import (
|
||||
HistoSSLEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
Resnet18,
|
||||
SSLEncoder)
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMILSlides, BaseMILTiles, BaseMIL
|
||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from health_cpath.datasets.default_paths import (
|
||||
PANDA_DATASET_ID,
|
||||
PANDA_5X_TILES_DATASET_ID)
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from health_cpath.models.encoders import HistoSSLEncoder, ImageNetSimCLREncoder, Resnet18, SSLEncoder
|
||||
from health_cpath.utils.naming import PlotOption
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
||||
|
@ -132,7 +124,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
# declared in BaseMILSlides:
|
||||
level=1,
|
||||
tile_size=224,
|
||||
random_offset=True,
|
||||
background_val=255,
|
||||
azure_datasets=[PANDA_DATASET_ID],)
|
||||
default_kwargs.update(kwargs)
|
||||
|
@ -152,16 +143,11 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
root_path=self.local_datasets[0],
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
level=self.level,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
tile_size=self.tile_size,
|
||||
step=self.step,
|
||||
random_offset=self.random_offset,
|
||||
level=self.level,
|
||||
tiling_params=create_from_matching_params(self, TilingParams),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pad_full=self.pad_full,
|
||||
background_val=self.background_val,
|
||||
filter_mode=self.filter_mode,
|
||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch import optim
|
|||
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
TransformerPooling,
|
||||
TransformerPoolingBenchmark
|
||||
|
@ -107,13 +108,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
step=self.step,
|
||||
random_offset=self.random_offset,
|
||||
tiling_params=create_from_matching_params(self, TilingParams),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pad_full=self.pad_full,
|
||||
background_val=self.background_val,
|
||||
filter_mode=self.filter_mode,
|
||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
|
|
|
@ -8,21 +8,19 @@ from enum import Enum
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
|
||||
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||
from health_ml.utils.common_utils import _create_generator
|
||||
|
||||
from health_cpath.utils.wsi_utils import image_collate
|
||||
from health_cpath.utils.wsi_utils import TilingParams, image_collate
|
||||
from health_cpath.models.transforms import LoadTilesBatchd
|
||||
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
|
||||
from health_cpath.utils.naming import ModelKey
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
|
||||
from monai.transforms.compose import Compose
|
||||
from monai.transforms.io.dictionary import LoadImaged
|
||||
from monai.apps.pathology.transforms import TileOnGridd
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from monai.transforms import Compose, LoadImaged, SplitDimd
|
||||
from monai.data.image_reader import WSIReader
|
||||
|
||||
|
||||
|
@ -65,12 +63,12 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
|
||||
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
|
||||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
||||
tile_count used for tiling on the fly at training time.
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
|
||||
num_patches used for tiling on the fly at training time.
|
||||
:param max_bag_size_inf: Upper bound on number of tiles in each loaded bag during validation and test stages.
|
||||
If 0 (default), will return all samples in each bag. If > 0 , bags larger than `max_bag_size_inf` will yield
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
||||
tile_count used for tiling on the fly at validation and test time.
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
|
||||
num_patches used for tiling on the fly at validation and test time.
|
||||
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
|
||||
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
|
||||
:param transforms_dict: A dictionary that contains transform, or a composition of transforms using
|
||||
|
@ -276,61 +274,31 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
def __init__(
|
||||
self,
|
||||
level: int = 1,
|
||||
tile_size: int = 224,
|
||||
step: Optional[int] = None,
|
||||
random_offset: bool = True,
|
||||
pad_full: bool = False,
|
||||
background_val: int = 255,
|
||||
filter_mode: str = 'min',
|
||||
backend: str = 'cuCIM',
|
||||
wsi_reader_args: Dict[str, Any] = {},
|
||||
tiling_params: TilingParams = TilingParams(),
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
:param level: the whole slide image level at which the image is extracted, defaults to 1
|
||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
|
||||
:param tile_size: size of the square tile, defaults to 224
|
||||
this param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
|
||||
Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in
|
||||
the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param random_offset: randomize position of the grid, instead of starting from the top-left corner,
|
||||
defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param pad_full: pad image to the size evenly divisible by tile_size, defaults to False
|
||||
This param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param background_val: the background constant to ignore background tiles (e.g. 255 for white background),
|
||||
defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than
|
||||
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
|
||||
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
|
||||
monai transform for tiling on the fly.
|
||||
:param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform
|
||||
:param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
||||
:param backend: the WSI reader backend, defaults to "cuCIM".
|
||||
:param wsi_reader_args: additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
||||
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
||||
:param tiling_params: the tiling on the fly parameters, defaults to TileOnTheFlyParams()
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
# Tiling on the fly params
|
||||
self.tile_size = tile_size
|
||||
self.step = step
|
||||
self.random_offset = random_offset
|
||||
self.pad_full = pad_full
|
||||
self.background_val = background_val
|
||||
self.filter_mode = filter_mode
|
||||
self.tiling_params = tiling_params
|
||||
# WSIReader params
|
||||
self.level = level
|
||||
self.backend = backend
|
||||
self.wsi_reader_args = wsi_reader_args
|
||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
|
||||
# max_bag_size_inf to None if set to 0
|
||||
for stage_key, max_bag_size in self.bag_sizes.items():
|
||||
if max_bag_size == 0:
|
||||
self.bag_sizes[stage_key] = None # type: ignore
|
||||
|
||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
||||
base_transform = Compose(
|
||||
[
|
||||
LoadImaged(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
keys=SlideKey.IMAGE,
|
||||
reader=WSIReader,
|
||||
dtype=np.uint8,
|
||||
image_only=True,
|
||||
|
@ -338,17 +306,10 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
backend=self.backend,
|
||||
**self.wsi_reader_args,
|
||||
),
|
||||
TileOnGridd(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
tile_count=self.bag_sizes[stage],
|
||||
tile_size=self.tile_size,
|
||||
step=self.step,
|
||||
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
|
||||
pad_full=self.pad_full,
|
||||
background_val=self.background_val,
|
||||
filter_mode=self.filter_mode,
|
||||
return_list_of_dicts=True,
|
||||
),
|
||||
self.tiling_params.get_tiling_transform(bag_size=self.bag_sizes[stage], stage=stage),
|
||||
# GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
|
||||
# tiles to be able to apply augmentations on each tile independently
|
||||
SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True),
|
||||
]
|
||||
)
|
||||
if self.transforms_dict and self.transforms_dict[stage]:
|
||||
|
|
|
@ -15,6 +15,7 @@ from math import ceil
|
|||
from pathlib import Path
|
||||
from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
||||
|
||||
from monai.data.meta_tensor import MetaTensor
|
||||
from monai.data.dataset import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -47,6 +48,9 @@ def load_image_dict(
|
|||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
|
||||
img = loader(sample)
|
||||
if isinstance(img[SlideKey.IMAGE], MetaTensor):
|
||||
# New monai transforms return a MetaTensor, we need to convert it to a numpy array for backward compatibility
|
||||
img[SlideKey.IMAGE] = img[SlideKey.IMAGE].numpy()
|
||||
return img
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import torch
|
||||
import param
|
||||
import numpy as np
|
||||
|
||||
from typing import Any, List
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from typing import Any, Callable, List, Optional
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_ml.utils.bag_utils import multibag_collate
|
||||
from monai.data.meta_tensor import MetaTensor
|
||||
from monai.transforms import RandGridPatchd, GridPatchd
|
||||
|
||||
|
||||
def image_collate(batch: List) -> Any:
|
||||
|
@ -11,15 +14,88 @@ def image_collate(batch: List) -> Any:
|
|||
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
||||
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
||||
followed by the default collate which will form a batch BxNx3xHxW.
|
||||
The list of dicts refers to the the list of tiles produced by the TileOnGridd transform applied on a WSI.
|
||||
The list of dicts refers to the the list of tiles produced by the Rand/GridPatchd transform applied on a WSI.
|
||||
"""
|
||||
|
||||
for i, item in enumerate(batch):
|
||||
data = item[0]
|
||||
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
||||
if isinstance(data[SlideKey.IMAGE], MetaTensor):
|
||||
# MetaTensor is a monai class that is used to store metadata along with the image
|
||||
# We need to convert it to torch tensor to avoid adding the metadata to the batch
|
||||
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE].as_tensor() for ix in item], dim=0)
|
||||
elif isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
||||
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
|
||||
else:
|
||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
|
||||
batch[i] = data
|
||||
return multibag_collate(batch)
|
||||
|
||||
|
||||
class TilingParams(param.Parameterized):
|
||||
"""Parameters for Tiling On the Fly a WSI using RandGridPatchd and GridPatchd monai transforms"""
|
||||
|
||||
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="The size of the tile, Default: 224")
|
||||
tile_overlap: int = param.Number(
|
||||
default=0,
|
||||
bounds=(0.0, 1.0),
|
||||
doc="The amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).")
|
||||
tile_sort_fn: Optional[str] = param.String(
|
||||
default='min',
|
||||
doc="When bag_size is fixed, it determines whether to keep tiles with highest intensity values (`'max'`), "
|
||||
"lowest values (`'min'`) that assumes background is high values, or in their default order (`None`). ")
|
||||
tile_pad_mode: Optional[str] = param.String(
|
||||
default=None,
|
||||
doc="The mode of padding, refer to NumpyPadMode and PytorchPadMode. Defaults to None, for no padding.")
|
||||
intensity_threshold: float = param.Number(
|
||||
default=255.,
|
||||
doc="The intensity threshold to filter out tiles based on intensity values. Default to None.")
|
||||
background_val: int = param.Integer(
|
||||
default=255,
|
||||
doc="The intensity value of background. Default to 255.")
|
||||
rand_min_offset: int = param.Integer(
|
||||
default=0,
|
||||
bounds=(0, None),
|
||||
doc="The minimum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
|
||||
"the random version of RandGridPatchd used at training time. Default to 0.")
|
||||
rand_max_offset: int = param.Integer(
|
||||
default=None,
|
||||
bounds=(0, None),
|
||||
doc="The maximum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
|
||||
"the random version of RandGridPatchd used at training time. Default to None.")
|
||||
inf_offset: Optional[int] = param.Integer(
|
||||
default=None,
|
||||
doc="The offset to be used for inference sampling. This parameter is passed to GridPatchd. Default to None.")
|
||||
|
||||
@property
|
||||
def scaled_threshold(self) -> float:
|
||||
"""Returns the threshold to be used for filtering out tiles based on intensity values. We need to multiply
|
||||
the threshold by the tile size to account for the fact that the intensity is computed on the entire tile"""
|
||||
return 0.999 * 3 * self.intensity_threshold * self.tile_size * self.tile_size
|
||||
|
||||
def get_tiling_transform(self, bag_size: int, stage: ModelKey,) -> Callable:
|
||||
if stage == ModelKey.TRAIN:
|
||||
return RandGridPatchd(
|
||||
keys=[SlideKey.IMAGE],
|
||||
patch_size=(self.tile_size, self.tile_size),
|
||||
min_offset=self.rand_min_offset,
|
||||
max_offset=self.rand_max_offset,
|
||||
num_patches=bag_size,
|
||||
overlap=self.tile_overlap,
|
||||
sort_fn=self.tile_sort_fn,
|
||||
threshold=self.scaled_threshold,
|
||||
pad_mode=self.tile_pad_mode, # type: ignore
|
||||
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
|
||||
)
|
||||
else:
|
||||
return GridPatchd(
|
||||
keys=[SlideKey.IMAGE],
|
||||
patch_size=(self.tile_size, self.tile_size),
|
||||
offset=self.inf_offset, # type: ignore
|
||||
num_patches=bag_size,
|
||||
overlap=self.tile_overlap,
|
||||
sort_fn=self.tile_sort_fn,
|
||||
threshold=self.scaled_threshold,
|
||||
pad_mode=self.tile_pad_mode, # type: ignore
|
||||
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ
|
|||
from health_cpath.datamodules.base_module import HistoDataModule
|
||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
||||
|
@ -62,11 +63,9 @@ def test_slides_datamodule_different_bag_sizes(
|
|||
batch_size=2,
|
||||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
tile_size=28,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=0,
|
||||
)
|
||||
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
|
||||
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
|
||||
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
||||
|
||||
|
@ -98,7 +97,7 @@ def test_slides_datamodule_different_batch_sizes(
|
|||
batch_size_inf=batch_size_inf,
|
||||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
tile_size=28,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=0,
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
|||
from monai.transforms import RandFlipd
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
from torch.utils.data import DataLoader
|
||||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_cpath.datamodules.base_module import SlidesDataModule
|
||||
|
@ -94,7 +95,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
|||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
|
@ -115,7 +116,6 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
|||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = None
|
||||
tile_size = 28
|
||||
level = 0
|
||||
assert_batch_index = 0
|
||||
min_expected_tile_count = 16
|
||||
|
@ -123,7 +123,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Pa
|
|||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
|
@ -145,7 +145,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
|||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
tiling_params=TilingParams(tile_size=tile_size),
|
||||
level=level,
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
|
@ -165,9 +165,8 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
|||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
tile_size = 28
|
||||
level = 0
|
||||
step = 14
|
||||
overlap = .5
|
||||
expected_tile_matches = 16
|
||||
min_expected_tile_count = 32
|
||||
assert_batch_index = 0
|
||||
|
@ -175,8 +174,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal:
|
|||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
max_bag_size=None,
|
||||
batch_size=batch_size,
|
||||
tile_size=tile_size,
|
||||
step=step,
|
||||
tiling_params=TilingParams(tile_size=28, tile_overlap=overlap),
|
||||
level=level
|
||||
)
|
||||
dataloader = datamodule.train_dataloader()
|
||||
|
@ -208,14 +206,13 @@ def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> Non
|
|||
|
||||
batch_size = 1
|
||||
tile_count = 4
|
||||
tile_size = 28
|
||||
level = 0
|
||||
flipdatamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
tile_size=tile_size,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
transforms_dict=get_transforms_dict(),
|
||||
)
|
||||
|
@ -259,7 +256,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
|
|||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
|
||||
tile_count = 2
|
||||
tile_size = 28
|
||||
level = 0
|
||||
assert_batch_index = 0
|
||||
n_tiles_list = [4, 5, 6, 7, 8, 9]
|
||||
|
@ -269,7 +265,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
|
|||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
tile_size=tile_size,
|
||||
tiling_params=TilingParams(tile_size=28),
|
||||
level=level,
|
||||
)
|
||||
train_dataloader = datamodule.train_dataloader()
|
||||
|
|
|
@ -2,10 +2,12 @@ import torch
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_cpath.utils.wsi_utils import TilingParams, image_collate
|
||||
from monai.data.meta_tensor import MetaTensor
|
||||
from monai.transforms import RandGridPatchd, GridPatchd
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Sequence
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.wsi_utils import image_collate
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
|
@ -35,8 +37,10 @@ class MockTiledWSIDataset(Dataset):
|
|||
img: Union[np.ndarray, torch.Tensor]
|
||||
if self.img_type == "np":
|
||||
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
else:
|
||||
elif self.img_type == "torch":
|
||||
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
elif self.img_type == "metatensor":
|
||||
img = MetaTensor(torch.randint(0, 255, size=(tile_count, *self.tile_size)))
|
||||
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
|
||||
SlideKey.IMAGE: img,
|
||||
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
|
||||
|
@ -45,7 +49,7 @@ class MockTiledWSIDataset(Dataset):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("img_type", ["np", "torch"])
|
||||
@pytest.mark.parametrize("img_type", ["np", "torch", "metatensor"])
|
||||
@pytest.mark.parametrize("random_n_tiles", [False, True])
|
||||
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
||||
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
|
||||
|
@ -72,3 +76,11 @@ def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
|||
assert all((value_list[idx] == samples_list[idx][key]) for idx in range(batch_size))
|
||||
else:
|
||||
assert all(torch.equal(value_list[idx], samples_list[idx][key]) for idx in range(batch_size))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stage", [m for m in ModelKey])
|
||||
def test_tiling_params(stage: ModelKey) -> None:
|
||||
params = TilingParams()
|
||||
expected_transform_type = RandGridPatchd if stage == ModelKey.TRAIN else GridPatchd
|
||||
transform = params.get_tiling_transform(stage=stage, bag_size=10)
|
||||
assert isinstance(transform, expected_transform_type)
|
||||
|
|
Загрузка…
Ссылка в новой задаче