зеркало из https://github.com/microsoft/hi-ml.git
ENH: Upgrade monai to 1.0.1 (#668)
Upgrade monai to 1.0.1 Note that we now need to convert monai outputs to torch tensors since Rand/GridPatchd transforms return MetaTensors a new data structure in the latest releases. This is handled in the collate function
This commit is contained in:
Родитель
9e812e7ca8
Коммит
e68f9b1db7
|
@ -44,15 +44,22 @@
|
||||||
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
||||||
"args": [
|
"args": [
|
||||||
"--model=health_cpath.SlidesPandaImageNetMIL",
|
"--model=health_cpath.SlidesPandaImageNetMIL",
|
||||||
"--pl_fast_dev_run=10",
|
"--pl_limit_train_batches=5",
|
||||||
|
"--pl_limit_test_batches=5",
|
||||||
|
"--pl_limit_val_batches=5",
|
||||||
|
"--max_epochs=2",
|
||||||
|
"--max_num_gpus=1",
|
||||||
"--crossval_count=0",
|
"--crossval_count=0",
|
||||||
"--batch_size=2",
|
"--batch_size=2",
|
||||||
|
"--batch_size_inf=2",
|
||||||
"--max_bag_size=4",
|
"--max_bag_size=4",
|
||||||
"--max_bag_size_inf=4",
|
"--max_bag_size_inf=4",
|
||||||
"--num_top_slides=2",
|
"--num_top_slides=2",
|
||||||
"--num_top_tiles=2"
|
"--num_top_tiles=2",
|
||||||
|
"--strictly_aml_v1=True",
|
||||||
],
|
],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Run TilesPandaImageNetMIL locally",
|
"name": "Python: Run TilesPandaImageNetMIL locally",
|
||||||
|
@ -61,15 +68,22 @@
|
||||||
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
|
||||||
"args": [
|
"args": [
|
||||||
"--model=health_cpath.TilesPandaImageNetMIL",
|
"--model=health_cpath.TilesPandaImageNetMIL",
|
||||||
"--pl_fast_dev_run=10",
|
"--pl_limit_train_batches=5",
|
||||||
|
"--pl_limit_test_batches=5",
|
||||||
|
"--pl_limit_val_batches=5",
|
||||||
|
"--max_epochs=2",
|
||||||
|
"--max_num_gpus=1",
|
||||||
"--crossval_count=0",
|
"--crossval_count=0",
|
||||||
"--batch_size=2",
|
"--batch_size=2",
|
||||||
|
"--batch_size_inf=2",
|
||||||
"--max_bag_size=4",
|
"--max_bag_size=4",
|
||||||
"--max_bag_size_inf=4",
|
"--max_bag_size_inf=4",
|
||||||
"--num_top_slides=2",
|
"--num_top_slides=2",
|
||||||
"--num_top_tiles=2"
|
"--num_top_tiles=2",
|
||||||
|
"--strictly_aml_v1=True",
|
||||||
],
|
],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ dependencies:
|
||||||
- libpng=1.6.37=hbc83047_0
|
- libpng=1.6.37=hbc83047_0
|
||||||
- libstdcxx-ng=11.2.0=h1234567_1
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
- libtasn1=4.16.0=h27cfd23_0
|
- libtasn1=4.16.0=h27cfd23_0
|
||||||
- libtiff=4.4.0=hecacb30_1
|
- libtiff=4.4.0=hecacb30_2
|
||||||
- libunistring=0.9.10=h27cfd23_0
|
- libunistring=0.9.10=h27cfd23_0
|
||||||
- libuuid=1.41.5=h5eee18b_0
|
- libuuid=1.41.5=h5eee18b_0
|
||||||
- libuv=1.40.0=h7b6447c_0
|
- libuv=1.40.0=h7b6447c_0
|
||||||
|
@ -92,7 +92,7 @@ dependencies:
|
||||||
- astroid==2.12.12
|
- astroid==2.12.12
|
||||||
- async-timeout==4.0.2
|
- async-timeout==4.0.2
|
||||||
- attrs==21.4.0
|
- attrs==21.4.0
|
||||||
- azure-ai-ml==1.1.0
|
- azure-ai-ml==1.1.1
|
||||||
- azure-common==1.1.28
|
- azure-common==1.1.28
|
||||||
- azure-core==1.26.1
|
- azure-core==1.26.1
|
||||||
- azure-graphrbac==0.61.1
|
- azure-graphrbac==0.61.1
|
||||||
|
@ -158,7 +158,7 @@ dependencies:
|
||||||
- google-api-core==2.10.2
|
- google-api-core==2.10.2
|
||||||
- google-auth==1.35.0
|
- google-auth==1.35.0
|
||||||
- google-auth-oauthlib==0.4.6
|
- google-auth-oauthlib==0.4.6
|
||||||
- googleapis-common-protos==1.56.4
|
- googleapis-common-protos==1.57.0
|
||||||
- greenlet==2.0.1
|
- greenlet==2.0.1
|
||||||
- grpcio==1.50.0
|
- grpcio==1.50.0
|
||||||
- gunicorn==20.1.0
|
- gunicorn==20.1.0
|
||||||
|
@ -199,7 +199,7 @@ dependencies:
|
||||||
- mdit-py-plugins==0.2.8
|
- mdit-py-plugins==0.2.8
|
||||||
- mlflow==2.0.1
|
- mlflow==2.0.1
|
||||||
- mlflow-skinny==2.0.1
|
- mlflow-skinny==2.0.1
|
||||||
- monai==0.8.0
|
- monai==1.0.1
|
||||||
- more-itertools==8.10.0
|
- more-itertools==8.10.0
|
||||||
- msal==1.20.0
|
- msal==1.20.0
|
||||||
- msal-extensions==0.3.1
|
- msal-extensions==0.3.1
|
||||||
|
|
|
@ -3,7 +3,7 @@ cucim==22.04.00
|
||||||
girder-client==3.1.14
|
girder-client==3.1.14
|
||||||
hi-ml
|
hi-ml
|
||||||
lightning-bolts==0.4.0
|
lightning-bolts==0.4.0
|
||||||
monai==0.8.0
|
monai==1.0.1
|
||||||
more-itertools==8.10.0
|
more-itertools==8.10.0
|
||||||
numpy==1.22.0
|
numpy==1.22.0
|
||||||
pillow==9.0.1
|
pillow==9.0.1
|
||||||
|
|
|
@ -16,6 +16,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||||
|
|
||||||
from health_azure.utils import create_from_matching_params
|
from health_azure.utils import create_from_matching_params
|
||||||
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
|
from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParams
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams
|
||||||
|
|
||||||
from health_ml.utils import fixed_paths
|
from health_ml.utils import fixed_paths
|
||||||
from health_ml.deep_learning_config import OptimizerParams
|
from health_ml.deep_learning_config import OptimizerParams
|
||||||
|
@ -316,27 +317,11 @@ class BaseMILTiles(BaseMIL):
|
||||||
return deepmil_module
|
return deepmil_module
|
||||||
|
|
||||||
|
|
||||||
class BaseMILSlides(BaseMIL):
|
class BaseMILSlides(BaseMIL, TilingParams):
|
||||||
"""BaseSlidesMIL is an abstract subclass of BaseMIL for running MIL experiments on slides datasets. It is
|
"""BaseSlidesMIL is an abstract subclass of BaseMIL for running MIL experiments on slides datasets. It is
|
||||||
responsible for instantiating the full DeepMIL model in slides settings. Subclasses should define their datamodules
|
responsible for instantiating the full DeepMIL model in slides settings. Subclasses should define their datamodules
|
||||||
and configure experiment-specific parameters.
|
and configure experiment-specific parameters.
|
||||||
"""
|
"""
|
||||||
# Slides Data module parameters:
|
|
||||||
tile_size: int = param.Integer(224, bounds=(0, None), doc="Size of the square tile, defaults to 224.")
|
|
||||||
step: int = param.Integer(None, bounds=(0, None),
|
|
||||||
doc="Step size to define the offset between tiles."
|
|
||||||
"If None (default), it takes the same value as tile_size."
|
|
||||||
"If step < tile_size, it creates overlapping tiles."
|
|
||||||
"If step > tile_size, it skips some chunks in the wsi.")
|
|
||||||
random_offset: bool = param.Boolean(False, doc="If True, randomize position of the grid, instead of starting at"
|
|
||||||
"the top-left corner,")
|
|
||||||
pad_full: bool = param.Boolean(False, doc="If True, pad image to the size evenly divisible by tile_size")
|
|
||||||
background_val: int = param.Integer(255, bounds=(0, None),
|
|
||||||
doc="Threshold to estimate the foreground in a whole slide image.")
|
|
||||||
filter_mode: str = param.String("min", doc="mode must be in ['min', 'max', 'random']. If total number of tiles is"
|
|
||||||
"greater than tile_count, then sort by intensity sum, and take the "
|
|
||||||
"smallest (for min), largest (for max) or random (for random) subset, "
|
|
||||||
"defaults to 'min' (which assumes background is high value).")
|
|
||||||
|
|
||||||
def create_model(self) -> SlidesDeepMILModule:
|
def create_model(self) -> SlidesDeepMILModule:
|
||||||
self.data_module = self.get_data_module()
|
self.data_module = self.get_data_module()
|
||||||
|
|
|
@ -3,25 +3,17 @@
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
from typing import Any, Optional, Set
|
from typing import Any, Optional, Set
|
||||||
|
from health_azure.utils import is_running_in_azure_ml, create_from_matching_params
|
||||||
from health_azure.utils import is_running_in_azure_ml
|
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILSlides, BaseMILTiles
|
||||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
|
||||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||||
from health_cpath.datamodules.panda_module import (
|
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||||
PandaSlidesDataModule,
|
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID
|
||||||
PandaTilesDataModule)
|
|
||||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
|
||||||
from health_cpath.models.encoders import (
|
|
||||||
HistoSSLEncoder,
|
|
||||||
ImageNetSimCLREncoder,
|
|
||||||
Resnet18,
|
|
||||||
SSLEncoder)
|
|
||||||
from health_cpath.configs.classification.BaseMIL import BaseMILSlides, BaseMILTiles, BaseMIL
|
|
||||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||||
from health_cpath.datasets.default_paths import (
|
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||||
PANDA_DATASET_ID,
|
from health_cpath.models.encoders import HistoSSLEncoder, ImageNetSimCLREncoder, Resnet18, SSLEncoder
|
||||||
PANDA_5X_TILES_DATASET_ID)
|
|
||||||
from health_cpath.utils.naming import PlotOption
|
from health_cpath.utils.naming import PlotOption
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams
|
||||||
|
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +124,6 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
||||||
# declared in BaseMILSlides:
|
# declared in BaseMILSlides:
|
||||||
level=1,
|
level=1,
|
||||||
tile_size=224,
|
tile_size=224,
|
||||||
random_offset=True,
|
|
||||||
background_val=255,
|
background_val=255,
|
||||||
azure_datasets=[PANDA_DATASET_ID],)
|
azure_datasets=[PANDA_DATASET_ID],)
|
||||||
default_kwargs.update(kwargs)
|
default_kwargs.update(kwargs)
|
||||||
|
@ -152,16 +143,11 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
||||||
root_path=self.local_datasets[0],
|
root_path=self.local_datasets[0],
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
batch_size_inf=self.batch_size_inf,
|
batch_size_inf=self.batch_size_inf,
|
||||||
level=self.level,
|
|
||||||
max_bag_size=self.max_bag_size,
|
max_bag_size=self.max_bag_size,
|
||||||
max_bag_size_inf=self.max_bag_size_inf,
|
max_bag_size_inf=self.max_bag_size_inf,
|
||||||
tile_size=self.tile_size,
|
level=self.level,
|
||||||
step=self.step,
|
tiling_params=create_from_matching_params(self, TilingParams),
|
||||||
random_offset=self.random_offset,
|
|
||||||
seed=self.get_effective_random_seed(),
|
seed=self.get_effective_random_seed(),
|
||||||
pad_full=self.pad_full,
|
|
||||||
background_val=self.background_val,
|
|
||||||
filter_mode=self.filter_mode,
|
|
||||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||||
crossval_count=self.crossval_count,
|
crossval_count=self.crossval_count,
|
||||||
crossval_index=self.crossval_index,
|
crossval_index=self.crossval_index,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch import optim
|
||||||
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
from monai.transforms import Compose, ScaleIntensityRanged, RandRotate90d, RandFlipd
|
||||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary
|
||||||
from health_azure.utils import create_from_matching_params
|
from health_azure.utils import create_from_matching_params
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams
|
||||||
from health_ml.networks.layers.attention_layers import (
|
from health_ml.networks.layers.attention_layers import (
|
||||||
TransformerPooling,
|
TransformerPooling,
|
||||||
TransformerPoolingBenchmark
|
TransformerPoolingBenchmark
|
||||||
|
@ -107,13 +108,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
||||||
max_bag_size=self.max_bag_size,
|
max_bag_size=self.max_bag_size,
|
||||||
max_bag_size_inf=self.max_bag_size_inf,
|
max_bag_size_inf=self.max_bag_size_inf,
|
||||||
level=self.level,
|
level=self.level,
|
||||||
tile_size=self.tile_size,
|
tiling_params=create_from_matching_params(self, TilingParams),
|
||||||
step=self.step,
|
|
||||||
random_offset=self.random_offset,
|
|
||||||
seed=self.get_effective_random_seed(),
|
seed=self.get_effective_random_seed(),
|
||||||
pad_full=self.pad_full,
|
|
||||||
background_val=self.background_val,
|
|
||||||
filter_mode=self.filter_mode,
|
|
||||||
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
|
||||||
crossval_count=self.crossval_count,
|
crossval_count=self.crossval_count,
|
||||||
crossval_index=self.crossval_index,
|
crossval_index=self.crossval_index,
|
||||||
|
|
|
@ -8,21 +8,19 @@ from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
|
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar, Union
|
||||||
|
|
||||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||||
from health_ml.utils.common_utils import _create_generator
|
from health_ml.utils.common_utils import _create_generator
|
||||||
|
|
||||||
from health_cpath.utils.wsi_utils import image_collate
|
from health_cpath.utils.wsi_utils import TilingParams, image_collate
|
||||||
from health_cpath.models.transforms import LoadTilesBatchd
|
from health_cpath.models.transforms import LoadTilesBatchd
|
||||||
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
|
from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset
|
||||||
from health_cpath.utils.naming import ModelKey
|
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||||
|
|
||||||
from monai.transforms.compose import Compose
|
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||||
from monai.transforms.io.dictionary import LoadImaged
|
from monai.transforms import Compose, LoadImaged, SplitDimd
|
||||||
from monai.apps.pathology.transforms import TileOnGridd
|
|
||||||
from monai.data.image_reader import WSIReader
|
from monai.data.image_reader import WSIReader
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,12 +63,12 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
||||||
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
|
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
|
||||||
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
|
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
|
||||||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
||||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
|
||||||
tile_count used for tiling on the fly at training time.
|
num_patches used for tiling on the fly at training time.
|
||||||
:param max_bag_size_inf: Upper bound on number of tiles in each loaded bag during validation and test stages.
|
:param max_bag_size_inf: Upper bound on number of tiles in each loaded bag during validation and test stages.
|
||||||
If 0 (default), will return all samples in each bag. If > 0 , bags larger than `max_bag_size_inf` will yield
|
If 0 (default), will return all samples in each bag. If > 0 , bags larger than `max_bag_size_inf` will yield
|
||||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
random subsets of instances. For SlideDataModule, this parameter is used in Rand/GridPatchd Transform to set the
|
||||||
tile_count used for tiling on the fly at validation and test time.
|
num_patches used for tiling on the fly at validation and test time.
|
||||||
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
|
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
|
||||||
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
|
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
|
||||||
:param transforms_dict: A dictionary that contains transform, or a composition of transforms using
|
:param transforms_dict: A dictionary that contains transform, or a composition of transforms using
|
||||||
|
@ -276,61 +274,31 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
level: int = 1,
|
level: int = 1,
|
||||||
tile_size: int = 224,
|
|
||||||
step: Optional[int] = None,
|
|
||||||
random_offset: bool = True,
|
|
||||||
pad_full: bool = False,
|
|
||||||
background_val: int = 255,
|
|
||||||
filter_mode: str = 'min',
|
|
||||||
backend: str = 'cuCIM',
|
backend: str = 'cuCIM',
|
||||||
wsi_reader_args: Dict[str, Any] = {},
|
wsi_reader_args: Dict[str, Any] = {},
|
||||||
|
tiling_params: TilingParams = TilingParams(),
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param level: the whole slide image level at which the image is extracted, defaults to 1
|
:param level: the whole slide image level at which the image is extracted, defaults to 1
|
||||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
|
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
|
||||||
:param tile_size: size of the square tile, defaults to 224
|
:param backend: the WSI reader backend, defaults to "cuCIM".
|
||||||
this param is passed to TileOnGridd monai transform for tiling on the fly.
|
:param wsi_reader_args: additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
||||||
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
|
|
||||||
Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in
|
|
||||||
the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
|
||||||
:param random_offset: randomize position of the grid, instead of starting from the top-left corner,
|
|
||||||
defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
|
||||||
:param pad_full: pad image to the size evenly divisible by tile_size, defaults to False
|
|
||||||
This param is passed to TileOnGridd monai transform for tiling on the fly.
|
|
||||||
:param background_val: the background constant to ignore background tiles (e.g. 255 for white background),
|
|
||||||
defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly.
|
|
||||||
:param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than
|
|
||||||
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
|
|
||||||
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
|
|
||||||
monai transform for tiling on the fly.
|
|
||||||
:param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform
|
|
||||||
:param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
|
||||||
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
||||||
|
:param tiling_params: the tiling on the fly parameters, defaults to TileOnTheFlyParams()
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
# Tiling on the fly params
|
self.tiling_params = tiling_params
|
||||||
self.tile_size = tile_size
|
|
||||||
self.step = step
|
|
||||||
self.random_offset = random_offset
|
|
||||||
self.pad_full = pad_full
|
|
||||||
self.background_val = background_val
|
|
||||||
self.filter_mode = filter_mode
|
|
||||||
# WSIReader params
|
# WSIReader params
|
||||||
self.level = level
|
self.level = level
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.wsi_reader_args = wsi_reader_args
|
self.wsi_reader_args = wsi_reader_args
|
||||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
|
|
||||||
# max_bag_size_inf to None if set to 0
|
|
||||||
for stage_key, max_bag_size in self.bag_sizes.items():
|
|
||||||
if max_bag_size == 0:
|
|
||||||
self.bag_sizes[stage_key] = None # type: ignore
|
|
||||||
|
|
||||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
||||||
base_transform = Compose(
|
base_transform = Compose(
|
||||||
[
|
[
|
||||||
LoadImaged(
|
LoadImaged(
|
||||||
keys=slides_dataset.IMAGE_COLUMN,
|
keys=SlideKey.IMAGE,
|
||||||
reader=WSIReader,
|
reader=WSIReader,
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
image_only=True,
|
image_only=True,
|
||||||
|
@ -338,17 +306,10 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
||||||
backend=self.backend,
|
backend=self.backend,
|
||||||
**self.wsi_reader_args,
|
**self.wsi_reader_args,
|
||||||
),
|
),
|
||||||
TileOnGridd(
|
self.tiling_params.get_tiling_transform(bag_size=self.bag_sizes[stage], stage=stage),
|
||||||
keys=slides_dataset.IMAGE_COLUMN,
|
# GridPatchd returns stacked tiles (bag_size, C, H, W), however we need to split them into separate
|
||||||
tile_count=self.bag_sizes[stage],
|
# tiles to be able to apply augmentations on each tile independently
|
||||||
tile_size=self.tile_size,
|
SplitDimd(keys=SlideKey.IMAGE, dim=0, keepdim=False, list_output=True),
|
||||||
step=self.step,
|
|
||||||
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
|
|
||||||
pad_full=self.pad_full,
|
|
||||||
background_val=self.background_val,
|
|
||||||
filter_mode=self.filter_mode,
|
|
||||||
return_list_of_dicts=True,
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if self.transforms_dict and self.transforms_dict[stage]:
|
if self.transforms_dict and self.transforms_dict[stage]:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
||||||
|
|
||||||
|
from monai.data.meta_tensor import MetaTensor
|
||||||
from monai.data.dataset import Dataset
|
from monai.data.dataset import Dataset
|
||||||
from monai.data.image_reader import WSIReader
|
from monai.data.image_reader import WSIReader
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -47,6 +48,9 @@ def load_image_dict(
|
||||||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||||
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
|
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
|
||||||
img = loader(sample)
|
img = loader(sample)
|
||||||
|
if isinstance(img[SlideKey.IMAGE], MetaTensor):
|
||||||
|
# New monai transforms return a MetaTensor, we need to convert it to a numpy array for backward compatibility
|
||||||
|
img[SlideKey.IMAGE] = img[SlideKey.IMAGE].numpy()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
|
import param
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, Callable, List, Optional
|
||||||
from health_cpath.utils.naming import SlideKey
|
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||||
from health_ml.utils.bag_utils import multibag_collate
|
from health_ml.utils.bag_utils import multibag_collate
|
||||||
|
from monai.data.meta_tensor import MetaTensor
|
||||||
|
from monai.transforms import RandGridPatchd, GridPatchd
|
||||||
|
|
||||||
|
|
||||||
def image_collate(batch: List) -> Any:
|
def image_collate(batch: List) -> Any:
|
||||||
|
@ -11,15 +14,88 @@ def image_collate(batch: List) -> Any:
|
||||||
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
||||||
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
||||||
followed by the default collate which will form a batch BxNx3xHxW.
|
followed by the default collate which will form a batch BxNx3xHxW.
|
||||||
The list of dicts refers to the the list of tiles produced by the TileOnGridd transform applied on a WSI.
|
The list of dicts refers to the the list of tiles produced by the Rand/GridPatchd transform applied on a WSI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
data = item[0]
|
data = item[0]
|
||||||
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
if isinstance(data[SlideKey.IMAGE], MetaTensor):
|
||||||
|
# MetaTensor is a monai class that is used to store metadata along with the image
|
||||||
|
# We need to convert it to torch tensor to avoid adding the metadata to the batch
|
||||||
|
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE].as_tensor() for ix in item], dim=0)
|
||||||
|
elif isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
||||||
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
|
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
|
||||||
else:
|
else:
|
||||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||||
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
|
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
|
||||||
batch[i] = data
|
batch[i] = data
|
||||||
return multibag_collate(batch)
|
return multibag_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class TilingParams(param.Parameterized):
|
||||||
|
"""Parameters for Tiling On the Fly a WSI using RandGridPatchd and GridPatchd monai transforms"""
|
||||||
|
|
||||||
|
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="The size of the tile, Default: 224")
|
||||||
|
tile_overlap: int = param.Number(
|
||||||
|
default=0,
|
||||||
|
bounds=(0.0, 1.0),
|
||||||
|
doc="The amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).")
|
||||||
|
tile_sort_fn: Optional[str] = param.String(
|
||||||
|
default='min',
|
||||||
|
doc="When bag_size is fixed, it determines whether to keep tiles with highest intensity values (`'max'`), "
|
||||||
|
"lowest values (`'min'`) that assumes background is high values, or in their default order (`None`). ")
|
||||||
|
tile_pad_mode: Optional[str] = param.String(
|
||||||
|
default=None,
|
||||||
|
doc="The mode of padding, refer to NumpyPadMode and PytorchPadMode. Defaults to None, for no padding.")
|
||||||
|
intensity_threshold: float = param.Number(
|
||||||
|
default=255.,
|
||||||
|
doc="The intensity threshold to filter out tiles based on intensity values. Default to None.")
|
||||||
|
background_val: int = param.Integer(
|
||||||
|
default=255,
|
||||||
|
doc="The intensity value of background. Default to 255.")
|
||||||
|
rand_min_offset: int = param.Integer(
|
||||||
|
default=0,
|
||||||
|
bounds=(0, None),
|
||||||
|
doc="The minimum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
|
||||||
|
"the random version of RandGridPatchd used at training time. Default to 0.")
|
||||||
|
rand_max_offset: int = param.Integer(
|
||||||
|
default=None,
|
||||||
|
bounds=(0, None),
|
||||||
|
doc="The maximum range of sarting position to be selected randomly. This parameter is passed to RandGridPatchd."
|
||||||
|
"the random version of RandGridPatchd used at training time. Default to None.")
|
||||||
|
inf_offset: Optional[int] = param.Integer(
|
||||||
|
default=None,
|
||||||
|
doc="The offset to be used for inference sampling. This parameter is passed to GridPatchd. Default to None.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scaled_threshold(self) -> float:
|
||||||
|
"""Returns the threshold to be used for filtering out tiles based on intensity values. We need to multiply
|
||||||
|
the threshold by the tile size to account for the fact that the intensity is computed on the entire tile"""
|
||||||
|
return 0.999 * 3 * self.intensity_threshold * self.tile_size * self.tile_size
|
||||||
|
|
||||||
|
def get_tiling_transform(self, bag_size: int, stage: ModelKey,) -> Callable:
|
||||||
|
if stage == ModelKey.TRAIN:
|
||||||
|
return RandGridPatchd(
|
||||||
|
keys=[SlideKey.IMAGE],
|
||||||
|
patch_size=(self.tile_size, self.tile_size),
|
||||||
|
min_offset=self.rand_min_offset,
|
||||||
|
max_offset=self.rand_max_offset,
|
||||||
|
num_patches=bag_size,
|
||||||
|
overlap=self.tile_overlap,
|
||||||
|
sort_fn=self.tile_sort_fn,
|
||||||
|
threshold=self.scaled_threshold,
|
||||||
|
pad_mode=self.tile_pad_mode, # type: ignore
|
||||||
|
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GridPatchd(
|
||||||
|
keys=[SlideKey.IMAGE],
|
||||||
|
patch_size=(self.tile_size, self.tile_size),
|
||||||
|
offset=self.inf_offset, # type: ignore
|
||||||
|
num_patches=bag_size,
|
||||||
|
overlap=self.tile_overlap,
|
||||||
|
sort_fn=self.tile_sort_fn,
|
||||||
|
threshold=self.scaled_threshold,
|
||||||
|
pad_mode=self.tile_pad_mode, # type: ignore
|
||||||
|
constant_values=self.background_val, # this arg is passed to np.pad or torch.pad
|
||||||
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ
|
||||||
from health_cpath.datamodules.base_module import HistoDataModule
|
from health_cpath.datamodules.base_module import HistoDataModule
|
||||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams
|
||||||
from health_ml.utils.common_utils import is_gpu_available
|
from health_ml.utils.common_utils import is_gpu_available
|
||||||
from testhisto.utils.utils_testhisto import run_distributed
|
from testhisto.utils.utils_testhisto import run_distributed
|
||||||
|
|
||||||
|
@ -62,11 +63,9 @@ def test_slides_datamodule_different_bag_sizes(
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
max_bag_size=max_bag_size,
|
max_bag_size=max_bag_size,
|
||||||
max_bag_size_inf=max_bag_size_inf,
|
max_bag_size_inf=max_bag_size_inf,
|
||||||
tile_size=28,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=0,
|
level=0,
|
||||||
)
|
)
|
||||||
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
|
|
||||||
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
|
|
||||||
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
||||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
||||||
|
|
||||||
|
@ -98,7 +97,7 @@ def test_slides_datamodule_different_batch_sizes(
|
||||||
batch_size_inf=batch_size_inf,
|
batch_size_inf=batch_size_inf,
|
||||||
max_bag_size=16,
|
max_bag_size=16,
|
||||||
max_bag_size_inf=16,
|
max_bag_size_inf=16,
|
||||||
tile_size=28,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=0,
|
level=0,
|
||||||
)
|
)
|
||||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||||
|
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
||||||
from monai.transforms import RandFlipd
|
from monai.transforms import RandFlipd
|
||||||
from typing import Generator, Dict, Callable, Union, Tuple
|
from typing import Generator, Dict, Callable, Union, Tuple
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams
|
||||||
|
|
||||||
from health_ml.utils.common_utils import is_gpu_available
|
from health_ml.utils.common_utils import is_gpu_available
|
||||||
from health_cpath.datamodules.base_module import SlidesDataModule
|
from health_cpath.datamodules.base_module import SlidesDataModule
|
||||||
|
@ -94,7 +95,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||||
root_path=mock_panda_slides_root_dir_diagonal,
|
root_path=mock_panda_slides_root_dir_diagonal,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_bag_size=tile_count,
|
max_bag_size=tile_count,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=level,
|
level=level,
|
||||||
)
|
)
|
||||||
dataloader = datamodule.train_dataloader()
|
dataloader = datamodule.train_dataloader()
|
||||||
|
@ -115,7 +116,6 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
tile_count = None
|
tile_count = None
|
||||||
tile_size = 28
|
|
||||||
level = 0
|
level = 0
|
||||||
assert_batch_index = 0
|
assert_batch_index = 0
|
||||||
min_expected_tile_count = 16
|
min_expected_tile_count = 16
|
||||||
|
@ -123,7 +123,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Pa
|
||||||
root_path=mock_panda_slides_root_dir_diagonal,
|
root_path=mock_panda_slides_root_dir_diagonal,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_bag_size=tile_count,
|
max_bag_size=tile_count,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=level,
|
level=level,
|
||||||
)
|
)
|
||||||
dataloader = datamodule.train_dataloader()
|
dataloader = datamodule.train_dataloader()
|
||||||
|
@ -145,7 +145,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
||||||
root_path=mock_panda_slides_root_dir_diagonal,
|
root_path=mock_panda_slides_root_dir_diagonal,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_bag_size=tile_count,
|
max_bag_size=tile_count,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=tile_size),
|
||||||
level=level,
|
level=level,
|
||||||
)
|
)
|
||||||
dataloader = datamodule.train_dataloader()
|
dataloader = datamodule.train_dataloader()
|
||||||
|
@ -165,9 +165,8 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||||
tile_size = 28
|
|
||||||
level = 0
|
level = 0
|
||||||
step = 14
|
overlap = .5
|
||||||
expected_tile_matches = 16
|
expected_tile_matches = 16
|
||||||
min_expected_tile_count = 32
|
min_expected_tile_count = 32
|
||||||
assert_batch_index = 0
|
assert_batch_index = 0
|
||||||
|
@ -175,8 +174,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal:
|
||||||
root_path=mock_panda_slides_root_dir_diagonal,
|
root_path=mock_panda_slides_root_dir_diagonal,
|
||||||
max_bag_size=None,
|
max_bag_size=None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=28, tile_overlap=overlap),
|
||||||
step=step,
|
|
||||||
level=level
|
level=level
|
||||||
)
|
)
|
||||||
dataloader = datamodule.train_dataloader()
|
dataloader = datamodule.train_dataloader()
|
||||||
|
@ -208,14 +206,13 @@ def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> Non
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
tile_count = 4
|
tile_count = 4
|
||||||
tile_size = 28
|
|
||||||
level = 0
|
level = 0
|
||||||
flipdatamodule = PandaSlidesDataModule(
|
flipdatamodule = PandaSlidesDataModule(
|
||||||
root_path=mock_panda_slides_root_dir_diagonal,
|
root_path=mock_panda_slides_root_dir_diagonal,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_bag_size=tile_count,
|
max_bag_size=tile_count,
|
||||||
max_bag_size_inf=0,
|
max_bag_size_inf=0,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=level,
|
level=level,
|
||||||
transforms_dict=get_transforms_dict(),
|
transforms_dict=get_transforms_dict(),
|
||||||
)
|
)
|
||||||
|
@ -259,7 +256,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
|
||||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||||
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
|
def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_different_n_tiles: Path) -> None:
|
||||||
tile_count = 2
|
tile_count = 2
|
||||||
tile_size = 28
|
|
||||||
level = 0
|
level = 0
|
||||||
assert_batch_index = 0
|
assert_batch_index = 0
|
||||||
n_tiles_list = [4, 5, 6, 7, 8, 9]
|
n_tiles_list = [4, 5, 6, 7, 8, 9]
|
||||||
|
@ -269,7 +265,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_bag_size=tile_count,
|
max_bag_size=tile_count,
|
||||||
max_bag_size_inf=0,
|
max_bag_size_inf=0,
|
||||||
tile_size=tile_size,
|
tiling_params=TilingParams(tile_size=28),
|
||||||
level=level,
|
level=level,
|
||||||
)
|
)
|
||||||
train_dataloader = datamodule.train_dataloader()
|
train_dataloader = datamodule.train_dataloader()
|
||||||
|
|
|
@ -2,10 +2,12 @@ import torch
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||||
|
from health_cpath.utils.wsi_utils import TilingParams, image_collate
|
||||||
|
from monai.data.meta_tensor import MetaTensor
|
||||||
|
from monai.transforms import RandGridPatchd, GridPatchd
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from health_cpath.utils.naming import SlideKey
|
|
||||||
from health_cpath.utils.wsi_utils import image_collate
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,8 +37,10 @@ class MockTiledWSIDataset(Dataset):
|
||||||
img: Union[np.ndarray, torch.Tensor]
|
img: Union[np.ndarray, torch.Tensor]
|
||||||
if self.img_type == "np":
|
if self.img_type == "np":
|
||||||
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
|
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||||
else:
|
elif self.img_type == "torch":
|
||||||
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
|
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||||
|
elif self.img_type == "metatensor":
|
||||||
|
img = MetaTensor(torch.randint(0, 255, size=(tile_count, *self.tile_size)))
|
||||||
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
|
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
|
||||||
SlideKey.IMAGE: img,
|
SlideKey.IMAGE: img,
|
||||||
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
|
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
|
||||||
|
@ -45,7 +49,7 @@ class MockTiledWSIDataset(Dataset):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("img_type", ["np", "torch"])
|
@pytest.mark.parametrize("img_type", ["np", "torch", "metatensor"])
|
||||||
@pytest.mark.parametrize("random_n_tiles", [False, True])
|
@pytest.mark.parametrize("random_n_tiles", [False, True])
|
||||||
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
||||||
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
|
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
|
||||||
|
@ -72,3 +76,11 @@ def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
||||||
assert all((value_list[idx] == samples_list[idx][key]) for idx in range(batch_size))
|
assert all((value_list[idx] == samples_list[idx][key]) for idx in range(batch_size))
|
||||||
else:
|
else:
|
||||||
assert all(torch.equal(value_list[idx], samples_list[idx][key]) for idx in range(batch_size))
|
assert all(torch.equal(value_list[idx], samples_list[idx][key]) for idx in range(batch_size))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stage", [m for m in ModelKey])
|
||||||
|
def test_tiling_params(stage: ModelKey) -> None:
|
||||||
|
params = TilingParams()
|
||||||
|
expected_transform_type = RandGridPatchd if stage == ModelKey.TRAIN else GridPatchd
|
||||||
|
transform = params.get_tiling_transform(stage=stage, bag_size=10)
|
||||||
|
assert isinstance(transform, expected_transform_type)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче