305 строки
15 KiB
Python
305 строки
15 KiB
Python
# ------------------------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
# ------------------------------------------------------------------------------------------
|
|
from collections import defaultdict
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
import zipfile
|
|
|
|
import numpy as np
|
|
import param
|
|
from azureml.core import Run
|
|
from InnerEye_DICOM_RT.nifti_to_dicom_rt_converter import rtconvert
|
|
|
|
from InnerEye.Azure.azure_util import is_offline_run_context
|
|
from InnerEye.Common import fixed_paths
|
|
from InnerEye.Common.fixed_paths import DEFAULT_RESULT_ZIP_DICOM_NAME
|
|
from InnerEye.Common.generic_parsing import GenericConfig
|
|
from InnerEye.Common.type_annotations import TupleFloat3, TupleInt3
|
|
from InnerEye.ML.config import SegmentationModelBase
|
|
from InnerEye.ML.model_inference_config import ModelInferenceConfig
|
|
from InnerEye.ML.model_testing import DEFAULT_RESULT_IMAGE_NAME
|
|
from InnerEye.ML.photometric_normalization import PhotometricNormalization
|
|
from InnerEye.ML.pipelines.ensemble import EnsemblePipeline
|
|
from InnerEye.ML.pipelines.inference import FullImageInferencePipelineBase, InferencePipeline
|
|
from InnerEye.ML.utils.config_util import ModelConfigLoader
|
|
from InnerEye.ML.utils.io_util import ImageWithHeader, load_nifti_image, reverse_tuple_float3, store_as_ubyte_nifti, \
|
|
load_dicom_series_and_save
|
|
|
|
|
|
class ScorePipelineConfig(GenericConfig):
|
|
data_folder: Path = param.ClassSelector(class_=Path, default=Path.cwd(),
|
|
doc="Path to the folder that contains the images that should be scored")
|
|
model_folder: str = param.String(doc="Path to the folder that contains the model, in particular inference "
|
|
"configuration and checkpoints. Defaults to the folder where the current "
|
|
"file lives.")
|
|
image_files: List[str] = param.List([fixed_paths.DEFAULT_TEST_IMAGE_NAME], class_=str, instantiate=False,
|
|
bounds=(1, None),
|
|
doc="The name of the images channels to run the pipeline on. These "
|
|
"files must exist in the data_folder.")
|
|
result_image_name: str = param.String(DEFAULT_RESULT_IMAGE_NAME,
|
|
doc="The name of the result image, created in the project root folder.")
|
|
use_gpu: bool = param.Boolean(True, doc="If GPU should be used or not.")
|
|
use_dicom: bool = param.Boolean(False, doc="If images to be scored are DICOM and output to be DICOM-RT. "
|
|
"If this is set then image_files should contain a single zip file "
|
|
"containing a set of DICOM files.")
|
|
result_zip_dicom_name: str = param.String(DEFAULT_RESULT_ZIP_DICOM_NAME,
|
|
doc="The name of the zipped DICOM-RT file if use_dicom set.")
|
|
|
|
|
|
def init_from_model_inference_json(model_folder: Path, use_gpu: bool = True) -> Tuple[FullImageInferencePipelineBase,
|
|
SegmentationModelBase]:
|
|
"""
|
|
Loads the config and inference pipeline from the current directory using fixed_paths.MODEL_INFERENCE_JSON_FILE_NAME
|
|
:return: Tuple[InferencePipeline, Config]
|
|
"""
|
|
logging.info('Python version: ' + sys.version)
|
|
path_to_model_inference_config = model_folder / fixed_paths.MODEL_INFERENCE_JSON_FILE_NAME
|
|
logging.info(f'path_to_model_inference_config: {path_to_model_inference_config}')
|
|
model_inference_config = read_model_inference_config(str(path_to_model_inference_config))
|
|
logging.info(f'model_inference_config: {model_inference_config}')
|
|
full_path_to_checkpoints = [model_folder / x for x in model_inference_config.checkpoint_paths]
|
|
logging.info(f'full_path_to_checkpoints: {full_path_to_checkpoints}')
|
|
loader = ModelConfigLoader[SegmentationModelBase](
|
|
model_configs_namespace=model_inference_config.model_configs_namespace)
|
|
model_config = loader.create_model_config_from_name(model_name=model_inference_config.model_name)
|
|
return create_inference_pipeline(model_config, full_path_to_checkpoints, use_gpu)
|
|
|
|
|
|
def create_inference_pipeline(model_config: SegmentationModelBase,
|
|
full_path_to_checkpoints: List[Path],
|
|
use_gpu: bool = True) \
|
|
-> Tuple[FullImageInferencePipelineBase, SegmentationModelBase]:
|
|
"""
|
|
Create pipeline for inference, this can be a single model inference pipeline or an ensemble, if multiple
|
|
checkpoints provided.
|
|
:param model_config: Model config to use to create the pipeline.
|
|
:param full_path_to_checkpoints: Checkpoints to use for model inference.
|
|
:param use_gpu: If GPU should be used or not.
|
|
"""
|
|
model_config.use_gpu = use_gpu
|
|
logging.info('test_config: ' + model_config.model_name)
|
|
|
|
inference_pipeline: Optional[FullImageInferencePipelineBase]
|
|
if len(full_path_to_checkpoints) == 1:
|
|
inference_pipeline = InferencePipeline.create_from_checkpoint(
|
|
path_to_checkpoint=full_path_to_checkpoints[0],
|
|
model_config=model_config)
|
|
else:
|
|
inference_pipeline = EnsemblePipeline.create_from_checkpoints(path_to_checkpoints=full_path_to_checkpoints,
|
|
model_config=model_config)
|
|
if inference_pipeline is None:
|
|
raise ValueError("Cannot create inference pipeline")
|
|
|
|
return inference_pipeline, model_config
|
|
|
|
|
|
def read_model_inference_config(path_to_model_inference_config: str) -> ModelInferenceConfig:
|
|
with open(path_to_model_inference_config, 'r', encoding='utf-8') as file:
|
|
model_inference_config = ModelInferenceConfig.from_json(file.read()) # type: ignore
|
|
return model_inference_config
|
|
|
|
|
|
def is_spacing_valid(spacing: TupleFloat3, dataset_expected_spacing_xyz: TupleFloat3) -> bool:
|
|
absolute_tolerance = 1e-1
|
|
return np.allclose(spacing, dataset_expected_spacing_xyz, atol=absolute_tolerance)
|
|
|
|
|
|
def run_inference(images_with_header: List[ImageWithHeader],
|
|
inference_pipeline: FullImageInferencePipelineBase,
|
|
config: SegmentationModelBase) -> np.ndarray:
|
|
"""
|
|
Runs inference on a list of channels and given a config and inference pipeline
|
|
:param images_with_header:
|
|
:param inference_pipeline:
|
|
:param config:
|
|
:return: segmentation
|
|
"""
|
|
# Check the image has the correct spacing
|
|
if config.dataset_expected_spacing_xyz:
|
|
for image_with_header in images_with_header:
|
|
spacing_xyz = reverse_tuple_float3(image_with_header.header.spacing)
|
|
if not is_spacing_valid(spacing_xyz, config.dataset_expected_spacing_xyz):
|
|
raise ValueError(f'Input image has spacing {spacing_xyz} '
|
|
f'but expected {config.dataset_expected_spacing_xyz}')
|
|
# Photo norm
|
|
photo_norm = PhotometricNormalization(config_args=config)
|
|
photo_norm_images = [photo_norm.transform(image_with_header.image) for image_with_header in images_with_header]
|
|
segmentation = inference_pipeline.predict_and_post_process_whole_image(
|
|
image_channels=np.array(photo_norm_images),
|
|
voxel_spacing_mm=images_with_header[0].header.spacing
|
|
).segmentation
|
|
|
|
return segmentation
|
|
|
|
|
|
def extract_zipped_files_and_flatten(zip_file_path: Path, extraction_folder: Path) -> None:
|
|
"""
|
|
Unzip a zip file and extract all the files discarding any folders they
|
|
may have in the zip file.
|
|
|
|
:param zip_file_path: Path to zip file.
|
|
:param extraction_folder: Path to extraction folder.
|
|
"""
|
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
|
zipinfos_by_name = defaultdict(list)
|
|
for zipped_file in zip_file.infolist():
|
|
if not zipped_file.is_dir():
|
|
# discard the path, if any, to just get the filename and suffix
|
|
name = os.path.basename(zipped_file.filename)
|
|
zipinfos_by_name[name].append(zipped_file)
|
|
duplicates = {name: zipinfos for name, zipinfos in zipinfos_by_name.items() if len(zipinfos) > 1}
|
|
if len(duplicates) > 0:
|
|
warnings = ""
|
|
for name, zipinfos in duplicates.items():
|
|
joint_paths = ", ".join([os.path.dirname(zipinfo.filename) for zipinfo in zipinfos])
|
|
warnings += f"File {name} is duplicated in folders {joint_paths}.\n"
|
|
raise ValueError("Zip file contains duplicates.\n" + warnings)
|
|
for name, zipinfos in zipinfos_by_name.items():
|
|
zipinfo = zipinfos[0]
|
|
zipinfo.filename = name
|
|
zip_file.extract(zipinfo, str(extraction_folder))
|
|
|
|
|
|
def convert_zipped_dicom_to_nifti(zip_file_path: Path, reference_series_folder: Path,
|
|
nifti_file_path: Path) -> None:
|
|
"""
|
|
Given a zip file, extract DICOM series and convert to Nifti format.
|
|
|
|
This function:
|
|
1) Unzips the file at zip_file_path into reference_series_folder,
|
|
assumed to contain a DICOM series.
|
|
2) Creates a Nifti file from the DICOM series.
|
|
|
|
:param zip_file_path: Path to a zip file.
|
|
:param reference_series_folder: Folder to unzip DICOM series into.
|
|
:param nifti_file_path: Path to target Nifti file.
|
|
"""
|
|
extract_zipped_files_and_flatten(zip_file_path, reference_series_folder)
|
|
load_dicom_series_and_save(reference_series_folder, nifti_file_path)
|
|
|
|
|
|
def convert_rgb_colour_to_hex(colour: TupleInt3) -> str:
|
|
"""
|
|
Config colours are stored as TupleInt3's, but DICOM-RT convert expects
|
|
hexadecimal strings. This function converts them into the correct
|
|
format.
|
|
|
|
:param colour: RGB colour as a TupleInt3.
|
|
:return: Colour formatted as a hex string.
|
|
"""
|
|
return '{0:02X}{1:02X}{2:02X}'.format(colour[0], colour[1], colour[2])
|
|
|
|
|
|
def convert_nifti_to_zipped_dicom_rt(nifti_file: Path, reference_series: Path, scratch_folder: Path,
|
|
config: SegmentationModelBase, dicom_rt_zip_file_name: str) -> Path:
|
|
"""
|
|
Given a Nifti file and a reference DICOM series, create zip file containing a DICOM-RT file.
|
|
|
|
Calls rtconvert with the given Nifti file, reference DICOM series and configuration from
|
|
config to create a DICOM-RT file in the scratch folder. This is then zipped and a path to
|
|
the zip returned.
|
|
|
|
:param nifti_file: Path to Nifti file.
|
|
:param reference_series: Path to folder containing reference DICOM series.
|
|
:param scratch_folder: Scratch folder to extract files into.
|
|
:param config: Model config.
|
|
:param dicom_rt_zip_file_name: Target DICOM-RT zip file name, ending in .dcm.zip.
|
|
:return: Path to DICOM-RT file.
|
|
"""
|
|
dicom_rt_file_path = scratch_folder / Path(dicom_rt_zip_file_name).with_suffix("")
|
|
(stdout, stderr) = rtconvert(
|
|
in_file=nifti_file,
|
|
reference_series=reference_series,
|
|
out_file=dicom_rt_file_path,
|
|
struct_names=config.ground_truth_ids_display_names,
|
|
struct_colors=[convert_rgb_colour_to_hex(rgb) for rgb in config.colours],
|
|
fill_holes=config.fill_holes)
|
|
# Log stdout, stderr from DICOM-RT conversion.
|
|
logging.debug("stdout: %s", stdout)
|
|
logging.debug("stderr: %s", stderr)
|
|
dicom_rt_zip_file_path = scratch_folder / dicom_rt_zip_file_name
|
|
with zipfile.ZipFile(dicom_rt_zip_file_path, 'w') as dicom_rt_zip:
|
|
dicom_rt_zip.write(dicom_rt_file_path, dicom_rt_file_path.name)
|
|
return dicom_rt_zip_file_path
|
|
|
|
|
|
def check_input_file(data_folder: Path, filename: str) -> Path:
|
|
"""
|
|
Check the folder: data_folder contains a file with name: filename.
|
|
|
|
If the file does not exist then raise a FileNotFoundError exception. Otherwise return the
|
|
path to the file.
|
|
|
|
:param data_folder: Path to data folder.
|
|
:param filename: Filename within data folder to test.
|
|
:return: Full path to filename.
|
|
"""
|
|
full_file_path = data_folder / filename
|
|
if not full_file_path.exists():
|
|
message = \
|
|
str(data_folder) if data_folder.is_absolute() else f"{data_folder}, absolute: {data_folder.absolute()}"
|
|
raise FileNotFoundError(f"File {filename} does not exist in data folder {message}")
|
|
return full_file_path
|
|
|
|
|
|
def score_image(args: ScorePipelineConfig) -> Path:
|
|
"""
|
|
Perform model inference on a single image. By doing the following:
|
|
1) Copy the provided data root directory to the root (this contains the model checkpoints and image to infer)
|
|
2) Instantiate an inference pipeline based on the provided model_inference.json in the snapshot
|
|
3) Store the segmentation file in the current directory
|
|
4) Upload the segmentation to AML
|
|
:param args:
|
|
:return:
|
|
"""
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
score_py_folder = Path(__file__).parent
|
|
model_folder = Path(args.model_folder or str(score_py_folder))
|
|
|
|
run_context = Run.get_context()
|
|
logging.info(f"Run context={run_context.id}")
|
|
|
|
if args.use_dicom:
|
|
# Only a single zip file is supported.
|
|
if len(args.image_files) > 1:
|
|
raise ValueError("Supply exactly one zip file in args.images.")
|
|
input_zip_file = check_input_file(args.data_folder, args.image_files[0])
|
|
reference_series_folder = model_folder / "temp_extraction"
|
|
nifti_filename = model_folder / "temp_nifti.nii.gz"
|
|
convert_zipped_dicom_to_nifti(input_zip_file, reference_series_folder, nifti_filename)
|
|
test_images = [nifti_filename]
|
|
else:
|
|
test_images = [check_input_file(args.data_folder, file) for file in args.image_files]
|
|
|
|
images = [load_nifti_image(file) for file in test_images]
|
|
|
|
inference_pipeline, config = init_from_model_inference_json(model_folder, args.use_gpu)
|
|
segmentation = run_inference(images, inference_pipeline, config)
|
|
|
|
segmentation_file_name = model_folder / args.result_image_name
|
|
result_dst = store_as_ubyte_nifti(segmentation, images[0].header, segmentation_file_name)
|
|
|
|
if args.use_dicom:
|
|
result_dst = convert_nifti_to_zipped_dicom_rt(result_dst, reference_series_folder, model_folder,
|
|
config, args.result_zip_dicom_name)
|
|
|
|
if not is_offline_run_context(run_context):
|
|
upload_file_name = args.result_zip_dicom_name if args.use_dicom else args.result_image_name
|
|
run_context.upload_file(upload_file_name, str(result_dst))
|
|
logging.info(f"Segmentation completed: {result_dst}")
|
|
return result_dst
|
|
|
|
|
|
def main() -> None:
|
|
print(f"PYTHONPATH: {os.environ.get('PYTHONPATH')}")
|
|
score_image(ScorePipelineConfig.parse_args())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|