Replace RadIO with TorchIO for patch-based inference (#666)
* Replace RadIO with TorchIO * Ensure patches are float32 for forward pass * Update changelog * Ignore some types to fix mypy errors * Remove APEX from conda environment in docs example Co-authored-by: Javier <jaalvare@microsoft.com>
This commit is contained in:
Родитель
e2ec5cc839
Коммит
d7e5d8b5e5
|
@ -13,6 +13,8 @@ created.
|
|||
## Upcoming
|
||||
|
||||
### Added
|
||||
|
||||
- ([#666](https://github.com/microsoft/InnerEye-DeepLearning/pull/666)) Replace RadIO with TorchIO for patch-based inference.
|
||||
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
|
||||
loss.
|
||||
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
|
||||
|
|
|
@ -7,21 +7,18 @@ from __future__ import annotations
|
|||
import logging
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from radio import CTImagesMaskedBatch
|
||||
from radio.batchflow import Dataset, action, inbatch_parallel
|
||||
import torchio as tio
|
||||
|
||||
from InnerEye.Common.type_annotations import TupleFloat3
|
||||
from InnerEye.ML import config
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_inference
|
||||
from InnerEye.ML.lightning_models import SegmentationLightning
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel
|
||||
from InnerEye.ML.utils import image_util, ml_util
|
||||
from InnerEye.ML.utils.image_util import compute_uncertainty_map_from_posteriors, gaussian_smooth_posteriors, \
|
||||
posteriors_to_segmentation
|
||||
|
@ -218,6 +215,24 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
assert isinstance(lightning_model, SegmentationLightning)
|
||||
return InferencePipeline(model=lightning_model, model_config=model_config, pipeline_id=pipeline_id)
|
||||
|
||||
def post_process_posteriors(self, posteriors: np.ndarray, mask: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Perform post processing on the computed outputs of the a single pass of the pipelines.
|
||||
Currently the following operations are performed:
|
||||
-------------------------------------------------------------------------------------
|
||||
1) the mask is applied to the posteriors (if required).
|
||||
2) the final posteriors are used to perform an argmax to generate a multi-label segmentation.
|
||||
3) extract the largest foreground connected component in the segmentation if required
|
||||
"""
|
||||
if mask is not None:
|
||||
posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask)
|
||||
|
||||
# create segmentation using an argmax over the posterior probabilities
|
||||
segmentation = image_util.posteriors_to_segmentation(posteriors)
|
||||
|
||||
return posteriors, segmentation
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_whole_image(self, image_channels: np.ndarray,
|
||||
voxel_spacing_mm: TupleFloat3,
|
||||
mask: np.ndarray = None,
|
||||
|
@ -238,259 +253,48 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
if mask is not None:
|
||||
ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3])
|
||||
self.model.eval()
|
||||
# create the dataset for the batch
|
||||
batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch)
|
||||
# setup the pipeline
|
||||
pipeline = (batch_dataset.p
|
||||
# define pipeline variables
|
||||
.init_variables([InferencePipeline.Variables.Model,
|
||||
InferencePipeline.Variables.ModelConfig,
|
||||
InferencePipeline.Variables.CropSize,
|
||||
InferencePipeline.Variables.OutputSize,
|
||||
InferencePipeline.Variables.OutputImageShape,
|
||||
InferencePipeline.Variables.Stride])
|
||||
# update the variables for the batch actions
|
||||
.update_variable(name=InferencePipeline.Variables.Model, value=self.model)
|
||||
.update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config)
|
||||
# perform cascaded batch actions
|
||||
.load(image_channels=image_channels, mask=mask)
|
||||
.pre_process()
|
||||
.predict()
|
||||
.post_process()
|
||||
|
||||
image = tio.ScalarImage(tensor=image_channels)
|
||||
subject = tio.Subject(image=image)
|
||||
|
||||
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
|
||||
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
|
||||
restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore
|
||||
effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore
|
||||
self.model_config.test_crop_size,
|
||||
self.model_config.inference_stride_size)
|
||||
|
||||
patch_overlap = np.array(effective_patch_size) - np.array(effective_stride)
|
||||
grid_sampler = tio.inference.GridSampler(
|
||||
subject,
|
||||
effective_patch_size,
|
||||
patch_overlap,
|
||||
padding_mode=self.model_config.padding_mode.value,
|
||||
)
|
||||
# run the batch through the pipeline
|
||||
logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}")
|
||||
processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1)
|
||||
posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors)
|
||||
batch_size = self.model_config.inference_batch_size
|
||||
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size) # type: ignore
|
||||
aggregator = tio.inference.GridAggregator(grid_sampler)
|
||||
|
||||
logging.debug(
|
||||
f"Inference on image size {image.spatial_shape} will run "
|
||||
f"with crop size {effective_patch_size} and stride {effective_stride}")
|
||||
for patches_batch in patch_loader:
|
||||
input_tensor = patches_batch['image'][tio.DATA].float()
|
||||
if self.model_config.use_gpu:
|
||||
input_tensor = input_tensor.cuda()
|
||||
locations = patches_batch[tio.LOCATION]
|
||||
# perform the forward pass
|
||||
patches_posteriors = self.model(input_tensor).detach()
|
||||
# collect the predictions over each of the batches
|
||||
aggregator.add_batch(patches_posteriors, locations)
|
||||
posteriors = aggregator.get_output_tensor().numpy()
|
||||
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask)
|
||||
|
||||
image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
|
||||
# prepare pipeline results from the processed batch
|
||||
return InferencePipeline.Result(
|
||||
patient_id=patient_id,
|
||||
segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation),
|
||||
segmentation=segmentation,
|
||||
posteriors=posteriors,
|
||||
voxel_spacing_mm=voxel_spacing_mm
|
||||
)
|
||||
|
||||
|
||||
class InferenceBatch(CTImagesMaskedBatch):
|
||||
"""
|
||||
Batch class for IO with the inference pipeline. One instance of a batch will load the image
|
||||
into the 'images' component of the pipeline, and store the results of the full pass
|
||||
of the pipeline into the 'segmentation' and 'posteriors' components.
|
||||
"""
|
||||
|
||||
class Components(Enum):
|
||||
"""
|
||||
Components associated with the inference batch class
|
||||
"""
|
||||
|
||||
# the input image channels in Channels x Z x Y x X format.
|
||||
ImageChannels = 'channels'
|
||||
# a set of 2D image slices (ie: a 3D image channel), stacked in Z x Y x X format.
|
||||
Images = 'images'
|
||||
# a binary mask used to ignore predictions in Z x Y x X format.
|
||||
Mask = 'mask'
|
||||
# a numpy.ndarray in Z x Y x X format with class labels for each voxel in the original image.
|
||||
Segmentation = 'segmentation'
|
||||
# a numpy.ndarray with the first dimension indexing each class in C x Z x Y x X format
|
||||
# with each Z x Y x X being the same shape as the Images component, and consisting of
|
||||
# [0, 1] values representing the model confidence for each voxel.
|
||||
Posteriors = 'posteriors'
|
||||
|
||||
def __init__(self, index: int, *args: Any, **kwargs: Any):
|
||||
super().__init__(index, *args, **kwargs)
|
||||
self.components = [x.value for x in InferenceBatch.Components]
|
||||
|
||||
@action
|
||||
def load(self, image_channels: np.ndarray, mask: np.ndarray) -> InferenceBatch:
|
||||
"""
|
||||
Load image channels and mask into their respective pipeline components.
|
||||
"""
|
||||
self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels)
|
||||
model_config = self.get_configs()
|
||||
if model_config is None:
|
||||
raise ValueError("model_config is None")
|
||||
if model_config.test_crop_size is None:
|
||||
raise ValueError("model_config.test_crop_size is None")
|
||||
if model_config.inference_stride_size is None:
|
||||
raise ValueError("model_config.inference_stride_size is None")
|
||||
|
||||
# fetch the image channels from the batch
|
||||
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)
|
||||
self.pipeline.set_variable(name=InferencePipeline.Variables.OutputImageShape, value=image_channels[0].shape)
|
||||
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
|
||||
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
|
||||
image_size = image_channels.shape[1:]
|
||||
model: BaseSegmentationModel = self.pipeline.get_variable(InferencePipeline.Variables.Model).model
|
||||
effective_crop, effective_stride = \
|
||||
model.crop_size_constraints.restrict_crop_size_to_image(image_size,
|
||||
model_config.test_crop_size,
|
||||
model_config.inference_stride_size)
|
||||
self.pipeline.set_variable(name=InferencePipeline.Variables.CropSize, value=effective_crop)
|
||||
self.pipeline.set_variable(name=InferencePipeline.Variables.Stride, value=effective_stride)
|
||||
logging.debug(
|
||||
f"Inference on image size {image_size} will run "
|
||||
f"with crop size {effective_crop} and stride {effective_stride}")
|
||||
# In most cases, we will be able to read the output size from the pre-computed values
|
||||
# via get_output_size. Only if we have a non-standard (smaller) crop size, re-computed the output size.
|
||||
output_size = model_config.get_output_size(execution_mode=ModelExecutionMode.TEST)
|
||||
if effective_crop != model_config.test_crop_size:
|
||||
output_size = model.get_output_shape(input_shape=effective_crop) # type: ignore
|
||||
self.pipeline.set_variable(name=InferencePipeline.Variables.OutputSize, value=output_size)
|
||||
|
||||
if mask is not None:
|
||||
self.set_component(component=InferenceBatch.Components.Mask, data=mask)
|
||||
|
||||
return self
|
||||
|
||||
@action
|
||||
def pre_process(self) -> InferenceBatch:
|
||||
"""
|
||||
Prepare the input components of the batch for further processing.
|
||||
"""
|
||||
model_config = self.get_configs()
|
||||
|
||||
# fetch the image channels from the batch
|
||||
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)
|
||||
|
||||
crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize)
|
||||
output_size = self.pipeline.get_variable(InferencePipeline.Variables.OutputSize)
|
||||
image_channels = image_util.pad_images_for_inference(
|
||||
images=image_channels,
|
||||
crop_size=crop_size,
|
||||
output_size=output_size,
|
||||
padding_mode=model_config.padding_mode
|
||||
)
|
||||
|
||||
# update the post-processed components
|
||||
self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels)
|
||||
|
||||
return self
|
||||
|
||||
@action
|
||||
def predict(self) -> InferenceBatch:
|
||||
"""
|
||||
Perform a forward pass of the model on the provided image, this generates
|
||||
a set of posterior maps for each class, as well as a segmentation output
|
||||
stored in the respective 'posteriors' and 'segmentation' components.
|
||||
"""
|
||||
model_config = self.get_configs()
|
||||
|
||||
# extract patches for each image channel: Num patches x Channels x Z x Y x X
|
||||
patches = self._extract_patches_for_image_channels()
|
||||
|
||||
# split the generated patches into batches and perform forward passes
|
||||
predictions = []
|
||||
batch_size = model_config.inference_batch_size
|
||||
|
||||
for batch_idx in range(0, len(patches), batch_size):
|
||||
# slice over the batches to prepare batch
|
||||
batch = torch.tensor(patches[batch_idx: batch_idx + batch_size, ...]).float()
|
||||
if model_config.use_gpu:
|
||||
batch = batch.cuda()
|
||||
# perform the forward pass
|
||||
batch_predictions = self._model_fn(batch).detach().cpu().numpy()
|
||||
# collect the predictions over each of the batches
|
||||
predictions.append(batch_predictions)
|
||||
|
||||
# map the batched predictions to the original batch shape
|
||||
# of shape but with an added class dimension: Num patches x Class x Z x Y x X
|
||||
predictions = np.concatenate(predictions, axis=0)
|
||||
|
||||
# create posterior output for each class with the shape: Class x Z x Y x x. We use float32 as these
|
||||
# arrays can be big.
|
||||
output_image_shape = self.pipeline.get_variable(InferencePipeline.Variables.OutputImageShape)
|
||||
posteriors = np.zeros(shape=[model_config.number_of_classes] + list(output_image_shape), dtype=np.float32)
|
||||
stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride)
|
||||
|
||||
for c in range(len(posteriors)):
|
||||
# stitch the patches for each posterior class
|
||||
self.load_from_patches(predictions[:, c, ...], # type: ignore
|
||||
stride=stride,
|
||||
scan_shape=output_image_shape,
|
||||
data_attr=InferenceBatch.Components.Posteriors.value)
|
||||
# extract computed output from the component so the pipeline buffer can be reused
|
||||
posteriors[c] = self.get_component(InferenceBatch.Components.Posteriors)
|
||||
|
||||
# store the stitched up results for the batch
|
||||
self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors)
|
||||
|
||||
return self
|
||||
|
||||
@action
|
||||
def post_process(self) -> InferenceBatch:
|
||||
"""
|
||||
Perform post processing on the computed outputs of the a single pass of the pipelines.
|
||||
Currently the following operations are performed:
|
||||
-------------------------------------------------------------------------------------
|
||||
1) the mask is applied to the posteriors (if required).
|
||||
2) the final posteriors are used to perform an argmax to generate a multi-label segmentation.
|
||||
3) extract the largest foreground connected component in the segmentation if required
|
||||
"""
|
||||
mask = self.get_component(InferenceBatch.Components.Mask)
|
||||
posteriors = self.get_component(InferenceBatch.Components.Posteriors)
|
||||
if mask is not None:
|
||||
posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask)
|
||||
|
||||
# create segmentation using an argmax over the posterior probabilities
|
||||
segmentation = image_util.posteriors_to_segmentation(posteriors)
|
||||
|
||||
# update the post-processed posteriors and save the segmentation
|
||||
self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors)
|
||||
self.set_component(component=InferenceBatch.Components.Segmentation, data=segmentation)
|
||||
|
||||
return self
|
||||
|
||||
def get_configs(self) -> config.SegmentationModelBase:
|
||||
return self.pipeline.get_variable(InferencePipeline.Variables.ModelConfig)
|
||||
|
||||
def get_component(self, component: InferenceBatch.Components) -> np.ndarray:
|
||||
return getattr(self, component.value) if hasattr(self, component.value) else None
|
||||
|
||||
@inbatch_parallel(init='indices', post='_post_custom_components', target='threads')
|
||||
def set_component(self, batch_idx: int, component: InferenceBatch.Components, data: np.ndarray) \
|
||||
-> Dict[str, Any]:
|
||||
logging.debug("Updated data in pipeline component: {}, for batch: {}.".format(component.value, batch_idx))
|
||||
return {
|
||||
component.value: {'type': component.value, 'data': data}
|
||||
}
|
||||
|
||||
def _extract_patches_for_image_channels(self) -> np.ndarray:
|
||||
"""
|
||||
Extracts deterministically, patches from each image channel
|
||||
:return: Patches for each image channel in format: Num patches x Channels x Z x Y x X
|
||||
"""
|
||||
model_config = self.get_configs()
|
||||
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)
|
||||
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
|
||||
# to always fit into image, and adjust stride accordingly. If test_crop_size is smaller than the
|
||||
# image, crop and stride will remain unchanged.
|
||||
crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize)
|
||||
stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride)
|
||||
patches = []
|
||||
for channel_index, channel in enumerate(image_channels):
|
||||
# set the current image channel component to process
|
||||
self.set_component(component=InferenceBatch.Components.Images, data=channel)
|
||||
channel_patches = self.get_patches(patch_shape=crop_size,
|
||||
stride=stride,
|
||||
padding=model_config.padding_mode.value,
|
||||
data_attr=InferenceBatch.Components.Images.value)
|
||||
logging.debug(
|
||||
f"Image channel {channel_index}: Tensor with extracted patches has size {channel_patches.shape}")
|
||||
patches.append(channel_patches)
|
||||
# reset the images component
|
||||
self.set_component(component=InferenceBatch.Components.Images, data=[])
|
||||
|
||||
return np.stack(patches, axis=1)
|
||||
|
||||
def _model_fn(self, patches: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper function to handle the model forward pass
|
||||
:param patches: Image patches to be passed to the model in format Patches x Channels x Z x Y x X
|
||||
:return posteriors: Confidence maps [0,1] for each patch per class
|
||||
in format: Patches x Channels x Class x Z x Y x X
|
||||
"""
|
||||
model = self.pipeline.get_variable(InferencePipeline.Variables.Model)
|
||||
# Model forward pass returns posteriors
|
||||
with torch.no_grad():
|
||||
return model(patches)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* You need to have a Conda installation on your machine.
|
||||
* Create a Conda environment file `environment.yml` in your source code with this contents:
|
||||
|
||||
```
|
||||
```yaml
|
||||
name: MyEnv
|
||||
channels:
|
||||
- defaults
|
||||
|
@ -15,8 +15,6 @@ dependencies:
|
|||
- python=3.7.3
|
||||
- pytorch=1.3.0
|
||||
- pip:
|
||||
- git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio
|
||||
- git+https://github.com/ptrblck/apex.git@4ad9b3b#egg=apex
|
||||
- innereye
|
||||
```
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ dependencies:
|
|||
- python-blosc=1.7.0
|
||||
- torchvision=0.11.1
|
||||
- pip:
|
||||
- git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio
|
||||
- azure-mgmt-resource==12.1.0
|
||||
- azure-mgmt-datafactory==1.1.0
|
||||
- azure-storage-blob==12.6.0
|
||||
|
@ -70,6 +69,7 @@ dependencies:
|
|||
- tabulate==0.8.7
|
||||
- tensorboard==2.3.0
|
||||
- tensorboardX==2.1
|
||||
- torchio==0.18.73
|
||||
- torchmetrics==0.6.0
|
||||
- umap-learn==0.5.2
|
||||
- yacs==0.1.8
|
||||
|
|
Загрузка…
Ссылка в новой задаче