зеркало из https://github.com/microsoft/hi-ml.git
ENH: Parameterize CuCim backend (#661)
Add wsi_reader_args to be able to use different backend
This commit is contained in:
Родитель
f97d368870
Коммит
a1f580576c
|
@ -79,6 +79,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
wsi_has_mask: bool = param.Boolean(default=True,
|
||||
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
|
||||
"region of the WSI. If False, will load the whole WSI.")
|
||||
wsi_backend: str = param.String(default="cuCIM", doc="The backend to use for loading WSI. ")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
@ -152,6 +153,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
|
|||
val_plot_options=self.get_val_plot_options(),
|
||||
test_plot_options=self.get_test_plot_options(),
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
backend=self.wsi_backend,
|
||||
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
|
||||
)
|
||||
if self.num_top_slides > 0:
|
||||
|
|
|
@ -25,6 +25,7 @@ from monai.transforms.io.dictionary import LoadImaged
|
|||
from monai.apps.pathology.transforms import TileOnGridd
|
||||
from monai.data.image_reader import WSIReader
|
||||
|
||||
|
||||
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
|
||||
|
||||
|
||||
|
@ -280,12 +281,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
random_offset: bool = True,
|
||||
pad_full: bool = False,
|
||||
background_val: int = 255,
|
||||
filter_mode: str = "min",
|
||||
filter_mode: str = 'min',
|
||||
backend: str = 'cuCIM',
|
||||
wsi_reader_args: Dict[str, Any] = {},
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
:param level: the whole slide image level at which the image is extracted, defaults to 1
|
||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend
|
||||
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
|
||||
:param tile_size: size of the square tile, defaults to 224
|
||||
this param is passed to TileOnGridd monai transform for tiling on the fly.
|
||||
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
|
||||
|
@ -301,15 +304,22 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
|
||||
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
|
||||
monai transform for tiling on the fly.
|
||||
:param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform
|
||||
:param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
|
||||
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.level = level
|
||||
# Tiling on the fly params
|
||||
self.tile_size = tile_size
|
||||
self.step = step
|
||||
self.random_offset = random_offset
|
||||
self.pad_full = pad_full
|
||||
self.background_val = background_val
|
||||
self.filter_mode = filter_mode
|
||||
# WSIReader params
|
||||
self.level = level
|
||||
self.backend = backend
|
||||
self.wsi_reader_args = wsi_reader_args
|
||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
|
||||
# max_bag_size_inf to None if set to 0
|
||||
for stage_key, max_bag_size in self.bag_sizes.items():
|
||||
|
@ -322,10 +332,11 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
LoadImaged(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
reader=WSIReader,
|
||||
backend="cuCIM",
|
||||
dtype=np.uint8,
|
||||
level=self.level,
|
||||
image_only=True,
|
||||
level=self.level,
|
||||
backend=self.backend,
|
||||
**self.wsi_reader_args,
|
||||
),
|
||||
TileOnGridd(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
|
|
|
@ -258,7 +258,7 @@ class DeepMILOutputsHandler:
|
|||
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
|
||||
maximise: bool, val_plot_options: Collection[PlotOption],
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
|
||||
val_set_is_dist: bool = True) -> None:
|
||||
backend: str = "cuCIM", val_set_is_dist: bool = True) -> None:
|
||||
"""
|
||||
:param outputs_root: Root directory where to save all produced outputs.
|
||||
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
|
||||
|
@ -271,6 +271,7 @@ class DeepMILOutputsHandler:
|
|||
:param val_plot_options: The desired plot options for validation time.
|
||||
:param test_plot_options: The desired plot options for test time.
|
||||
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
|
||||
:param backend: The backend to use for reading the tiles. Default is "cuCIM".
|
||||
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
|
||||
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
|
||||
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
|
||||
|
@ -294,7 +295,8 @@ class DeepMILOutputsHandler:
|
|||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.VAL,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
wsi_has_mask=wsi_has_mask,
|
||||
backend=backend,
|
||||
)
|
||||
self.test_plots_handler = DeepMILPlotsHandler(
|
||||
plot_options=test_plot_options,
|
||||
|
@ -302,7 +304,8 @@ class DeepMILOutputsHandler:
|
|||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.TEST,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
wsi_has_mask=wsi_has_mask,
|
||||
backend=backend,
|
||||
)
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
|
|
|
@ -183,6 +183,7 @@ class DeepMILPlotsHandler:
|
|||
stage: str = '',
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
wsi_has_mask: bool = True,
|
||||
backend: str = "cuCIM",
|
||||
) -> None:
|
||||
"""Class that handles the plotting of DeepMIL results.
|
||||
|
||||
|
@ -195,6 +196,8 @@ class DeepMILPlotsHandler:
|
|||
:param stage: Test or Validation, used to name the plots
|
||||
:param class_names: List of class names, defaults to None
|
||||
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
|
||||
:param wsi_has_mask: Whether the whole slide images have a mask, defaults to True
|
||||
:param backend: The backend to use for loading the whole slide images, defaults to "cuCIM"
|
||||
"""
|
||||
self.plot_options = plot_options
|
||||
self.class_names = validate_class_names_for_plot_options(class_names, plot_options)
|
||||
|
@ -204,6 +207,7 @@ class DeepMILPlotsHandler:
|
|||
self.figsize = figsize
|
||||
self.stage = stage
|
||||
self.wsi_has_mask = wsi_has_mask
|
||||
self.backend = backend
|
||||
self.slides_dataset: Optional[SlidesDataset] = None
|
||||
|
||||
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
|
||||
|
@ -212,7 +216,8 @@ class DeepMILPlotsHandler:
|
|||
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = self.slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
|
||||
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask,
|
||||
backend=self.backend)
|
||||
return slide_dict
|
||||
|
||||
def save_slide_node_figures(
|
||||
|
|
|
@ -27,7 +27,9 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
|
|||
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
|
||||
def load_image_dict(
|
||||
sample: dict, level: int, margin: int, wsi_has_mask: bool = True, backend: str = "cuCIM"
|
||||
) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
:param sample: dict describing image metadata. Example:
|
||||
|
@ -39,10 +41,11 @@ def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool =
|
|||
'gleason_score': ['0+0']}
|
||||
:param level: level of resolution to be loaded
|
||||
:param margin: margin to be included
|
||||
:return: a dict containing the image data and metadata
|
||||
:param wsi_has_mask: whether the WSI has a mask
|
||||
:param backend: backend to be used to load the image (cuCIM or OpenSlide)
|
||||
"""
|
||||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
|
||||
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
|
||||
img = loader(sample)
|
||||
return img
|
||||
|
||||
|
|
|
@ -297,10 +297,11 @@ def test_location_selected_tiles(level: int) -> None:
|
|||
assert max(tile_ys) <= slide_image.shape[1] // factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["cuCIM", "OpenSlide"])
|
||||
@pytest.mark.parametrize("wsi_has_mask", [True, False])
|
||||
def test_load_image_dict(wsi_has_mask: bool) -> None:
|
||||
def test_load_image_dict(wsi_has_mask: bool, backend: str) -> None:
|
||||
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
|
||||
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
|
||||
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
|
||||
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask, backend=backend)
|
||||
assert mock_load_panda_roi.called == wsi_has_mask
|
||||
assert mock_load_roi.called == (not wsi_has_mask)
|
||||
|
|
Загрузка…
Ссылка в новой задаче