From d929ccbd4633e9119a954a837d5680f82124fc55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Thu, 10 Mar 2022 17:01:25 +0000 Subject: [PATCH] Ensure the shape of input patches is compatible with model constraints (#682) * Ensure patches shape is compatible with model constraints * Remove test for images that are too small --- CHANGELOG.md | 1 + InnerEye/ML/pipelines/inference.py | 45 +++++++++++++++---- .../pipelines/test_inference_smallimages.py | 9 ---- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cff60588..81f258b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,7 @@ gets uploaded to AzureML, by skipping all test folders. ### Fixed +- ([#682](https://github.com/microsoft/InnerEye-DeepLearning/pull/682)) Ensure the shape of input patches is compatible with model constraints. - ([#681](https://github.com/microsoft/InnerEye-DeepLearning/pull/681)) Pad model outputs if they are smaller than the inputs. - ([#683](https://github.com/microsoft/InnerEye-DeepLearning/pull/683)) Fix missing separator error in docs Makefile. - ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset. diff --git a/InnerEye/ML/pipelines/inference.py b/InnerEye/ML/pipelines/inference.py index 5b397c6e..4cbebaa8 100644 --- a/InnerEye/ML/pipelines/inference.py +++ b/InnerEye/ML/pipelines/inference.py @@ -7,7 +7,7 @@ from __future__ import annotations import logging from enum import Enum from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import numpy as np import torch @@ -235,7 +235,7 @@ class InferencePipeline(FullImageInferencePipelineBase): @torch.no_grad() def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, - mask: np.ndarray = None, + mask: Optional[np.ndarray] = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image @@ -255,12 +255,26 @@ class InferencePipeline(FullImageInferencePipelineBase): self.model.eval() image = tio.ScalarImage(tensor=image_channels) - subject = tio.Subject(image=image) + INPUT = 'input_image' + MASK = 'mask' + + subject_dict: Dict[str, tio.Image] = {INPUT: image} + if mask is not None: + subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis]) + subject = tio.Subject(subject_dict) + + constraints = self.model.model.crop_size_constraints + + # Make sure the image size is compatible with the model + multiple_constraints = constraints.multiple_of # type: ignore + if multiple_constraints is not None: + ensure_shape_multiple = tio.EnsureShapeMultiple(constraints.multiple_of) # type: ignore + subject = ensure_shape_multiple(subject) # type: ignore # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged. - restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore - effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore + restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore + effective_patch_size, effective_stride = restrict_patch_size(subject.spatial_shape, # type: ignore self.model_config.test_crop_size, self.model_config.inference_stride_size) @@ -276,10 +290,10 @@ class InferencePipeline(FullImageInferencePipelineBase): aggregator = tio.inference.GridAggregator(grid_sampler) logging.debug( - f"Inference on image size {image.spatial_shape} will run " + f"Inference on image size {subject.spatial_shape} will run " f"with crop size {effective_patch_size} and stride {effective_stride}") for patches_batch in patch_loader: - input_tensor = patches_batch['image'][tio.DATA].float() + input_tensor = patches_batch[INPUT][tio.DATA].float() if self.model_config.use_gpu: input_tensor = input_tensor.cuda() locations = patches_batch[tio.LOCATION] @@ -296,9 +310,24 @@ class InferencePipeline(FullImageInferencePipelineBase): # collect the predictions over each of the batches aggregator.add_batch(patches_posteriors, locations) posteriors = aggregator.get_output_tensor().numpy() - posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask) + posteriors_mask = None if mask is None else subject[MASK].numpy()[0] + posteriors, segmentation = self.post_process_posteriors(posteriors, mask=posteriors_mask) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") + + # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any) + posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine) + segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine) + subject.add_image(posteriors_image, 'posteriors') + subject.add_image(segmentation_image, 'segmentation') + # Remove some images to avoid unnecessary computations + subject.remove_image(INPUT) + if mask is not None: + subject.remove_image(MASK) + subject_original_space = subject.apply_inverse_transform() if subject.applied_transforms else subject + posteriors = subject_original_space.posteriors.numpy() # type: ignore + segmentation = subject_original_space.segmentation.numpy()[0] # type: ignore + # prepare pipeline results from the processed batch return InferencePipeline.Result( patient_id=patient_id, diff --git a/Tests/ML/pipelines/test_inference_smallimages.py b/Tests/ML/pipelines/test_inference_smallimages.py index db2a4fcd..d16a354c 100644 --- a/Tests/ML/pipelines/test_inference_smallimages.py +++ b/Tests/ML/pipelines/test_inference_smallimages.py @@ -59,15 +59,6 @@ def run_inference_on_unet(size: TupleInt3) -> None: image_util.check_array_range(p) -def test_inference_on_too_small_image() -> None: - """ - Running inference on a simplified Unet model when the input image is too small along an axis. - """ - with pytest.raises(ValueError) as ex: - run_inference_on_unet((5, 10, 64)) - assert "input image must have at least a size of (16, 16, 16)" in str(ex) - - @pytest.mark.parametrize("size", [(26, 20, 50), (16, 16, 16)]) def test_inference_on_small_image(size: TupleInt3) -> None: """