зеркало из https://github.com/microsoft/hi-ml.git
Add image load profiling (#138)
Add sample code to profile loading and tiling large images
This commit is contained in:
Родитель
7e00c84daa
Коммит
9da95e1881
|
@ -9,6 +9,27 @@ RUN su vscode -c "umask 0002 && . /usr/local/share/nvm/nvm.sh && nvm install ${N
|
|||
# Install ncc to build format_coverage github action.
|
||||
RUN npm install -g @vercel/ncc
|
||||
|
||||
# Avoid warnings by switching to noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# install additional OS packages, i.e. make.
|
||||
RUN apt-get update \
|
||||
&& apt-get -y install --no-install-recommends apt-utils dialog 2>&1 \
|
||||
#
|
||||
# make and gcc
|
||||
&& apt-get -y install build-essential \
|
||||
#
|
||||
# openslide c libs
|
||||
&& apt-get -y install openslide-tools \
|
||||
#
|
||||
# cleanup
|
||||
&& apt-get autoremove -y \
|
||||
&& apt-get clean -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Switch back to dialog for any ad-hoc use of apt-get
|
||||
ENV DEBIAN_FRONTEND=dialog
|
||||
|
||||
# Copy environment.yml (if found) to a temp location so we update the environment. Also
|
||||
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
|
||||
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
|
||||
|
@ -16,15 +37,6 @@ COPY build_requirements.txt test_requirements.txt /tmp/conda-tmp/
|
|||
RUN /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml \
|
||||
&& rm -rf /tmp/conda-tmp
|
||||
|
||||
# [Optional] Uncomment to install a different version of Python than the default
|
||||
# RUN conda install -y python=3.6 \
|
||||
# && pip install --no-cache-dir pipx \
|
||||
# && pipx reinstall-all
|
||||
|
||||
# install additional OS packages, i.e. make.
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends make
|
||||
|
||||
# [Optional] Uncomment this section to hard code HIML test environment variables into image.
|
||||
# ENV HIML_RESOURCE_GROUP="<< YOUR_HIML_RESOURCE_GROUP >>"
|
||||
# ENV HIML_SUBSCRIPTION_ID="<< YOUR_HIML_SUBSCRIPTION_ID >>"
|
||||
|
|
|
@ -41,5 +41,5 @@
|
|||
// "postCreateCommand": "python --version",
|
||||
|
||||
// Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
|
||||
"remoteUser": "vscode"
|
||||
// "remoteUser": "vscode"
|
||||
}
|
||||
|
|
|
@ -64,12 +64,11 @@ jobs:
|
|||
with:
|
||||
node-version: '14'
|
||||
- uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
activate-environment: InnerEye
|
||||
auto-activate-base: false
|
||||
environment-file: environment.yml
|
||||
- name: pyright
|
||||
run: make pyright
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
conda info
|
||||
make pyright
|
||||
|
||||
pytest_fast:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -15,6 +15,8 @@ created.
|
|||
### Added
|
||||
- ([#142](https://github.com/microsoft/hi-ml/pull/142)) Adding AzureML progress bar and diagnostics for batch loading
|
||||
|
||||
- ([#138](https://github.com/microsoft/hi-ml/pull/138)) Guidelines and profiling for whole slide images.
|
||||
|
||||
### Changed
|
||||
|
||||
### Fixed
|
||||
|
|
20
Makefile
20
Makefile
|
@ -18,22 +18,24 @@ pip_build: pip_upgrade
|
|||
pip_test: pip_upgrade
|
||||
pip install -r test_requirements.txt
|
||||
|
||||
# pip install local package in editable mode for development and testing
|
||||
# pip install local packages in editable mode for development and testing
|
||||
call_pip_local:
|
||||
$(call call_packages,call_pip_local)
|
||||
|
||||
# pip upgrade and install local package in editable mode
|
||||
# pip upgrade and install local packages in editable mode
|
||||
pip_local: pip_upgrade call_pip_local
|
||||
|
||||
# pip install everything for local development and testing
|
||||
pip: pip_build pip_test call_pip_local
|
||||
|
||||
# Set the conda environment for local development work, that contains all packages need for both hi-ml and hi-ml-azure
|
||||
# This is built from the package requirements, which pull in hi-ml-azure as a dependency, but for local dev work,
|
||||
# we want to consume that from source rather than pypi.
|
||||
conda:
|
||||
conda env update --file environment.yml
|
||||
# update current conda environment
|
||||
conda_update:
|
||||
conda env update -n $(CONDA_DEFAULT_ENV) --file environment.yml
|
||||
conda env update -n $(CONDA_DEFAULT_ENV) --file hi-ml/testhiml/testhiml/utils/slide_image_loading/environment.yml
|
||||
|
||||
# Set the conda environment for local development work, that contains all packages need for both hi-ml and hi-ml-azure
|
||||
# with hi-ml and hi-ml-azure installed in editable mode
|
||||
conda: conda_update call_pip_local
|
||||
|
||||
## Actions
|
||||
|
||||
|
@ -67,8 +69,8 @@ call_pyright:
|
|||
npm install -g pyright
|
||||
pyright
|
||||
|
||||
# pip install test requirements and run pyright
|
||||
pyright: pip call_pyright
|
||||
# conda install test requirements and run pyright
|
||||
pyright: conda call_pyright
|
||||
|
||||
# run basic checks
|
||||
call_check: call_flake8 call_mypy
|
||||
|
|
|
@ -41,6 +41,12 @@ The `hi-ml` toolbox provides
|
|||
developers.md
|
||||
contributing.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Guidelines
|
||||
|
||||
whole_slide_images.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Changelog
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Whole Slide Images
|
||||
|
||||
Computational Pathology works with image files that can be very large in size, up to many GB. These files may be too large to load entirely into memory at once, or at least too large to act as training data. Instead they may be split into multiple tiles of a much smaller size, e.g. 224x224 pixels before being used for training. There are two popular libraries used for handling this type of image:
|
||||
|
||||
* [OpenSlide](https://openslide.org/)
|
||||
* [cuCIM](https://github.com/rapidsai/cucim)
|
||||
|
||||
but they both come with trade offs and complications.
|
||||
|
||||
In development there is also [tifffile](https://github.com/cgohlke/tifffile/), but this is untested.
|
||||
|
||||
## OpenSlide
|
||||
|
||||
There is a Python interface for OpenSlide at [openslide-python](https://pypi.org/project/openslide-python/), but this first requires the installation of the OpenSlide library itself. This can be done on Ubuntu with:
|
||||
|
||||
```bash
|
||||
apt-get install openslide-tools
|
||||
```
|
||||
|
||||
On Windows follow the instructions [here](https://openslide.org/docs/windows/) and make sure that the install directory is added to the system path.
|
||||
|
||||
Once the shared library/dlls are installed, install the Python interface with:
|
||||
|
||||
```bash
|
||||
pip install openslide-python
|
||||
```
|
||||
|
||||
## cuCIM
|
||||
|
||||
cuCIM is much easier to install, it can be done entirely with the Python package: [cucim](https://pypi.org/project/cucim/). However, there are the following caveats:
|
||||
|
||||
* It requires a GPU, with NVIDIA driver 450.36+
|
||||
* It requires CUDA 11.0+
|
||||
* It supports only a subset of tiff image files.
|
||||
|
||||
The suitable AzureML base Docker images are therefore the ones containing `cuda11`, and the compute instance must contain a GPU.
|
||||
|
||||
## Performance
|
||||
|
||||
An exploratory set of scripts are at [slide_image_loading](./hi-ml/testhiml/utils/slide_image_loading) for comparing loading images with OpenSlide or cuCIM, and performing tiling using both libraries.
|
||||
|
||||
### Loading and saving at lowest resolution
|
||||
|
||||
Four test tiff files are used:
|
||||
|
||||
* a 44.5 MB file with level dimensions: ((27648, 29440), (6912, 7360), (1728, 1840))
|
||||
* a 19.9 MB file with level dimensions: ((5888, 25344), (1472, 6336), (368, 1584))
|
||||
* a 5.5 MB file with level dimensions: ((27648, 29440), (6912, 7360), (1728, 1840)), but acting as a mask
|
||||
* a 2.1 MB file with level dimensions: ((5888, 25344), (1472, 6336), (368, 1584)), but acting as a mask
|
||||
|
||||
For OpenSlide the following code:
|
||||
|
||||
```python
|
||||
with OpenSlide(str(input_file)) as img:
|
||||
count = img.level_count
|
||||
dimensions = img.level_dimensions
|
||||
|
||||
print(f"level_count: {count}")
|
||||
print(f"dimensions: {dimensions}")
|
||||
|
||||
for k, v in img.properties.items():
|
||||
print(k, v)
|
||||
|
||||
region = img.read_region(location=(0, 0),
|
||||
level=count-1,
|
||||
size=dimensions[count-1])
|
||||
region.save(output_file)
|
||||
```
|
||||
|
||||
took an average of 29ms to open the file, 88ms to read the region, and 243ms to save the region as a png.
|
||||
|
||||
For cuCIM the following code:
|
||||
|
||||
```python
|
||||
img = cucim.CuImage(str(input_file))
|
||||
|
||||
count = img.resolutions['level_count']
|
||||
dimensions = img.resolutions['level_dimensions']
|
||||
|
||||
print(f"level_count: {count}")
|
||||
print(f"level_dimensions: {dimensions}")
|
||||
|
||||
print(img.metadata)
|
||||
|
||||
region = img.read_region(location=(0, 0),
|
||||
size=dimensions[count-1],
|
||||
level=count-1)
|
||||
np_img_arr = np.asarray(region)
|
||||
img2 = Image.fromarray(np_img_arr)
|
||||
img2.save(output_file)
|
||||
```
|
||||
|
||||
took an average of 369ms to open the file, 7ms to read the region and 197ms to save the region as a png, but note that it failed to handle the mask images.
|
||||
|
||||
### Loading and saving as tiles at the medium resolution
|
||||
|
||||
Test code created tiles of size 224x224 pilfes, loaded the mask images, and used occupancy levels to decide which tiles to create and save from level 1 - the middle resolution. This was profiled against both images, as above.
|
||||
|
||||
For cuCIM the total time was 4.7s, 2.48s to retain the tiles as a Numpy stack but not save them as pngs. cuCIM has the option of cacheing images, but is actually made performance slightly worse, possibly because the natural tile sizes in the original tiffs were larger than the tile sizes.
|
||||
|
||||
For OpenSlide the comparable total times were 5.7s, and 3.26s.
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# According to:
|
||||
# https://github.com/rapidsai/cucim/blob/main/CONTRIBUTING.md
|
||||
# there is a minimum requirement of CUDA 11.0+
|
||||
# so the following base images are not suitable.
|
||||
# FROM mcr.microsoft.com/azureml/intelmpi2018.3-cuda10.0-cudnn7-ubuntu16.04
|
||||
# FROM mcr.microsoft.com/azureml/intelmpi2018.3-cuda9.0-cudnn7-ubuntu16.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.0-cudnn7-ubuntu16.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.0-cudnn7-ubuntu18.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.1-cudnn7-ubuntu18.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn7-ubuntu18.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi3.1.2-cuda9.0-cudnn7-ubuntu16.04
|
||||
|
||||
# The following have been tested, and are suitable:
|
||||
# FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.0.3-cudnn8-ubuntu18.04
|
||||
# FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.1-cudnn8-ubuntu18.04
|
||||
FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.1-cudnn8-ubuntu20.04
|
||||
|
||||
# Avoid warnings by switching to noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# install additional OS packages, i.e. make.
|
||||
RUN apt-get update \
|
||||
&& apt-get -y install --no-install-recommends apt-utils dialog 2>&1 \
|
||||
#
|
||||
# openslide c libs
|
||||
&& apt-get -y install openslide-tools \
|
||||
#
|
||||
# cleanup
|
||||
&& apt-get autoremove -y \
|
||||
&& apt-get clean -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Switch back to dialog for any ad-hoc use of apt-get
|
||||
ENV DEBIAN_FRONTEND=dialog
|
|
@ -0,0 +1,11 @@
|
|||
name: image_loader
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- pip:
|
||||
- azureml-defaults==1.32.0
|
||||
- hi-ml-azure==0.1.9
|
||||
- line_profiler==3.3.1
|
||||
- monai[all]==0.7.0
|
|
@ -0,0 +1,37 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
|
||||
from azureml.core import Environment, ScriptRunConfig
|
||||
from azureml.core.compute import ComputeTarget
|
||||
|
||||
from health_azure import get_workspace, submit_run
|
||||
from health_azure.utils import WORKSPACE_CONFIG_JSON
|
||||
|
||||
|
||||
# A compute instance, with a GPU, required for running cuCIM.
|
||||
GPU_TESTING_INSTANCE_NAME = "testing-standard-nc6"
|
||||
|
||||
here = Path(__file__).parent.resolve()
|
||||
|
||||
workspace = get_workspace(aml_workspace=None,
|
||||
workspace_config_path=here / WORKSPACE_CONFIG_JSON)
|
||||
|
||||
environment = Environment.from_dockerfile(name='image_load_env',
|
||||
dockerfile='./Dockerfile',
|
||||
conda_specification='./environment.yml')
|
||||
|
||||
compute_target = ComputeTarget(workspace=workspace, name=GPU_TESTING_INSTANCE_NAME)
|
||||
|
||||
config = ScriptRunConfig(source_directory='./src',
|
||||
script='profile_load_image.py',
|
||||
compute_target=compute_target,
|
||||
environment=environment)
|
||||
|
||||
run = submit_run(workspace=workspace,
|
||||
experiment_name='image_load_exp',
|
||||
script_run_config=config,
|
||||
wait_for_completion=True,
|
||||
wait_for_completion_show_output=True)
|
|
@ -0,0 +1,127 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
import pandas as pd
|
||||
from monai.config import KeysCollection
|
||||
from monai.data.image_reader import ImageReader, WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
from openslide import OpenSlide
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from Histopathology.utils import box_utils
|
||||
|
||||
|
||||
class PandaDataset(Dataset):
|
||||
"""Dataset class for loading files from the PANDA challenge dataset.
|
||||
|
||||
Iterating over this dataset returns a dictionary containing the `'image_id'`, paths to the `'image'`
|
||||
and `'mask'` files, and the remaining meta-data from the original dataset (`'data_provider'`,
|
||||
`'isup_grade'`, and `'gleason_score'`).
|
||||
|
||||
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
|
||||
"""
|
||||
def __init__(self, root_dir: Union[str, Path], n_slides: Optional[int] = None,
|
||||
frac_slides: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
self.root_dir = Path(root_dir)
|
||||
self.train_df = pd.read_csv(self.root_dir / "train.csv", index_col='image_id')
|
||||
if n_slides or frac_slides:
|
||||
self.train_df = self.train_df.sample(n=n_slides, frac=frac_slides, replace=False,
|
||||
random_state=1234)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.train_df.shape[0]
|
||||
|
||||
def _get_image_path(self, image_id: str) -> Path:
|
||||
return self.root_dir / "train_images" / f"{image_id}.tiff"
|
||||
|
||||
def _get_mask_path(self, image_id: str) -> Path:
|
||||
return self.root_dir / "train_label_masks" / f"{image_id}_mask.tiff"
|
||||
|
||||
def __getitem__(self, index: int) -> Dict:
|
||||
image_id = self.train_df.index[index]
|
||||
return {
|
||||
'image_id': image_id,
|
||||
'image': str(self._get_image_path(image_id).absolute()),
|
||||
'mask': str(self._get_mask_path(image_id).absolute()),
|
||||
**self.train_df.loc[image_id].to_dict()
|
||||
}
|
||||
|
||||
|
||||
# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
|
||||
class ReadImaged(MapTransform):
|
||||
"""Basic transform to read image files."""
|
||||
def __init__(self, reader: ImageReader, keys: KeysCollection,
|
||||
allow_missing_keys: bool = False, **kwargs: Any) -> None:
|
||||
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||
self.reader = reader
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
for key in self.keys:
|
||||
if key in data or not self.allow_missing_keys:
|
||||
data[key] = self.reader.read(data[key], **self.kwargs)
|
||||
return data
|
||||
|
||||
|
||||
class LoadPandaROId(MapTransform):
|
||||
"""Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI).
|
||||
|
||||
Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the
|
||||
respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries:
|
||||
- `'location'` (tuple): top-right coordinates of the bounding box
|
||||
- `'size'` (tuple): width and height of the bounding box
|
||||
- `'level'` (int): chosen magnification level
|
||||
- `'scale'` (float): corresponding scale, loaded from the file
|
||||
"""
|
||||
def __init__(self, image_reader: WSIReader, mask_reader: WSIReader,
|
||||
image_key: str = 'image', mask_key: str = 'mask',
|
||||
level: int = 0, margin: int = 0, **kwargs: Any) -> None:
|
||||
"""
|
||||
:param reader: An instance of MONAI's `WSIReader`.
|
||||
:param image_key: Image key in the input and output dictionaries.
|
||||
:param mask_key: Mask key in the input and output dictionaries.
|
||||
:param level: Magnification level to load from the raw multi-scale files.
|
||||
:param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping.
|
||||
"""
|
||||
super().__init__([image_key, mask_key], allow_missing_keys=False)
|
||||
self.image_reader = image_reader
|
||||
self.mask_reader = mask_reader
|
||||
self.image_key = image_key
|
||||
self.mask_key = mask_key
|
||||
self.level = level
|
||||
self.margin = margin
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box:
|
||||
# Estimate bounding box at the lowest resolution (i.e. highest level)
|
||||
highest_level = mask_obj.level_count - 1
|
||||
scale = mask_obj.level_downsamples[highest_level]
|
||||
mask, _ = self.mask_reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image
|
||||
|
||||
foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
|
||||
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
|
||||
return bbox
|
||||
|
||||
def __call__(self, data: Dict) -> Dict:
|
||||
mask_obj: OpenSlide = self.mask_reader.read(data[self.mask_key])
|
||||
image_obj: OpenSlide = self.image_reader.read(data[self.image_key])
|
||||
|
||||
level0_bbox = self._get_bounding_box(mask_obj)
|
||||
|
||||
# OpenSlide takes absolute location coordinates in the level 0 reference frame,
|
||||
# but relative region size in pixels at the chosen level
|
||||
scale = mask_obj.level_downsamples[self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
|
||||
size=(scaled_bbox.w, scaled_bbox.h),
|
||||
level=self.level)
|
||||
mask, _ = self.mask_reader.get_data(mask_obj, **get_data_kwargs) # type: ignore
|
||||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.image_reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data['scale'] = scale
|
||||
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
return data
|
|
@ -0,0 +1,212 @@
|
|||
import functools
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from monai.data import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from tqdm import tqdm
|
||||
|
||||
from Histopathology.preprocessing import tiling
|
||||
from Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
|
||||
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
|
||||
TMP_SUFFIX = "_tmp"
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
|
||||
-> Tuple[np.ndarray, np.ndarray]:
|
||||
if occupancy_threshold < 0. or occupancy_threshold > 1.:
|
||||
raise ValueError("Tile occupancy threshold must be between 0 and 1")
|
||||
foreground_mask = mask_tile > 0
|
||||
occupancy = foreground_mask.mean(axis=(-2, -1))
|
||||
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()
|
||||
|
||||
|
||||
def get_tile_descriptor(tile_location: Sequence[int]) -> str:
|
||||
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"
|
||||
|
||||
|
||||
def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
|
||||
return f"{slide_id}.{get_tile_descriptor(tile_location)}"
|
||||
|
||||
|
||||
def save_image(array_chw: np.ndarray, path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
|
||||
pil_image = Image.fromarray(array_hwc)
|
||||
pil_image.convert('RGB').save(path)
|
||||
|
||||
|
||||
def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
|
||||
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
||||
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
|
||||
constant_values=255)
|
||||
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)
|
||||
|
||||
selected: np.ndarray
|
||||
occupancies: np.ndarray
|
||||
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
|
||||
n_discarded = (~selected).sum()
|
||||
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")
|
||||
|
||||
image_tiles = image_tiles[selected]
|
||||
mask_tiles = mask_tiles[selected]
|
||||
tile_locations = tile_locations[selected]
|
||||
occupancies = occupancies[selected]
|
||||
|
||||
abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)
|
||||
|
||||
return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded
|
||||
|
||||
|
||||
# TODO refactor this to separate metadata identification from saving. We might want the metadata
|
||||
# even if the saving fails
|
||||
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
|
||||
tile_location: Sequence[int], output_dir: Path) -> dict:
|
||||
slide_id = sample['image_id']
|
||||
descriptor = get_tile_descriptor(tile_location)
|
||||
image_tile_filename = f"train_images/{descriptor}.png"
|
||||
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"
|
||||
|
||||
save_image(image_tile, output_dir / image_tile_filename)
|
||||
save_image(mask_tile, output_dir / mask_tile_filename)
|
||||
|
||||
tile_metadata = {
|
||||
'slide_id': slide_id,
|
||||
'tile_id': get_tile_id(slide_id, tile_location),
|
||||
'image': image_tile_filename,
|
||||
'mask': mask_tile_filename,
|
||||
'tile_x': tile_location[0],
|
||||
'tile_y': tile_location[1],
|
||||
'data_provider': sample['data_provider'],
|
||||
'slide_isup_grade': sample['isup_grade'],
|
||||
'slide_gleason_score': sample['gleason_score'],
|
||||
}
|
||||
|
||||
return tile_metadata
|
||||
|
||||
|
||||
def process_slide(image_wsi_reader: str, save_images: bool,
|
||||
sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> \
|
||||
Optional[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
|
||||
slide_id = sample['image_id']
|
||||
slide_dir: Path = output_dir / (slide_id + "/")
|
||||
logging.info(f">>> Slide dir {slide_dir}")
|
||||
if slide_dir.exists(): # already processed slide - skip
|
||||
logging.info(f">>> Skipping {slide_dir} - already processed")
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
slide_dir.mkdir(parents=True)
|
||||
|
||||
dataset_csv_path = slide_dir / "dataset.csv"
|
||||
dataset_csv_file = dataset_csv_path.open('w')
|
||||
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
|
||||
tiles_failure = 0
|
||||
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
|
||||
failed_tiles_file = failed_tiles_csv_path.open('w')
|
||||
failed_tiles_file.write('tile_id' + '\n')
|
||||
|
||||
logging.info(f"Loading slide {slide_id} ...")
|
||||
loader = LoadPandaROId(WSIReader(image_wsi_reader), WSIReader(), level=level, margin=margin)
|
||||
sample = loader(sample) # load 'image' and 'mask' from disk
|
||||
|
||||
logging.info(f"Tiling slide {slide_id} ...")
|
||||
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
|
||||
generate_tiles(sample, tile_size, occupancy_threshold)
|
||||
if not save_images:
|
||||
return image_tiles, mask_tiles, tile_locations, occupancies
|
||||
n_tiles = image_tiles.shape[0]
|
||||
|
||||
for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
|
||||
try:
|
||||
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
|
||||
slide_dir)
|
||||
tile_metadata['occupancy'] = occupancies[i]
|
||||
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
|
||||
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
|
||||
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
|
||||
dataset_csv_file.write(dataset_row + '\n')
|
||||
except Exception as e:
|
||||
tiles_failure += 1
|
||||
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
|
||||
failed_tiles_file.write(descriptor)
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while saving tile "
|
||||
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")
|
||||
|
||||
dataset_csv_file.close()
|
||||
failed_tiles_file.close()
|
||||
if tiles_failure > 0:
|
||||
# TODO what we want to do with slides that have some failed tiles?
|
||||
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def merge_dataset_csv_files(dataset_dir: Path) -> Path:
|
||||
full_csv = dataset_dir / "dataset.csv"
|
||||
with full_csv.open('w') as full_csv_file:
|
||||
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
|
||||
first_file = True
|
||||
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
|
||||
logging.info(f"Merging slide {slide_csv}")
|
||||
content = slide_csv.read_text()
|
||||
if not first_file:
|
||||
content = content[content.index('\n') + 1:] # discard header row for all but the first file
|
||||
full_csv_file.write(content)
|
||||
first_file = False
|
||||
return full_csv
|
||||
|
||||
|
||||
def main(process_slide: Callable,
|
||||
panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
|
||||
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:
|
||||
|
||||
# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
|
||||
# to select a subsample use keyword n_slides
|
||||
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore
|
||||
|
||||
output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
|
||||
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")
|
||||
|
||||
if overwrite and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=not overwrite)
|
||||
|
||||
func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
|
||||
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
|
||||
tile_progress=not parallel)
|
||||
|
||||
if parallel:
|
||||
import multiprocessing
|
||||
|
||||
pool = multiprocessing.Pool()
|
||||
map_func = pool.imap_unordered # type: ignore
|
||||
else:
|
||||
map_func = map # type: ignore
|
||||
|
||||
list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore
|
||||
|
||||
if parallel:
|
||||
pool.close() # type: ignore
|
||||
|
||||
logging.info("Merging slide files in a single file")
|
||||
merge_dataset_csv_files(output_dir)
|
|
@ -0,0 +1,123 @@
|
|||
# These tiling implementations are adapted from PANDA Kaggle solutions, for example:
|
||||
# https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_1d_padding(length: int, tile_size: int) -> Tuple[int, int]:
|
||||
"""Computes symmetric padding for `length` to be divisible by `tile_size`."""
|
||||
pad = (tile_size - length % tile_size) % tile_size
|
||||
return (pad // 2, pad - pad // 2)
|
||||
|
||||
|
||||
def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
|
||||
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Symmetrically pads a 2D `array` such that both dimensions are divisible by `tile_size`.
|
||||
|
||||
:param array: 2D image array.
|
||||
:param tile_size: Width/height of each tile in pixels.
|
||||
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
|
||||
:return: A tuple containing:
|
||||
- `padded_array`: Resulting array, in the same CHW/HWC layout as the input.
|
||||
- `offset`: XY offset introduced by the padding. Add this to coordinates relative to the
|
||||
original array to obtain indices for the padded array.
|
||||
"""
|
||||
height, width = array.shape[1:] if channels_first else array.shape[:-1]
|
||||
padding_h = get_1d_padding(height, tile_size)
|
||||
padding_w = get_1d_padding(width, tile_size)
|
||||
padding = [padding_h, padding_w]
|
||||
channels_axis = 0 if channels_first else 2
|
||||
padding.insert(channels_axis, (0, 0)) # zero padding on channels axis
|
||||
padded_array = np.pad(array, padding, **pad_kwargs)
|
||||
offset = (padding_w[0], padding_h[0])
|
||||
return padded_array, np.array(offset)
|
||||
|
||||
|
||||
def tile_array_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True,
|
||||
**pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Split an image array into square non-overlapping tiles.
|
||||
|
||||
The array will be padded symmetrically if its dimensions are not exact multiples of `tile_size`.
|
||||
|
||||
:param array: Image array.
|
||||
:param tile_size: Width/height of each tile in pixels.
|
||||
:param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`).
|
||||
:param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:return: A tuple containing:
|
||||
- `tiles`: A batch of tiles in NCHW layout.
|
||||
- `coords`: XY coordinates of each tile, in the same order.
|
||||
"""
|
||||
padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first, **pad_kwargs)
|
||||
if channels_first:
|
||||
channels, height, width = padded_array.shape
|
||||
else:
|
||||
height, width, channels = padded_array.shape
|
||||
n_tiles_h = height // tile_size
|
||||
n_tiles_w = width // tile_size
|
||||
|
||||
if channels_first:
|
||||
intermediate_shape = (channels, n_tiles_h, tile_size, n_tiles_w, tile_size)
|
||||
axis_order = (1, 3, 0, 2, 4) # (n_tiles_h, n_tiles_w, channels, tile_size, tile_size)
|
||||
output_shape = (n_tiles_h * n_tiles_w, channels, tile_size, tile_size)
|
||||
else:
|
||||
intermediate_shape = (n_tiles_h, tile_size, n_tiles_w, tile_size, channels)
|
||||
axis_order = (0, 2, 1, 3, 4) # (n_tiles_h, n_tiles_w, tile_size, tile_size, channels)
|
||||
output_shape = (n_tiles_h * n_tiles_w, tile_size, tile_size, channels)
|
||||
|
||||
tiles = padded_array.reshape(intermediate_shape) # Split width and height axes
|
||||
tiles = tiles.transpose(axis_order)
|
||||
tiles = tiles.reshape(output_shape) # Flatten tile batch dimension
|
||||
|
||||
# Compute top-left coordinates of every tile, relative to the original array's origin
|
||||
coords_h = tile_size * np.arange(n_tiles_h) - offset_h
|
||||
coords_w = tile_size * np.arange(n_tiles_w) - offset_w
|
||||
# Shape: (n_tiles_h * n_tiles_w, 2)
|
||||
coords = np.stack(np.meshgrid(coords_w, coords_h), axis=-1).reshape(-1, 2)
|
||||
|
||||
return tiles, coords
|
||||
|
||||
|
||||
def assemble_tiles_2d(tiles: np.ndarray, coords: np.ndarray, fill_value: Optional[float] = np.nan,
|
||||
channels_first: Optional[bool] = True) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Assembles a 2D array from sequences of tiles and coordinates.
|
||||
|
||||
:param tiles: Stack of tiles with batch dimension first.
|
||||
:param coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]).
|
||||
:param tile_size: Size of each tile; must be >0.
|
||||
:param fill_value: Value to assign to empty elements (default: `NaN`).
|
||||
:param channels_first: Whether each tile is in CHW (`True`, default) or HWC (`False`) layout.
|
||||
:return: A tuple containing:
|
||||
- `array`: The reassembled 2D array with the smallest dimensions to contain all given tiles.
|
||||
- `offset`: The lowest XY coordinates.
|
||||
- `offset`: XY offset introduced by the assembly. Add this to tile coordinates to obtain
|
||||
indices for the assembled array.
|
||||
"""
|
||||
if coords.shape[0] != tiles.shape[0]:
|
||||
raise ValueError(f"Tile coordinates and values must have the same length, "
|
||||
f"got {coords.shape[0]} and {tiles.shape[0]}")
|
||||
|
||||
if channels_first:
|
||||
n_tiles, channels, tile_size, _ = tiles.shape
|
||||
else:
|
||||
n_tiles, tile_size, _, channels = tiles.shape
|
||||
tile_xs, tile_ys = coords.T
|
||||
|
||||
x_min, x_max = min(tile_xs), max(tile_xs + tile_size)
|
||||
y_min, y_max = min(tile_ys), max(tile_ys + tile_size)
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
output_shape = (channels, height, width) if channels_first else (height, width, channels)
|
||||
array = np.full(output_shape, fill_value)
|
||||
|
||||
offset = np.array([-x_min, -y_min])
|
||||
for idx in range(n_tiles):
|
||||
row = coords[idx, 1] + offset[1]
|
||||
col = coords[idx, 0] + offset[0]
|
||||
if channels_first:
|
||||
array[:, row:row + tile_size, col:col + tile_size] = tiles[idx]
|
||||
else:
|
||||
array[row:row + tile_size, col:col + tile_size, :] = tiles[idx]
|
||||
|
||||
return array, offset
|
|
@ -0,0 +1,134 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
"""Utility class representing rectangular regions in 2D images.
|
||||
|
||||
:param x: Horizontal coordinate of the top-left corner.
|
||||
:param y: Vertical coordinate of the top-left corner.
|
||||
:param w: Box width.
|
||||
:param h: Box height.
|
||||
:raises ValueError: If either `w` or `h` are <= 0.
|
||||
"""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.w <= 0:
|
||||
raise ValueError(f"Width must be strictly positive, received {self.w}")
|
||||
if self.h <= 0:
|
||||
raise ValueError(f"Height must be strictly positive, received {self.w}")
|
||||
|
||||
def __add__(self, shift: Sequence[int]) -> 'Box':
|
||||
"""Translates the box's location by a given shift.
|
||||
|
||||
:param shift: A length-2 sequence containing horizontal and vertical shifts.
|
||||
:return: A new box with updated `x = x + shift[0]` and `y = y + shift[1]`.
|
||||
:raises ValueError: If `shift` does not have two elements.
|
||||
"""
|
||||
if len(shift) != 2:
|
||||
raise ValueError("Shift must be two-dimensional")
|
||||
return Box(x=self.x + shift[0],
|
||||
y=self.y + shift[1],
|
||||
w=self.w,
|
||||
h=self.h)
|
||||
|
||||
def __mul__(self, factor: float) -> 'Box':
|
||||
"""Scales the box by a given factor, e.g. when changing resolution.
|
||||
|
||||
:param factor: The factor by which to multiply the box's location and dimensions.
|
||||
:return: The updated box, with location and dimensions rounded to `int`.
|
||||
"""
|
||||
return Box(x=int(self.x * factor),
|
||||
y=int(self.y * factor),
|
||||
w=int(self.w * factor),
|
||||
h=int(self.h * factor))
|
||||
|
||||
def __rmul__(self, factor: float) -> 'Box':
|
||||
"""Scales the box by a given factor, e.g. when changing resolution.
|
||||
|
||||
:param factor: The factor by which to multiply the box's location and dimensions.
|
||||
:return: The updated box, with location and dimensions rounded to `int`.
|
||||
"""
|
||||
return self * factor
|
||||
|
||||
def __truediv__(self, factor: float) -> 'Box':
|
||||
"""Scales the box by a given factor, e.g. when changing resolution.
|
||||
|
||||
:param factor: The factor by which to divide the box's location and dimensions.
|
||||
:return: The updated box, with location and dimensions rounded to `int`.
|
||||
"""
|
||||
return self * (1. / factor)
|
||||
|
||||
def add_margin(self, margin: int) -> 'Box':
|
||||
"""Adds a symmetric margin on all sides of the box.
|
||||
|
||||
:param margin: The amount by which to enlarge the box.
|
||||
:return: A new box enlarged by `margin` on all sides.
|
||||
"""
|
||||
return Box(x=self.x - margin,
|
||||
y=self.y - margin,
|
||||
w=self.w + 2 * margin,
|
||||
h=self.h + 2 * margin)
|
||||
|
||||
def clip(self, other: 'Box') -> Optional['Box']:
|
||||
"""Clips a box to the interior of another.
|
||||
|
||||
This is useful to constrain a region to the interior of an image.
|
||||
|
||||
:param other: Box representing the new constraints.
|
||||
:return: A new constrained box, or `None` if the boxes do not overlap.
|
||||
"""
|
||||
x0 = max(self.x, other.x)
|
||||
y0 = max(self.y, other.y)
|
||||
x1 = min(self.x + self.w, other.x + other.w)
|
||||
y1 = min(self.y + self.h, other.y + other.h)
|
||||
try:
|
||||
return Box(x=x0, y=y0, w=x1 - x0, h=y1 - y0)
|
||||
except ValueError: # Empty result, boxes don't overlap
|
||||
return None
|
||||
|
||||
def to_slices(self) -> Tuple[slice, slice]:
|
||||
"""Converts the box to slices for indexing arrays.
|
||||
|
||||
For example: `my_2d_array[my_box.to_slices()]`.
|
||||
|
||||
:return: A 2-tuple with vertical and horizontal slices.
|
||||
"""
|
||||
return (slice(self.y, self.y + self.h),
|
||||
slice(self.x, self.x + self.w))
|
||||
|
||||
@staticmethod
|
||||
def from_slices(slices: Sequence[slice]) -> 'Box':
|
||||
"""Converts a pair of vertical and horizontal slices into a box.
|
||||
|
||||
:param slices: A length-2 sequence containing vertical and horizontal `slice` objects.
|
||||
:return: A box with corresponding location and dimensions.
|
||||
"""
|
||||
vert_slice, horz_slice = slices
|
||||
return Box(x=horz_slice.start,
|
||||
y=vert_slice.start,
|
||||
w=horz_slice.stop - horz_slice.start,
|
||||
h=vert_slice.stop - vert_slice.start)
|
||||
|
||||
|
||||
def get_bounding_box(mask: ArrayLike) -> Box:
|
||||
"""Extracts a bounding box from a binary 2D array.
|
||||
|
||||
:param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground.
|
||||
:return: The smallest box covering all non-zero elements of `mask`.
|
||||
"""
|
||||
xs = np.sum(mask, 1).nonzero()[0]
|
||||
ys = np.sum(mask, 0).nonzero()[0]
|
||||
x_min, x_max = xs.min(), xs.max()
|
||||
y_min, y_max = ys.min(), ys.max()
|
||||
width = x_max - x_min + 1
|
||||
height = y_max - y_min + 1
|
||||
return Box(x_min, y_min, width, height)
|
|
@ -0,0 +1,236 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from line_profiler import LineProfiler
|
||||
from openslide import OpenSlide
|
||||
from PIL import Image
|
||||
import cucim
|
||||
import numpy as np
|
||||
|
||||
from azureml.core import Dataset
|
||||
from health_azure import get_workspace, is_running_in_azure_ml
|
||||
|
||||
from Histopathology.preprocessing.create_tiles_dataset import process_slide, save_tile, generate_tiles
|
||||
|
||||
|
||||
def profile_cucim(input_file: Path,
|
||||
output_file: Path) -> None:
|
||||
"""
|
||||
Load an input_file with cuCIM, print out basic properties, and save as output_file.
|
||||
|
||||
:param input_file: Input file path.
|
||||
:param output_file: Output file path.
|
||||
:return: None
|
||||
"""
|
||||
img = cucim.CuImage(str(input_file))
|
||||
|
||||
count = img.resolutions['level_count']
|
||||
dimensions = img.resolutions['level_dimensions']
|
||||
|
||||
print(f"level_count: {count}")
|
||||
print(f"level_dimensions: {dimensions}")
|
||||
|
||||
print(img.metadata)
|
||||
|
||||
region = img.read_region(location=(0, 0),
|
||||
size=dimensions[count-1],
|
||||
level=count-1)
|
||||
np_img_arr = np.asarray(region)
|
||||
img2 = Image.fromarray(np_img_arr)
|
||||
img2.save(output_file)
|
||||
|
||||
|
||||
def profile_openslide(input_file: Path,
|
||||
output_file: Path) -> None:
|
||||
"""
|
||||
Load an input_file with OpenSlide, print out basic properties, and save as output_file.
|
||||
|
||||
:param input_file: Input file path.
|
||||
:param output_file: Output file path.
|
||||
:return: None
|
||||
"""
|
||||
with OpenSlide(str(input_file)) as img:
|
||||
count = img.level_count
|
||||
dimensions = img.level_dimensions
|
||||
|
||||
print(f"level_count: {count}")
|
||||
print(f"dimensions: {dimensions}")
|
||||
|
||||
for k, v in img.properties.items():
|
||||
print(k, v)
|
||||
|
||||
region = img.read_region(location=(0, 0),
|
||||
level=count-1,
|
||||
size=dimensions[count-1])
|
||||
region.save(output_file)
|
||||
|
||||
|
||||
def profile_folder(mount_point: Path,
|
||||
output_folder: Path,
|
||||
subfolder: str) -> None:
|
||||
"""
|
||||
For each *.tiff image file in the given subfolder or the mount_point,
|
||||
load each with cuCIM or OpenSlide, print out basic properties, and save as a png.
|
||||
|
||||
:param mount_point: Base path for source images.
|
||||
:param output_folder: Base path to save output images.
|
||||
:param subfolder: Subfolder of mount_point to search for tiffs.
|
||||
:return: None.
|
||||
"""
|
||||
cucim_output_folder = output_folder / "image_cucim" / subfolder
|
||||
cucim_output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
openslide_output_folder = output_folder / "image_openslide" / subfolder
|
||||
openslide_output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for image_file in (mount_point / subfolder).glob("*.tiff"):
|
||||
output_filename = image_file.with_suffix(".png").name
|
||||
|
||||
try:
|
||||
profile_cucim(image_file, cucim_output_folder / output_filename)
|
||||
except Exception as ex:
|
||||
print(f"Error calling cuCIM: {str(ex)}")
|
||||
profile_openslide(image_file, openslide_output_folder / output_filename)
|
||||
|
||||
|
||||
def profile_folders(mount_point: Path,
|
||||
output_folder: Path) -> None:
|
||||
profile_folder(mount_point, output_folder, "train_images")
|
||||
profile_folder(mount_point, output_folder, "train_label_masks")
|
||||
|
||||
|
||||
def print_cache_state(cache) -> None: # type: ignore
|
||||
"""
|
||||
Print out cuCIM cache state
|
||||
"""
|
||||
print(f"cache_hit: {cache.hit_count}, cache_miss: {cache.miss_count}")
|
||||
print(f"items in cache: {cache.size}/{cache.capacity}, "
|
||||
f"memory usage in cache: {cache.memory_size}/{cache.memory_capacity}")
|
||||
|
||||
|
||||
def wrap_profile_folders(mount_point: Path,
|
||||
output_folder: Path) -> None:
|
||||
"""
|
||||
Load some tiffs with cuCIM and OpenSlide, save them, and run line_profile.
|
||||
|
||||
:return: None.
|
||||
"""
|
||||
def wrap_profile_folders() -> None:
|
||||
profile_folders(mount_point, output_folder)
|
||||
|
||||
lp = LineProfiler()
|
||||
lp.add_function(profile_cucim)
|
||||
lp.add_function(profile_openslide)
|
||||
lp.add_function(profile_folder)
|
||||
lp_wrapper = lp(wrap_profile_folders)
|
||||
lp_wrapper()
|
||||
with open("outputs/profile_folders.txt", "w", encoding="utf-8") as f:
|
||||
lp.print_stats(f)
|
||||
|
||||
|
||||
def profile_main(mount_point: Path,
|
||||
output_folder: Path,
|
||||
label: str,
|
||||
process: Callable) -> None:
|
||||
def wrap_main() -> None:
|
||||
from Histopathology.preprocessing.create_tiles_dataset import main
|
||||
main(process,
|
||||
panda_dir=mount_point,
|
||||
root_output_dir=output_folder / label,
|
||||
level=1,
|
||||
tile_size=224,
|
||||
margin=64,
|
||||
occupancy_threshold=0.05,
|
||||
parallel=False,
|
||||
overwrite=True)
|
||||
|
||||
lp = LineProfiler()
|
||||
lp.add_function(process_slide)
|
||||
lp.add_function(save_tile)
|
||||
lp.add_function(generate_tiles)
|
||||
lp_wrapper = lp(wrap_main)
|
||||
lp_wrapper()
|
||||
with open(f"outputs/profile_{label}.txt", "w", encoding="utf-8") as f:
|
||||
lp.print_stats(f)
|
||||
|
||||
|
||||
def process_slide_open_slide_no_save(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> \
|
||||
Optional[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
|
||||
return process_slide('openslide', False,
|
||||
sample, level, margin, tile_size, occupancy_threshold,
|
||||
output_dir, tile_progress)
|
||||
|
||||
|
||||
def process_slide_cucim_no_save(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> \
|
||||
Optional[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
|
||||
return process_slide('cucim', False,
|
||||
sample, level, margin, tile_size, occupancy_threshold,
|
||||
output_dir, tile_progress)
|
||||
|
||||
|
||||
def process_slide_openslide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> None:
|
||||
process_slide('openslide', True,
|
||||
sample, level, margin, tile_size, occupancy_threshold,
|
||||
output_dir, tile_progress)
|
||||
|
||||
|
||||
def process_slide_cucim(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
|
||||
output_dir: Path, tile_progress: bool = False) -> None:
|
||||
process_slide('cucim', True,
|
||||
sample, level, margin, tile_size, occupancy_threshold,
|
||||
output_dir, tile_progress)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Load some tiffs with cuCIM and OpenSlide, then run image tiling with both libraries.
|
||||
|
||||
:return: None.
|
||||
"""
|
||||
print(f"cucim.is_available(): {cucim.is_available()}")
|
||||
print(f"cucim.is_available('skimage'): {cucim.is_available('skimage')}")
|
||||
print(f"cucim.is_available('core'): {cucim.is_available('core')}")
|
||||
print(f"cucim.is_available('clara'): {cucim.is_available('clara')}")
|
||||
|
||||
ws = get_workspace(aml_workspace=None, workspace_config_path=None)
|
||||
dataset = Dataset.get_by_name(ws, name='panda')
|
||||
|
||||
output_folder = Path("outputs") if is_running_in_azure_ml() else Path("../outputs")
|
||||
|
||||
with dataset.mount("/tmp/datasets/panda") as mount_context:
|
||||
mount_point = Path(mount_context.mount_point)
|
||||
|
||||
wrap_profile_folders(mount_point, output_folder)
|
||||
|
||||
labels = ['tile_cucim_no_save', 'tile_openslide_no_save', 'tile_cucim', 'tile_openslide']
|
||||
|
||||
for i, process in enumerate([process_slide_cucim_no_save,
|
||||
process_slide_open_slide_no_save,
|
||||
process_slide_cucim,
|
||||
process_slide_openslide]):
|
||||
profile_main(mount_point, output_folder, labels[i], process)
|
||||
|
||||
cucim.CuImage.cache("per_process", memory_capacity=2048, record_stat=True)
|
||||
cache = cucim.CuImage.cache()
|
||||
print(f"cucim.cache.config: {cache.config}")
|
||||
print_cache_state(cache)
|
||||
|
||||
labels = ['tile_cucim_no_save_cache', 'tile_cucim_cache']
|
||||
|
||||
for i, process in enumerate([process_slide_cucim_no_save,
|
||||
process_slide_cucim]):
|
||||
profile_main(mount_point, output_folder, labels[i], process)
|
||||
|
||||
print_cache_state(cache)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -4,7 +4,16 @@
|
|||
"hi-ml",
|
||||
"hi-ml-azure"
|
||||
],
|
||||
"useLibraryCodeForTypes ": true,
|
||||
"useLibraryCodeForTypes": false,
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportPrivateImportUsage": false,
|
||||
"executionEnvironments": [
|
||||
{
|
||||
"root": "hi-ml/testhiml/testhiml/utils/slide_image_loading/src/Histopathology",
|
||||
"extraPaths": [
|
||||
"hi-ml/testhiml/testhiml/utils/slide_image_loading/src"
|
||||
]
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче