Ensure the shape of input patches is compatible with model constraints (#682)

* Ensure patches shape is compatible with model constraints

* Remove test for images that are too small
This commit is contained in:
Fernando Pérez-García 2022-03-10 17:01:25 +00:00 коммит произвёл GitHub
Родитель 3c919dca14
Коммит d929ccbd46
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 38 добавлений и 17 удалений

Просмотреть файл

@ -93,6 +93,7 @@ gets uploaded to AzureML, by skipping all test folders.
### Fixed
- ([#682](https://github.com/microsoft/InnerEye-DeepLearning/pull/682)) Ensure the shape of input patches is compatible with model constraints.
- ([#681](https://github.com/microsoft/InnerEye-DeepLearning/pull/681)) Pad model outputs if they are smaller than the inputs.
- ([#683](https://github.com/microsoft/InnerEye-DeepLearning/pull/683)) Fix missing separator error in docs Makefile.
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset.

Просмотреть файл

@ -7,7 +7,7 @@ from __future__ import annotations
import logging
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
import numpy as np
import torch
@ -235,7 +235,7 @@ class InferencePipeline(FullImageInferencePipelineBase):
@torch.no_grad()
def predict_whole_image(self, image_channels: np.ndarray,
voxel_spacing_mm: TupleFloat3,
mask: np.ndarray = None,
mask: Optional[np.ndarray] = None,
patient_id: int = 0) -> InferencePipeline.Result:
"""
Performs a single inference pass through the pipeline for the provided image
@ -255,12 +255,26 @@ class InferencePipeline(FullImageInferencePipelineBase):
self.model.eval()
image = tio.ScalarImage(tensor=image_channels)
subject = tio.Subject(image=image)
INPUT = 'input_image'
MASK = 'mask'
subject_dict: Dict[str, tio.Image] = {INPUT: image}
if mask is not None:
subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis])
subject = tio.Subject(subject_dict)
constraints = self.model.model.crop_size_constraints
# Make sure the image size is compatible with the model
multiple_constraints = constraints.multiple_of # type: ignore
if multiple_constraints is not None:
ensure_shape_multiple = tio.EnsureShapeMultiple(constraints.multiple_of) # type: ignore
subject = ensure_shape_multiple(subject) # type: ignore
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore
effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore
restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore
effective_patch_size, effective_stride = restrict_patch_size(subject.spatial_shape, # type: ignore
self.model_config.test_crop_size,
self.model_config.inference_stride_size)
@ -276,10 +290,10 @@ class InferencePipeline(FullImageInferencePipelineBase):
aggregator = tio.inference.GridAggregator(grid_sampler)
logging.debug(
f"Inference on image size {image.spatial_shape} will run "
f"Inference on image size {subject.spatial_shape} will run "
f"with crop size {effective_patch_size} and stride {effective_stride}")
for patches_batch in patch_loader:
input_tensor = patches_batch['image'][tio.DATA].float()
input_tensor = patches_batch[INPUT][tio.DATA].float()
if self.model_config.use_gpu:
input_tensor = input_tensor.cuda()
locations = patches_batch[tio.LOCATION]
@ -296,9 +310,24 @@ class InferencePipeline(FullImageInferencePipelineBase):
# collect the predictions over each of the batches
aggregator.add_batch(patches_posteriors, locations)
posteriors = aggregator.get_output_tensor().numpy()
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask)
posteriors_mask = None if mask is None else subject[MASK].numpy()[0]
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=posteriors_mask)
image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
# Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any)
posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine)
segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine)
subject.add_image(posteriors_image, 'posteriors')
subject.add_image(segmentation_image, 'segmentation')
# Remove some images to avoid unnecessary computations
subject.remove_image(INPUT)
if mask is not None:
subject.remove_image(MASK)
subject_original_space = subject.apply_inverse_transform() if subject.applied_transforms else subject
posteriors = subject_original_space.posteriors.numpy() # type: ignore
segmentation = subject_original_space.segmentation.numpy()[0] # type: ignore
# prepare pipeline results from the processed batch
return InferencePipeline.Result(
patient_id=patient_id,

Просмотреть файл

@ -59,15 +59,6 @@ def run_inference_on_unet(size: TupleInt3) -> None:
image_util.check_array_range(p)
def test_inference_on_too_small_image() -> None:
"""
Running inference on a simplified Unet model when the input image is too small along an axis.
"""
with pytest.raises(ValueError) as ex:
run_inference_on_unet((5, 10, 64))
assert "input image must have at least a size of (16, 16, 16)" in str(ex)
@pytest.mark.parametrize("size", [(26, 20, 50), (16, 16, 16)])
def test_inference_on_small_image(size: TupleInt3) -> None:
"""