Ensure the shape of input patches is compatible with model constraints (#682)
* Ensure patches shape is compatible with model constraints * Remove test for images that are too small
This commit is contained in:
Родитель
3c919dca14
Коммит
d929ccbd46
|
@ -93,6 +93,7 @@ gets uploaded to AzureML, by skipping all test folders.
|
|||
|
||||
### Fixed
|
||||
|
||||
- ([#682](https://github.com/microsoft/InnerEye-DeepLearning/pull/682)) Ensure the shape of input patches is compatible with model constraints.
|
||||
- ([#681](https://github.com/microsoft/InnerEye-DeepLearning/pull/681)) Pad model outputs if they are smaller than the inputs.
|
||||
- ([#683](https://github.com/microsoft/InnerEye-DeepLearning/pull/683)) Fix missing separator error in docs Makefile.
|
||||
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset.
|
||||
|
|
|
@ -7,7 +7,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -235,7 +235,7 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
@torch.no_grad()
|
||||
def predict_whole_image(self, image_channels: np.ndarray,
|
||||
voxel_spacing_mm: TupleFloat3,
|
||||
mask: np.ndarray = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
patient_id: int = 0) -> InferencePipeline.Result:
|
||||
"""
|
||||
Performs a single inference pass through the pipeline for the provided image
|
||||
|
@ -255,12 +255,26 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
self.model.eval()
|
||||
|
||||
image = tio.ScalarImage(tensor=image_channels)
|
||||
subject = tio.Subject(image=image)
|
||||
INPUT = 'input_image'
|
||||
MASK = 'mask'
|
||||
|
||||
subject_dict: Dict[str, tio.Image] = {INPUT: image}
|
||||
if mask is not None:
|
||||
subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis])
|
||||
subject = tio.Subject(subject_dict)
|
||||
|
||||
constraints = self.model.model.crop_size_constraints
|
||||
|
||||
# Make sure the image size is compatible with the model
|
||||
multiple_constraints = constraints.multiple_of # type: ignore
|
||||
if multiple_constraints is not None:
|
||||
ensure_shape_multiple = tio.EnsureShapeMultiple(constraints.multiple_of) # type: ignore
|
||||
subject = ensure_shape_multiple(subject) # type: ignore
|
||||
|
||||
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
|
||||
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
|
||||
restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore
|
||||
effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore
|
||||
restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore
|
||||
effective_patch_size, effective_stride = restrict_patch_size(subject.spatial_shape, # type: ignore
|
||||
self.model_config.test_crop_size,
|
||||
self.model_config.inference_stride_size)
|
||||
|
||||
|
@ -276,10 +290,10 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
aggregator = tio.inference.GridAggregator(grid_sampler)
|
||||
|
||||
logging.debug(
|
||||
f"Inference on image size {image.spatial_shape} will run "
|
||||
f"Inference on image size {subject.spatial_shape} will run "
|
||||
f"with crop size {effective_patch_size} and stride {effective_stride}")
|
||||
for patches_batch in patch_loader:
|
||||
input_tensor = patches_batch['image'][tio.DATA].float()
|
||||
input_tensor = patches_batch[INPUT][tio.DATA].float()
|
||||
if self.model_config.use_gpu:
|
||||
input_tensor = input_tensor.cuda()
|
||||
locations = patches_batch[tio.LOCATION]
|
||||
|
@ -296,9 +310,24 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
# collect the predictions over each of the batches
|
||||
aggregator.add_batch(patches_posteriors, locations)
|
||||
posteriors = aggregator.get_output_tensor().numpy()
|
||||
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask)
|
||||
posteriors_mask = None if mask is None else subject[MASK].numpy()[0]
|
||||
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=posteriors_mask)
|
||||
|
||||
image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
|
||||
|
||||
# Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any)
|
||||
posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine)
|
||||
segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine)
|
||||
subject.add_image(posteriors_image, 'posteriors')
|
||||
subject.add_image(segmentation_image, 'segmentation')
|
||||
# Remove some images to avoid unnecessary computations
|
||||
subject.remove_image(INPUT)
|
||||
if mask is not None:
|
||||
subject.remove_image(MASK)
|
||||
subject_original_space = subject.apply_inverse_transform() if subject.applied_transforms else subject
|
||||
posteriors = subject_original_space.posteriors.numpy() # type: ignore
|
||||
segmentation = subject_original_space.segmentation.numpy()[0] # type: ignore
|
||||
|
||||
# prepare pipeline results from the processed batch
|
||||
return InferencePipeline.Result(
|
||||
patient_id=patient_id,
|
||||
|
|
|
@ -59,15 +59,6 @@ def run_inference_on_unet(size: TupleInt3) -> None:
|
|||
image_util.check_array_range(p)
|
||||
|
||||
|
||||
def test_inference_on_too_small_image() -> None:
|
||||
"""
|
||||
Running inference on a simplified Unet model when the input image is too small along an axis.
|
||||
"""
|
||||
with pytest.raises(ValueError) as ex:
|
||||
run_inference_on_unet((5, 10, 64))
|
||||
assert "input image must have at least a size of (16, 16, 16)" in str(ex)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [(26, 20, 50), (16, 16, 16)])
|
||||
def test_inference_on_small_image(size: TupleInt3) -> None:
|
||||
"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче