ENH: Add text and image inference code for CXR multimodal (#438)

* add image encoder and utils

* update the imports

* add file headers

* pin the library versions

* restructure the dir, add text related code, and update the env

* add text inference tests

* insert the missing headers

* Fix PEP8 issues

* Remove placeholder files

* flake8 fixes

* update the env and add img model tests

* fix image related tests and environment

* cosmetic changes related to linters

* add test image preprocessing

* update the test env

* update env

* pr comments -- env update

* fix mypy and add text module

* add tests for cxrbert

* revert

* mypy - pyright disable

* Add minor edits to io module

* Add __init__ files

* Fix some mypy errors

* add image and text inference engines

* add init

* add headers and update docstrings

* sort imports

* Unify imports

* Replace Health Intelligence with Health Futures

* Update GitHub workflow

* Add multimodal package to flake8 checks

* Fix flake8 errors

* test pretrained checkpoint

* remove unused par

* Add phrase grounding notebook

* pretrained path -- image model

* add ipython in the env for notebooks

* Fix incorrect install command in CI workflow

* move visualisation and update the notebook

* update the import

* Remove phrase grounding example Python file

* Use linear interpolation for similarity maps

* Stop trying to compute and upload coverage

* Fix float comparison in test

* add another demo example, unifying plotting

* Fix some mypy errors

* Fix similarity map indexing

* Fix warning about interpolation

* Remove hard-coded strings from notebook

* Ignore mypy error in test

* Add typing hints to visualization function

* flake8

* Start adding docs

* Move notebook outside Python package

* Run notebook in CI

* Move requirement

* Add call to pytest and install kernel for papermill

* Move requirements to file

* Improve docstring by adding info on percentiles

* Use Path class for path parameter

* Add typing hints in notebook functions

* Add clarifying comment

* Move notebooks command to local Makefile

* Remove unused variables

* Add docstrings to visualization functions

* Add Microsoft file header

* Launch CI tests also when hi-ml changes

* Pass list of transforms to validation function

* Fix docstring

* Update typing hints

* Update kwarg name

* Fix docstring

* Replace hard-coded variable with kwarg

* Declare variable before the line in which it is used

* Fix type of returned value

* Rename class and fix docstrings

* Fix docstring

* Improve some docstrings

* Improve assertion

* Replace import with type definition

* Stop initializing list in kwarg

* Fix docstring

* Fix docstring

* Fix docstring

* Fix docstring

* Remove old comment from ImageTextInferenceEngine docstring

* Update docstring for get_patch_embeddings_from_image

* Apply docstring suggestion for get_patch_embeddings_from_image

Co-authored-by: Fernando Pérez-García <fepegar@gmail.com>

* Stop running multimodal tests when the hi-ml libraries change.

* Fix docstring

* Use double quotes for consistency

* Improve error message and add ResNet enum

* Update docstring for tokenize_input_prompts

* Refactor some requirements

* Add missing requirement

* Stop installing hi-ml-azure in CI

* Install all requirements in CI

* Install all requirements in CI before pytest

* Add missing run requirements

* Ignore mypy error

* Upgrade Pillow to 9.0.1 to avoid security flag

* Add some VLP unit tests

* Improve check for special tokens

* Add missing Microsoft headers

* Remove some unused test requirements

* Allow [MASK] token in TextInput prompts

* Update docstring for MultiTaskModel

* Add __init__ file

* Fix docstring

* Read requirements from files in Conda YAML

* Enable members documentation by default

* Add multimodal folder to docs conf path

* Add multimodal run requirements to RTD requirements

* Stop ignoring multimodal API rst file

* Git-add multimodal API file

* Fix docstrings

* Refactor API documentation

* Use package description file for docs

Co-authored-by: Fernando Pérez-García <fperezgarcia@microsoft.com>
Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
Co-authored-by: Fernando Pérez-García <fepegar@gmail.com>
This commit is contained in:
Ozan Oktay 2022-06-27 12:02:55 +01:00 коммит произвёл GitHub
Родитель a78c39f0e9
Коммит a618a844aa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
42 изменённых файлов: 1880 добавлений и 63 удалений

38
.github/workflows/multimodal-pr.yml поставляемый
Просмотреть файл

@ -58,29 +58,6 @@ jobs:
make pip_test
make mypy
pyright:
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v2
with:
lfs: true
- uses: actions/setup-node@v2
with:
node-version: '14'
- uses: conda-incubator/setup-miniconda@v2
with:
environment-file: ${{ env.folder }}/environment.yml
- name: pyright
shell: bash -l {0}
run: |
conda info
cd ${{ env.folder }}
make pyright_install
make pyright
pytest:
runs-on: ubuntu-18.04
steps:
@ -93,13 +70,20 @@ jobs:
with:
python-version: ${{ env.pythonVersion }}
- name: Upgrade PIP
run: |
cd ${{ env.folder }}
make pip_upgrade
- name: Test with pytest
run: |
cd ${{ env.folder }}
# Install local package in editable mode
make pip_local
# Run tests
make pip_test
make pip
make pytest
- name: Run Jupyter notebooks
run: |
cd ${{ env.folder }}
make notebooks

1
.gitignore поставляемый
Просмотреть файл

@ -143,6 +143,7 @@ tensorboard_logs/
docs/source/api
!docs/source/api/api.rst
!docs/source/api/multimodal.rst
# This file is copied from repository root to docs/source by the makefile
/docs/source/CHANGELOG.md
/docs/source/CONTRIBUTING.md

Просмотреть файл

@ -44,6 +44,11 @@ repos:
alias: flake8-hi-ml-histopathology
files: ^hi-ml-histopathology/
args: [--config, hi-ml-histopathology/.flake8]
- id: flake8
name: flake8 ./multimodal/
alias: flake8-multimodal
files: ^multimodal/
args: [--config, multimodal/.flake8]
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.5.7

Просмотреть файл

@ -0,0 +1,40 @@
API
###
.. currentmodule:: health_multimodal
Vision-language processing (VLP)
--------------------------------
.. autosummary::
:toctree:
vlp
Image processing
----------------
.. autosummary::
:toctree:
image
Text processing
----------------
.. autosummary::
:toctree:
text
Common utils
------------
.. autosummary::
:toctree:
common

Просмотреть файл

@ -22,6 +22,7 @@ import sys
sys.path.insert(0, os.path.abspath('../../hi-ml/src'))
sys.path.insert(0, os.path.abspath('../../hi-ml-azure/src'))
sys.path.insert(0, os.path.abspath('../../hi-ml-histopathology/src'))
sys.path.insert(0, os.path.abspath('../../multimodal'))
# -- Project information -----------------------------------------------------
@ -87,3 +88,7 @@ highlight_language = "python"
# For classes, insert documentation from the class itself AND the constructor
autoclass_content = "both"
autodoc_default_options = {
'members': True,
}

Просмотреть файл

@ -46,6 +46,13 @@ The `hi-ml` toolbox provides
azure_setup.md
dsa.md
.. toctree::
:maxdepth: 1
:caption: Multimodal learning
package_description_multimodal.md
api/multimodal.rst
.. toctree::
:maxdepth: 1
:caption: Self supervised learning

Просмотреть файл

@ -0,0 +1,9 @@
# Multimodal learning
Introduction
## Installation
## Example
## Credits

Просмотреть файл

@ -0,0 +1 @@
../../multimodal/package_description.md

Просмотреть файл

@ -57,3 +57,8 @@ pyright:
# run basic checks
check: flake8 mypy pyright
# run notebooks to ensure there are no errors
notebooks: pip_test
ipython kernel install --name "python3" --user
papermill notebooks/phrase_grounding.ipynb /tmp/phrase_grounding_output.ipynb

Просмотреть файл

@ -1,24 +1,14 @@
name: multimodal
channels:
- defaults
- conda-forge
- pytorch
dependencies:
- pip=20.1.1
- python=3.7.3
- cudatoolkit=11.1
- pytorch=1.9.0
- torchvision=0.10.0
- pip:
# This file should contain pinned versions of all packages in requirements_run.txt
# and requirements_test.txt.
# It is not possible to reference the requirements files here with "-r" because AzureML's
# environment creation can't handle those.
# Run requirements: presently empty
# Test requirements
- black==22.1.0
- coverage==6.3.2
- flake8==4.0.1
- mypy==0.931
- pylint==2.12.2
- pycobertura==2.0.1
- pytest==6.2.2
- pytest-cov==2.11.1
- pytest-rerunfailures==10.2
- pytest-timeout==2.0.1
- -r requirements_run.txt
- -r requirements_test.txt

Просмотреть файл

@ -2,6 +2,3 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
def dummy() -> int:
return 1

Просмотреть файл

@ -2,11 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import pytest
from health_multimodal.dummy import dummy
"""General utils
.. currentmodule:: health_multimodal.common
@pytest.mark.gpu
def test_dummy() -> None:
assert dummy() == 1
.. autosummary::
:toctree:
visualization
"""

Просмотреть файл

@ -0,0 +1,119 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Union, Optional
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from health_multimodal.image.data.io import load_image
TypeArrayImage = Union[np.ndarray, Image.Image]
def _plot_image(
image: TypeArrayImage,
axis: plt.Axes,
title: Optional[str] = None,
) -> None:
"""Plot an image on a given axis, deleting the axis ticks and axis labels.
:param image: Input image.
:param axis: Axis to plot the image on.
:param title: Title used for the axis.
"""
axis.imshow(image)
axis.axis('off')
if title is not None:
axis.set_title(title)
def _get_isolines_levels(step_size: float) -> np.ndarray:
num_steps = np.floor(round(1 / step_size)).astype(int)
levels = np.linspace(step_size, 1, num_steps)
return levels
def _plot_isolines(
image: TypeArrayImage,
heatmap: np.ndarray,
axis: plt.Axes,
title: Optional[str] = None,
colormap: str = 'RdBu_r',
step: float = 0.25,
) -> None:
"""Plot an image and overlay heatmap isolines on it.
:param image: Input image.
:param heatmap: Heatmap of the same size as the image.
:param axis: Axis to plot the image on.
:param title: Title used for the axis.
:param colormap: Name of the Matplotlib colormap used for the isolines.
:param step: Step size between the isolines levels. The levels are in :math:`(0, 1]`.
For example, a step size of 0.25 will result in isolines levels of 0.25, 0.5, 0.75 and 1.
"""
axis.imshow(image)
levels = _get_isolines_levels(step)
contours = axis.contour(
heatmap,
cmap=colormap,
vmin=-1,
vmax=1,
levels=levels,
)
axis.clabel(contours, inline=True, fontsize=10)
axis.axis('off')
if title is not None:
axis.set_title(title)
def _plot_heatmap(
image: TypeArrayImage,
heatmap: np.ndarray,
figure: plt.Figure,
axis: plt.Axes,
colormap: str = 'RdBu_r',
title: Optional[str] = None,
alpha: float = 0.5,
) -> None:
"""Plot a heatmap overlaid on an image.
:param image: Input image.
:param heatmap: Input heatmap of the same size as the image.
:param figure: Figure to plot the images on.
:param axis: Axis to plot the images on.
:param colormap: Name of the Matplotlib colormap for the heatmap.
:param title: Title used for the axis.
:param alpha: Heatmap opacity. Must be in :math:`[0, 1]`.
"""
axis.imshow(image)
axes_image = axis.matshow(heatmap, alpha=alpha, cmap=colormap, vmin=-1, vmax=1)
# https://www.geeksforgeeks.org/how-to-change-matplotlib-color-bar-size-in-python/
divider = make_axes_locatable(axis)
colorbar_axes = divider.append_axes('right', size='10%', pad=0.1)
colorbar = figure.colorbar(axes_image, cax=colorbar_axes)
# https://stackoverflow.com/a/50671487/3956024
colorbar.ax.tick_params(pad=35)
plt.setp(colorbar.ax.get_yticklabels(), ha='right')
axis.axis('off')
if title is not None:
axis.set_title(title)
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> None:
"""Plot visualization of the input image, the similarity heatmap and the heatmap isolines.
:param image_path: Path to the input image.
:param similarity_map: Phase grounding similarity map of the same size as the image.
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 6))
image = load_image(image_path).convert('RGB')
_plot_image(image, axis=axes[0], title='Input image')
_plot_isolines(image, similarity_map, axis=axes[1], title='Similarity isolines')
_plot_heatmap(image, similarity_map, figure=fig, axis=axes[2], title='Similarity heatmap')

Просмотреть файл

@ -0,0 +1,33 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
"""Image-related tools
.. currentmodule:: health_multimodal.image.data
.. autosummary::
:toctree:
io
transforms
.. currentmodule:: health_multimodal.image.model
.. autosummary::
:toctree:
model
modules
"""
from .model.model import ImageModel, ResnetType
from .inference_engine import ImageInferenceEngine
__all__ = [
'ImageModel',
'ResnetType',
'ImageInferenceEngine',
]

Просмотреть файл

@ -0,0 +1,4 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,71 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import pydicom as dicom
from PIL import Image
import SimpleITK as sitk
from skimage import io
def remap_to_uint8(array: np.ndarray, percentiles: Optional[Tuple[float, float]] = None) -> np.ndarray:
"""Remap values in input so the output range is :math:`[0, 255]`.
Percentiles can be used to specify the range of values to remap.
This is useful to discard outliers in the input data.
:param array: Input array.
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
:returns: Array with ``0`` and ``255`` as minimum and maximum values.
"""
array = array.astype(float)
if percentiles is not None:
len_percentiles = len(percentiles)
if len_percentiles != 2:
message = (
'The value for percentiles should be a sequence of length 2,'
f' but has length {len_percentiles}'
)
raise ValueError(message)
a, b = percentiles
if a >= b:
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
if a < 0 or b > 100:
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
cutoff: np.ndarray = np.percentile(array, percentiles)
array = np.clip(array, *cutoff)
array -= array.min()
array /= array.max()
array *= 255
return array.astype(np.uint8)
def load_image(path: Path) -> Image.Image:
"""Load an image from disk.
The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers.
:param path: Path to image.
:returns: Image as ``Pillow`` ``Image``.
"""
# Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models
if path.suffix in [".jpg", ".jpeg", ".png"]:
image = io.imread(path)
elif path.suffixes == [".nii", ".gz"]:
image = sitk.GetArrayFromImage(sitk.ReadImage(str(path)))
if image.shape[0] == 1:
image = np.squeeze(image, axis=0)
assert image.ndim == 2
elif path.suffix == ".dcm":
image = dicom.dcmread(path).pixel_array
else:
raise ValueError(f"Image type not supported, filename was: {path}")
image = remap_to_uint8(image)
return Image.fromarray(image).convert("L")

Просмотреть файл

@ -0,0 +1,70 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from typing import Callable, Sequence, Optional, Tuple
import torch
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
class ExpandChannels:
"""
Transforms an image with one channel to an image with three channels by copying
pixel intensities of the image along the 1st dimension.
"""
def __call__(self, data: torch.Tensor) -> torch.Tensor:
"""
:param data: Tensor of shape [1, H, W].
:return: Tensor with channel copied three times, shape [3, H, W].
"""
if data.shape[0] != 1:
raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
return torch.repeat_interleave(data, 3, dim=0)
def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose:
"""
Defines the image transformation pipeline for Chest-Xray datasets.
:param resize: The size to resize the image to. Linear resampling is used.
Resizing is applied on the axis with smaller shape.
:param center_crop_size: The size to center crop the image to. Square crop is applied.
"""
transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
return Compose(transforms)
def infer_resize_params(val_img_transforms: Sequence[Callable]) -> Tuple[Optional[int], Optional[int]]:
"""
Given the validation transforms pipeline, extract the sizes to which the image was resized and cropped, if any.
"""
resize_size_from_transforms = None
crop_size_from_transforms = None
supported_types = Resize, CenterCrop, ToTensor, ExpandChannels
for transform in val_img_transforms:
trsf_type = type(transform)
if trsf_type not in supported_types:
raise ValueError(f"Unsupported transform type {trsf_type}. Supported types are {supported_types}")
if isinstance(transform, Resize):
if resize_size_from_transforms is None and crop_size_from_transforms is None:
assert transform.max_size is None
assert isinstance(transform.size, int), f"Expected int, got {transform.size}"
resize_size_from_transforms = transform.size
else:
raise ValueError("Expected Resize to be the first transform if present in val_img_transforms")
elif isinstance(transform, CenterCrop):
if crop_size_from_transforms is None:
two_dims = len(transform.size) == 2
same_sizes = transform.size[0] == transform.size[1]
is_square = two_dims and same_sizes
assert is_square, "Only square center crop supported"
crop_size_from_transforms = transform.size[0]
else:
raise ValueError(
f"Crop size has already been set to {crop_size_from_transforms} in a previous transform")
return resize_size_from_transforms, crop_size_from_transforms

Просмотреть файл

@ -0,0 +1,68 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Callable, Tuple
import torch
from torchvision.transforms import Compose
from health_multimodal.image.data.io import load_image
from health_multimodal.image.data.transforms import infer_resize_params
from health_multimodal.image.model.model import ImageModel
TypeShape2D = Tuple[int, int]
class ImageInferenceEngine:
"""
Encapsulate inference-time operations on an image model.
"""
def __init__(self, image_model: ImageModel, transform: Compose):
"""
:param img_model: Trained image model
:param transform: Transform to apply to the image after loading. Must return a torch.Tensor that can be
input directly to the image model.
"""
assert isinstance(image_model, ImageModel), f"Expected an ImageModel, got {type(image_model)}"
self.model = image_model
self.transform = transform
self.device = next(self.model.parameters()).device
self.model.eval()
self.resize_size, self.crop_size = infer_resize_params(self.transform.transforms)
def load_and_transform_input_image(self, image_path: Path, transform: Callable) -> Tuple[torch.Tensor, TypeShape2D]:
"""Read an image and apply the transform to it.
1. Read the image from the given path
2. Apply transform
3. Add the batch dimension
4. Move to the correct device
:param return_original_shape: Whether to return an extra tuple that has the original shape of the image
before the transforms. The tuple returned contains (width, height).
"""
image = load_image(image_path)
transformed_image = transform(image).unsqueeze(0).to(self.device)
return transformed_image, image.size
@torch.no_grad()
def get_patch_embeddings_from_image(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
"""Compute image embeddings in the joint latent space, preserving the image grid.
:param image_path: Path to the image to compute embeddings for.
:return: A tuple containing the image patch embeddings and
the shape of the original image (width, height) before applying transforms.
"""
input_image, img_shape = self.load_and_transform_input_image(image_path, self.transform)
projected_img_emb = self.model.get_patchwise_projected_embeddings(input_image, normalize=True)
assert projected_img_emb.shape[0] == 1
return projected_img_emb[0], img_shape

Просмотреть файл

@ -0,0 +1,4 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,173 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import enum
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from .modules import MLP, MultiTaskModel
TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
@enum.unique
class ResnetType(str, enum.Enum):
RESNET18 = "resnet18"
RESNET50 = "resnet50"
@dataclass
class ImageModelOutput():
img_embedding: torch.Tensor
patch_embedding: torch.Tensor
projected_global_embedding: torch.Tensor
class_logits: torch.Tensor
projected_patch_embeddings: torch.Tensor
class ImageModel(nn.Module):
"""Image encoder module"""
def __init__(self,
img_model_type: str,
joint_feature_size: int,
freeze_encoder: bool = False,
pretrained_model_path: Optional[str] = None,
**downstream_classifier_kwargs: Any):
super().__init__()
# Initiate encoder, projector, and classifier
self.encoder = ImageEncoder(img_model_type)
self.feature_size = get_encoder_output_dim(self.encoder)
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
hidden_dim=joint_feature_size, use_1x1_convs=True)
self.downstream_classifier_kwargs = downstream_classifier_kwargs
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
# Initialise the mode of modules
self.freeze_encoder = freeze_encoder
self.train()
if pretrained_model_path is not None:
assert isinstance(pretrained_model_path, str), f"Expected a string, got {type(pretrained_model_path)}"
self.load_state_dict(torch.load(pretrained_model_path))
def train(self, mode: bool = True) -> Any:
"""Switch the model between training and evaluation modes."""
super().train(mode=mode)
if self.freeze_encoder:
self.encoder.train(mode=False)
self.projector.train(mode=False)
return self
def forward(self, x: torch.Tensor) -> ImageModelOutput:
with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
projected_patch_embeddings = self.projector(patch_x)
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
logits = self.classifier(pooled_x) if self.classifier else None
return ImageModelOutput(img_embedding=pooled_x,
patch_embedding=patch_x,
class_logits=logits,
projected_patch_embeddings=projected_patch_embeddings,
projected_global_embedding=projected_global_embedding)
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
"""Create the classification module for the downstream task."""
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
@torch.no_grad()
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
"""Get patch-wise projected embeddings from the CNN model.
:param input_img: input tensor image [B, C, H, W].
:param normalize: If ``True``, the embeddings are L2-normalized.
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
"""
assert not self.training, "This function is only implemented for evaluation mode"
outputs = self.forward(input_img)
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
if normalize:
projected_embeddings = F.normalize(projected_embeddings, dim=1)
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
return projected_embeddings
class ImageEncoder(nn.Module):
"""Image encoder trunk module for the ``ImageModel`` class.
:param img_model_type: Type of image model to use: either ``"resnet18"`` or ``"resnet50"``.
"""
def __init__(self, img_model_type: str):
super().__init__()
self.img_model_type = img_model_type
self.encoder = self._create_encoder()
def _create_encoder(self, **kwargs: Any) -> nn.Module:
supported = ResnetType.RESNET18, ResnetType.RESNET50
if self.img_model_type not in supported:
raise NotImplementedError(f"Image model type \"{self.img_model_type}\" must be in {supported}")
encoder_class = resnet18 if self.img_model_type == ResnetType.RESNET18 else resnet50
encoder = encoder_class(return_all_feature_maps=True, pretrained=True, **kwargs)
return encoder
def forward(self, x: torch.Tensor, return_patch_embeddings: bool = False) -> TypeImageEncoder:
"""Image encoder forward pass."""
x = self.encoder(x)
x = x[-1] if isinstance(x, list) else x
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)), 1)
if return_patch_embeddings:
return x, avg_pooled_emb
return avg_pooled_emb
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
"""Workaround for enabling dilated convolutions after model initialization.
:param replace_stride_with_dilation: for each layer to replace the 2x2 stride with a dilated convolution
"""
if self.img_model_type == "resnet18":
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
raise NotImplementedError("resnet18 does not support dilated convolutions")
if replace_stride_with_dilation is None:
replace_stride_with_dilation = False, False, True
device = next(self.encoder.parameters()).device
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)
if self.encoder.training:
new_encoder.train()
else:
new_encoder.eval()
new_encoder.load_state_dict(self.encoder.state_dict())
self.encoder = new_encoder
@torch.no_grad()
def get_encoder_output_dim(module: torch.nn.Module) -> int:
"""Calculate the output dimension of ssl encoder by making a single forward pass.
:param module: Encoder module.
"""
# Target device
device = next(module.parameters()).device # type: ignore
assert isinstance(device, torch.device)
x = torch.rand((1, 3, 448, 448)).to(device)
# Extract the number of output feature dimensions
representations = module(x)
return representations.shape[1]

Просмотреть файл

@ -0,0 +1,86 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from typing import Callable, Optional
import torch
import torch.nn as nn
class MLP(nn.Module):
"""
Fully connected layers to map between image embeddings and projection space where pairs of images are compared.
:param input_dim: Input embedding feature size
:param hidden_dim: Hidden layer size in MLP
:param output_dim: Output projection size
:param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
"""
def __init__(self,
input_dim: int,
output_dim: int,
hidden_dim: Optional[int] = None,
use_1x1_convs: bool = False) -> None:
super().__init__()
if use_1x1_convs:
linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
normalisation_layer: Callable = nn.BatchNorm2d
projection_layer: Callable = nn.Conv2d
else:
linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
normalisation_layer = nn.BatchNorm1d
projection_layer = nn.Linear
self.output_dim = output_dim
self.input_dim = input_dim
if hidden_dim is not None:
self.model = nn.Sequential(
projection_layer(**linear_proj_1_args),
normalisation_layer(hidden_dim),
nn.ReLU(inplace=True),
projection_layer(**linear_proj_2_args))
else:
self.model = nn.Linear(input_dim, output_dim) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""forward pass of the multi-layer perceptron"""
x = self.model(x)
return x
class MultiTaskModel(nn.Module):
"""Torch module for multi-task classification heads. We create a separate classification head
for each task and perform a forward pass on each head independently in forward(). Classification
heads are instances of `MLP`.
:param input_dim: Number of dimensions of the input feature map.
:param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
:param num_classes: Number of output classes per task.
:param num_tasks: Number of classification tasks or heads required.
"""
def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):
super().__init__()
self.num_classes = num_classes
self.num_tasks = num_tasks
for task in range(num_tasks):
# TODO check if softmax not needed here.
setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns [batch_size, num_tasks, num_classes] tensor of logits."""
batch_size = x.shape[0]
out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
for task in range(self.num_tasks):
classifier = getattr(self, "fc_" + str(task))
out[:, :, task] = classifier(x)
return out

Просмотреть файл

@ -0,0 +1,41 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
"""Text-related tools
.. currentmodule:: health_multimodal.text
.. autosummary::
:toctree:
inference_engine
.. currentmodule:: health_multimodal.text.data
.. autosummary::
:toctree:
io
.. currentmodule:: health_multimodal.text.model
.. autosummary::
:toctree:
configuration_cxrbert
modelling_cxrbert
"""
from .data.io import TypePrompts
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
__all__ = [
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
"TypePrompts",
]

Просмотреть файл

@ -0,0 +1,4 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,56 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import logging
from typing import Any, List, Union
from transformers import BertTokenizer
TypePrompts = Union[str, List[str]]
class TextInput:
"""Text input class that can be used for inference and deployment.
Implements tokenizer related operations and ensure that input strings
conform with the standards expected from a BERT model.
:param tokenizer: A BertTokenizer object.
"""
def __init__(self, tokenizer: BertTokenizer) -> None:
self.tokenizer = tokenizer
def tokenize_input_prompts(self, prompts: TypePrompts, verbose: bool) -> Any:
"""
Tokenizes the input sentence(s) and adds special tokens as defined by the tokenizer.
:param prompts: Either a string containing a single sentence, or a list of strings each containing
a single sentence. Note that this method will not correctly tokenize multiple sentences if they
are input as a single string.
:param verbose: If set to True, will log the sentence after tokenization.
:return: A 2D tensor containing the tokenized sentences
"""
prompts = [prompts] if isinstance(prompts, str) else prompts
self.assert_special_tokens_not_present(" ".join(prompts))
prompts = [prompt.rstrip("!?.") for prompt in prompts] # removes punctuation from end of prompt
tokenizer_output = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=prompts,
add_special_tokens=True,
padding='longest',
return_tensors='pt')
if verbose:
for prompt in tokenizer_output.input_ids:
input_tokens = self.tokenizer.convert_ids_to_tokens(prompt.tolist())
logging.info(f"Input tokens: {input_tokens}")
return tokenizer_output
def assert_special_tokens_not_present(self, prompt: str) -> None:
"""Check if the input prompts contain special tokens."""
special_tokens = self.tokenizer.all_special_tokens
special_tokens.remove(self.tokenizer.mask_token) # [MASK] is allowed
if any(map(lambda token: token in prompt, special_tokens)):
raise ValueError(f"The input \"{prompt}\" contains at least one special token ({special_tokens})")

Просмотреть файл

@ -0,0 +1,113 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from typing import Any, List, Union
import torch
from transformers import BertForMaskedLM, BertTokenizer
from health_multimodal.text.data.io import TextInput
class TextInferenceEngine(TextInput):
"""
Text inference class that implements functionalities required to extract
sentence embedding, similarity and MLM prediction tasks.
:param tokenizer: A BertTokenizer object.
:param text_model: Text model either default HuggingFace class
"""
def __init__(self, tokenizer: BertTokenizer, text_model: BertForMaskedLM) -> None:
super().__init__(tokenizer=tokenizer)
assert isinstance(text_model, BertForMaskedLM), f"Expected a BertForMaskedLM, got {type(text_model)}"
self.model = text_model
self.device = next(self.model.parameters()).device
self.max_allowed_input_length = self.model.config.max_position_embeddings
def is_in_eval(self) -> bool:
"""Returns True if the model is in eval mode."""
return not self.model.training
def tokenize_input_prompts(self, prompts: Union[str, List[str]], verbose: bool = True) -> Any:
tokenizer_output = super().tokenize_input_prompts(prompts, verbose=verbose)
tokenizer_output.input_ids = tokenizer_output.input_ids.to(self.device)
tokenizer_output.attention_mask = tokenizer_output.attention_mask.to(self.device)
max_length = tokenizer_output.input_ids.shape[1]
if tokenizer_output.input_ids.shape[1] > self.max_allowed_input_length:
raise ValueError(f"The sequence length of the input ({max_length}) is "
f"longer than the maximum allowed sequence length ({self.max_allowed_input_length}).")
return tokenizer_output
@torch.no_grad()
def get_embeddings_from_prompt(self, prompts: Union[str, List[str]], verbose: bool = True) -> torch.Tensor:
"""Generate L2-normalised embeddings for a list of input text prompts.
:param prompts: Input text prompt(s) either in string or list of string format.
:param verbose: If set to True, tokenized words are displayed in the console.
:return: Tensor of shape (batch_size, embedding_size).
"""
assert self.is_in_eval()
tokenizer_output = self.tokenize_input_prompts(prompts=prompts, verbose=verbose)
txt_emb = self.model.get_projected_text_embeddings( # type: ignore
input_ids=tokenizer_output.input_ids,
attention_mask=tokenizer_output.attention_mask)
return txt_emb
@torch.no_grad()
def get_pairwise_similarities(self,
prompt_set_1: Union[str, List[str]],
prompt_set_2: Union[str, List[str]]) -> torch.Tensor:
"""Compute pairwise cosine similarities between the embeddings of the given prompts."""
emb_1 = self.get_embeddings_from_prompt(prompts=prompt_set_1, verbose=False)
emb_2 = self.get_embeddings_from_prompt(prompts=prompt_set_2, verbose=False)
sim = torch.diag(torch.mm(emb_1, emb_2.t())).detach()
return sim
@torch.no_grad()
def predict_masked_tokens(self, prompts: Union[str, List[str]]) -> List[List[str]]:
"""Predict masked tokens for a single or list of input text prompts.
Requires models to be trained with a MLM prediction head.
:param prompts: Input text prompt(s) either in string or list of string format.
:return: Predicted token candidates (Top-1) at masked position.
"""
assert self.is_in_eval()
# Tokenize the input prompts
tokenized_prompts = self.tokenize_input_prompts(prompts)
# Collect all token predictions
text_model_output = self.model.forward(input_ids=tokenized_prompts.input_ids,
attention_mask=tokenized_prompts.attention_mask)
logits = text_model_output.logits
logits = logits.detach()
predicted_token_ids = torch.argmax(logits, dim=-1) # Batch x Seq
# Identify the masked token indices
batch_size = predicted_token_ids.shape[0]
mask_token_id = self.tokenizer.mask_token_id
mlm_mask = tokenized_prompts.input_ids == mask_token_id # Batch x Seq
# Convert the predicted token ids to token strings
output = list()
for b in range(batch_size):
_ids = predicted_token_ids[b, mlm_mask[b]].cpu().tolist()
_tokens = self.tokenizer.convert_ids_to_tokens(_ids)
output.append(_tokens)
return output

Просмотреть файл

@ -0,0 +1,4 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,27 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any
from transformers import BertConfig, BertTokenizer
class CXRBertConfig(BertConfig):
"""
Config class for CXR-BERT model.
:param projection_size: Dimensionality of the joint latent space.
"""
model_type = "cxr-bert"
def __init__(self, projection_size: int = 128, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.projection_size = projection_size
class CXRBertTokenizer(BertTokenizer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

Просмотреть файл

@ -0,0 +1,133 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor as T
from transformers import BertForMaskedLM
from transformers.modeling_outputs import ModelOutput
from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
BERTTupleOutput = Tuple[T, T, T, T, T]
class CXRBertOutput(ModelOutput):
last_hidden_state: torch.FloatTensor
logits: torch.FloatTensor
cls_projected_embedding: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertProjectionHead(nn.Module):
"""Projection head to be used with BERT CLS token.
This is similar to ``BertPredictionHeadTransform`` in HuggingFace.
:param config: Configuration for BERT.
"""
def __init__(self, config: CXRBertConfig) -> None:
super().__init__()
self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
self.transform_act_fn = nn.functional.gelu
self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense_to_hidden(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dense_to_output(hidden_states)
return hidden_states
class CXRBertModel(BertForMaskedLM):
"""
Implements the CXR-BERT model outlined in the manuscript:
Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
https://arxiv.org/abs/2204.09817
Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is
used to align the latent vectors of image and text modalities.
"""
config_class = CXRBertConfig # type: ignore
def __init__(self, config: CXRBertConfig):
super().__init__(config)
self.cls_projection_head = BertProjectionHead(config)
self.init_weights()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_cls_projected_embedding: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any
) -> Union[BERTTupleOutput, CXRBertOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_for_masked_lm_output = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True)
last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
cls_projected_embedding = self.cls_projection_head(
last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
if return_dict:
return CXRBertOutput(
last_hidden_state=last_hidden_state,
logits=bert_for_masked_lm_output.logits,
cls_projected_embedding=cls_projected_embedding,
hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
attentions=bert_for_masked_lm_output.attentions,
)
else:
return (
last_hidden_state,
bert_for_masked_lm_output.logits,
cls_projected_embedding,
bert_for_masked_lm_output.hidden_states,
bert_for_masked_lm_output.attentions,)
def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
The joint latent space is trained using a contrastive objective between image and text data modalities.
:param input_ids: (batch_size, sequence_length)
:param attention_mask: (batch_size, sequence_length)
:return: (batch_size, projection_size)
"""
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
output_cls_projected_embedding=True, return_dict=True)
assert isinstance(outputs, CXRBertOutput)
assert outputs.cls_projected_embedding is not None
normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
return normalized_cls_embedding

Просмотреть файл

@ -0,0 +1,14 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
"""Visual-language processing tools
.. currentmodule:: health_multimodal.vlp
.. autosummary::
:toctree:
inference_engine
"""

Просмотреть файл

@ -0,0 +1,117 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
"""Tools related to joint image and text inference"""
from math import ceil, floor
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy import ndimage
from health_multimodal.image.inference_engine import ImageInferenceEngine
from health_multimodal.text.inference_engine import TextInferenceEngine
class ImageTextInferenceEngine:
"""
Encapsulate functions related to inference on ImageTextModels.
"""
def __init__(self,
image_inference_engine: ImageInferenceEngine,
text_inference_engine: TextInferenceEngine) -> None:
"""
This class takes an ImageTextModel as well as an ImageInferenceEngine and TextInferenceEngine.
"""
self.image_inference_engine = image_inference_engine
self.text_inference_engine = text_inference_engine
def get_similarity_map_from_raw_data(self, image_path: Path, query_text: str) -> np.ndarray:
"""
Return a heatmap of the similarities between each patch embedding from the image and the text embedding.
"""
assert not self.image_inference_engine.model.training
assert not self.text_inference_engine.model.training
# TODO: Add checks in here regarding the text query, etc.
image_embedding, (width, height) = self.image_inference_engine.get_patch_embeddings_from_image(image_path)
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text)
sim = self._get_similarity_map_from_embeddings(image_embedding, text_embedding)
resized_sim_map = self.convert_similarity_to_image_size(
sim,
width=width,
height=height,
resize_size=self.image_inference_engine.resize_size,
crop_size=self.image_inference_engine.crop_size,
val_img_transform=self.image_inference_engine.transform,
)
return resized_sim_map
@staticmethod
def _get_similarity_map_from_embeddings(projected_patch_embeddings: torch.Tensor,
projected_text_embeddings: torch.Tensor,
sigma: float = 1.5) -> torch.Tensor:
"""
Get smoothed similarity map for a given image patch embeddings and text embeddings.
:param projected_patch_embeddings: [n_patches_h, n_patches_w, feature_size]
:param projected_text_embeddings: [1, feature_size]
:return: similarity_map: similarity map of shape [n_patches_h, n_patches_w]
"""
n_patches_h, n_patches_w, feature_size = projected_patch_embeddings.shape
assert feature_size == projected_text_embeddings.shape[1]
assert projected_text_embeddings.shape[0] == 1
assert projected_text_embeddings.dim() == 2
patch_wise_similarity = projected_patch_embeddings.view(-1, feature_size) @ projected_text_embeddings.t()
patch_wise_similarity = patch_wise_similarity.reshape(n_patches_h, n_patches_w).cpu().numpy()
smoothed_similarity_map = torch.tensor(ndimage.gaussian_filter(
patch_wise_similarity, sigma=(sigma, sigma), order=0))
return smoothed_similarity_map
@staticmethod
def convert_similarity_to_image_size(
similarity_map: torch.Tensor, width: int, height: int, resize_size: Optional[int],
crop_size: Optional[int], val_img_transform: Optional[Callable] = None) -> np.ndarray:
"""
Convert similarity map from raw patch grid to original image size,
taking into account whether the image has been resized and/or cropped prior to entering the network.
"""
n_patches_h, n_patches_w = similarity_map.shape[0], similarity_map.shape[1]
target_shape = 1, 1, n_patches_h, n_patches_w
smallest_dimension = min(height, width)
# TODO:
# verify_resize_params(val_img_transforms, resize_size, crop_size)
if crop_size is not None:
if resize_size is not None:
cropped_size_orig_space = int(crop_size * smallest_dimension / resize_size)
target_size = cropped_size_orig_space, cropped_size_orig_space
else:
target_size = crop_size, crop_size
similarity_map = F.interpolate(
similarity_map.reshape(target_shape),
size=target_size,
mode='bilinear',
align_corners=False,
)
margin_w, margin_h = (width - target_size[0]), (height - target_size[1])
margins_for_pad = (floor(margin_w / 2), ceil(margin_w / 2), floor(margin_h / 2), ceil(margin_h / 2))
similarity_map = F.pad(similarity_map[0, 0], margins_for_pad, value=float("NaN"))
else:
similarity_map = F.interpolate(
similarity_map.reshape(target_shape),
size=(height, width),
mode='bilinear',
align_corners=False,
)[0, 0]
return similarity_map.numpy()

Просмотреть файл

@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# -------------------------------------------------------------------------------------------\n",
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
"# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.\n",
"# -------------------------------------------------------------------------------------------"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tempfile\n",
"from pathlib import Path\n",
"\n",
"import torch\n",
"from transformers import AutoModel, AutoTokenizer\n",
"\n",
"from health_multimodal.text.inference_engine import TextInferenceEngine\n",
"from health_multimodal.image.inference_engine import ImageInferenceEngine\n",
"from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine\n",
"\n",
"from health_multimodal.image.model.model import ImageModel\n",
"from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference\n",
"\n",
"from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the text inference engine\n",
"HUGGING_FACE_URL = \"microsoft/BiomedVLP-CXR-BERT-specialized\"\n",
"text_inference = TextInferenceEngine(\n",
" tokenizer=AutoTokenizer.from_pretrained(HUGGING_FACE_URL, trust_remote_code=True),\n",
" text_model=AutoModel.from_pretrained(HUGGING_FACE_URL, trust_remote_code=True),\n",
")\n",
"\n",
"# Load the image inference engine\n",
"resnet_checkpoint_path = \"\" # add path to checkpoint here\n",
"if not Path(resnet_checkpoint_path).is_file():\n",
" print(\"Checkpoint file not found!\")\n",
" resnet_checkpoint_path = None\n",
"image_inference = ImageInferenceEngine(\n",
" image_model=ImageModel(img_model_type=\"resnet50\", joint_feature_size=128, pretrained_model_path=resnet_checkpoint_path),\n",
" transform=create_chest_xray_transform_for_inference(resize=512, center_crop_size=480))\n",
"\n",
"# Instantiate the joint inference engine\n",
"image_text_inference = ImageTextInferenceEngine(\n",
" image_inference_engine=image_inference,\n",
" text_inference_engine=text_inference,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def plot_phrase_grounding(image_path: Path, text_prompt: str) -> None:\n",
" sim_map = image_text_inference.get_similarity_map_from_raw_data(image_path=image_path, query_text=text_prompt)\n",
" plot_phrase_grounding_similarity_map(image_path=image_path, similarity_map=sim_map)\n",
"\n",
"def plot_phrase_grounding_from_url(image_url: str, text_prompt: str) -> None:\n",
" image_path = Path(tempfile.tempdir, 'downloaded_chest_xray.jpg')\n",
" !curl -s -L -o {image_path} {image_url}\n",
" plot_phrase_grounding(image_path, text_prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_prompt = \"Pneumonia in the right lung\"\n",
"image_url = \"https://prod-images-static.radiopaedia.org/images/1371188/0a1f5edc85aa58d5780928cb39b08659c1fc4d6d7c7dce2f8db1d63c7c737234_gallery.jpeg\"\n",
"plot_phrase_grounding_from_url(image_url, text_prompt)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.3 ('himl-multimodal')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b10e2d33e98f46e002b38decbb3115032da80ae497861a1d67d5527569b17994"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -1,8 +1,5 @@
{
"include": [
"health_multimodal",
"test_multimodal",
],
"include": [],
"useLibraryCodeForTypes": false,
"reportMissingImports": true,
"reportMissingTypeStubs": false,

Просмотреть файл

@ -0,0 +1,10 @@
huggingface-hub==0.6.0
lightning-bolts==0.3.4
pillow==9.0.1
pydicom==2.2.2
pytorch-lightning==1.5.5
scikit-image==0.18.1
SimpleITK==2.1.1
torch==1.9.0
torchvision==0.10.0
transformers==4.17.0

Просмотреть файл

@ -1,10 +1,7 @@
black==22.1.0
coverage==6.3.2
flake8==4.0.1
ipykernel==6.15.0
ipython==7.34.0
mypy==0.931
pylint==2.12.2
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1
pytest-rerunfailures==10.2
pytest-timeout==2.0.1
papermill==2.3.4
setuptools==59.5.0

Просмотреть файл

@ -62,7 +62,7 @@ install_requires = (here / 'requirements_run.txt').read_text().split("\n")
# Remove any whitespace and blank lines
install_requires = [line.strip() for line in install_requires if line.strip()]
description = 'Microsoft Health Intelligence package to work with multi-modal health data'
description = 'Microsoft Health Futures package to work with multi-modal health data'
setup(
name=package_name,

Просмотреть файл

@ -0,0 +1,66 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from pathlib import Path
from tempfile import NamedTemporaryFile
import numpy as np
import pytest
import SimpleITK as sitk
from PIL import Image
from health_multimodal.image.data.io import load_image, remap_to_uint8
def _assert_min_max_dtype(array: np.ndarray) -> None:
assert array.min() == 0
assert array.max() == 255
assert array.dtype == np.uint8
def test_load_image() -> None:
"""Test the image loading function using dummy files."""
size = 4, 4
def _assertions(path: Path) -> None:
img = load_image(path)
assert img.size == size
array = np.asarray(img)
_assert_min_max_dtype(array)
array = np.arange(np.prod(size), dtype=np.uint8).reshape(*size)
image = Image.fromarray(array).convert('RGB')
for suffix in '.jpg', '.jpeg', '.png':
with NamedTemporaryFile(suffix=suffix) as file:
image.save(file)
_assertions(Path(file.name))
nifti_img = sitk.GetImageFromArray(np.arange(16, dtype=np.uint16).reshape(*size) + 100)
with NamedTemporaryFile(suffix='.nii.gz') as file:
sitk.WriteImage(nifti_img, file.name)
_assertions(Path(file.name))
def test_remap_to_uint8() -> None:
"""Test the intensity casting function using different percentiles."""
array = np.arange(10).astype(np.uint16) # mimic DICOM data type
with pytest.raises(ValueError):
remap_to_uint8(array, (1, 2, 3)) # type: ignore[arg-type]
with pytest.raises(ValueError):
remap_to_uint8(array, (-1, 50))
with pytest.raises(ValueError):
remap_to_uint8(array, (1, 150))
with pytest.raises(ValueError):
remap_to_uint8(array, (5, 2))
normalized = remap_to_uint8(array)
_assert_min_max_dtype(normalized)
normalized = remap_to_uint8(array, (1, 99))
_assert_min_max_dtype(normalized)
array_positive_min = array + 5
normalized = remap_to_uint8(array_positive_min)
_assert_min_max_dtype(normalized)
normalized = remap_to_uint8(array_positive_min, (1, 99))
_assert_min_max_dtype(normalized)

Просмотреть файл

@ -0,0 +1,135 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import pytest
import torch
from pl_bolts.models.self_supervised.resnets import resnet50
from health_multimodal.image.model.model import ImageEncoder, ImageModel
from health_multimodal.image.model.modules import MultiTaskModel
def test_frozen_cnn_model() -> None:
"""
Checks if the mode of module parameters is set correctly.
"""
model = ImageModel(img_model_type='resnet18',
joint_feature_size=4,
num_classes=2,
freeze_encoder=True,
classifier_hidden_dim=24,
num_tasks=1)
assert not model.encoder.training
assert not model.projector.training
assert isinstance(model.classifier, MultiTaskModel)
assert model.classifier.training
model.train()
assert not model.encoder.training
assert not model.projector.training
assert isinstance(model.classifier, MultiTaskModel)
assert model.classifier.training
model.eval()
assert not model.encoder.training
assert not model.projector.training
assert isinstance(model.classifier, MultiTaskModel)
assert not model.classifier.training
model = ImageModel(img_model_type='resnet18',
joint_feature_size=4,
num_classes=2,
freeze_encoder=False,
classifier_hidden_dim=24,
num_tasks=1)
assert model.encoder.training
assert model.projector.training
assert model.classifier.training # type: ignore
def test_image_get_patchwise_projected_embeddings() -> None:
"""
Checks if the image patch embeddings are correctly computed and projected to the latent space.
"""
num_classes = 2
num_tasks = 1
joint_feature_size = 4
model = ImageModel(img_model_type='resnet18',
joint_feature_size=joint_feature_size,
num_classes=num_classes,
freeze_encoder=True,
classifier_hidden_dim=None,
num_tasks=num_tasks)
model.train()
with pytest.raises(AssertionError) as ex:
model.get_patchwise_projected_embeddings(torch.rand(size=(2, 3, 448, 448)), normalize=True)
assert "This function is only implemented for evaluation mode" in str(ex)
model.eval()
batch_size = 2
image = torch.rand(size=(batch_size, 3, 64, 64))
encoder_output, _ = model.encoder.forward(image, return_patch_embeddings=True)
h, w = encoder_output.shape[2:]
# First check the model output is in the expected shape,
# since this is used internally by get_patchwise_projected_embeddings
model_output = model.forward(image)
assert model_output.projected_patch_embeddings.shape == (batch_size, joint_feature_size, h, w)
assert model_output.projected_global_embedding.shape == (batch_size, joint_feature_size)
projected_global_embedding = model_output.projected_global_embedding
unnormalized_patch_embeddings = model.get_patchwise_projected_embeddings(image, normalize=False)
# Make sure the projected embeddings returned are the right shape
assert unnormalized_patch_embeddings.shape == (batch_size, h, w, joint_feature_size)
result_1 = torch.mean(unnormalized_patch_embeddings, dim=(1, 2)) # B x W x H x D
result_2 = projected_global_embedding
assert torch.allclose(result_1, result_2)
# test normalized version
normalized_patch_embeddings = model.get_patchwise_projected_embeddings(image, normalize=True)
assert normalized_patch_embeddings.shape == (batch_size, h, w, joint_feature_size)
# Make sure the norm is 1 along the embedding dimension
norm = normalized_patch_embeddings.norm(p=2, dim=-1)
assert torch.all(torch.abs(norm - 1.0) < 1e-5)
def test_reload_resnet_with_dilation() -> None:
"""
Tests if the resnet model can be switched from pooling to dilated convolutions.
"""
replace_stride_with_dilation = [False, False, True]
# resnet18 does not support dilation
model_with_dilation = ImageEncoder(img_model_type="resnet18")
with pytest.raises(NotImplementedError):
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
# resnet50
original_model = ImageEncoder(img_model_type="resnet50").eval()
model_with_dilation = ImageEncoder(img_model_type="resnet50").eval()
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
assert not model_with_dilation.training
batch_size = 2
image = torch.rand(size=(batch_size, 3, 64, 64))
with torch.no_grad():
outputs_dilation, _ = model_with_dilation(image, return_patch_embeddings=True)
outputs_original, _ = original_model(image, return_patch_embeddings=True)
assert outputs_original.shape[2] * \
2 == outputs_dilation.shape[2], "The dilation model should return larger feature maps."
expected_model = resnet50(return_all_feature_maps=True, pretrained=True,
replace_stride_with_dilation=replace_stride_with_dilation)
expected_model.eval()
with torch.no_grad():
expected_output = expected_model(image)[-1]
assert torch.allclose(outputs_dilation, expected_output)

Просмотреть файл

@ -0,0 +1,27 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import pytest
from transformers import AutoModel, AutoTokenizer
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED, TypePrompts
from health_multimodal.text.inference_engine import TextInferenceEngine
text_inference = TextInferenceEngine(
tokenizer=AutoTokenizer.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
text_model=AutoModel.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
)
@pytest.mark.parametrize("prompts", ("", "hello", "this is a test", ["this is", "also a test"]))
def test_good_prompts(prompts: TypePrompts) -> None:
text_inference.tokenize_input_prompts(prompts)
@pytest.mark.parametrize("prompts", ("[CLS]", "hello [PAD]"))
def test_bad_prompts(prompts: TypePrompts) -> None:
with pytest.raises(ValueError, match="The input .* contains at least one special token"):
text_inference.tokenize_input_prompts(prompts)

Просмотреть файл

@ -0,0 +1,75 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import json
from pathlib import Path
from tempfile import TemporaryDirectory
import torch
import pytest
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
from health_multimodal.text.model.modelling_cxrbert import CXRBertModel
def test_model_instantiation() -> None:
def _test_model_forward(model: CXRBertModel) -> None:
batch_size = 2
seq_length = 5
input_ids = torch.randint(0, 512, (batch_size, seq_length))
attention_mask = torch.ones_like(input_ids)
outputs = model(input_ids, attention_mask, output_cls_projected_embedding=True)
assert outputs.last_hidden_state.shape == (batch_size, seq_length, config.hidden_size)
assert outputs.cls_projected_embedding.shape == (batch_size, config.projection_size)
projected_embeddings = model.get_projected_text_embeddings(input_ids, attention_mask)
assert projected_embeddings.shape == (batch_size, config.projection_size)
norm = torch.norm(projected_embeddings[0], p=2).item()
assert pytest.approx(norm) == 1
outputs = model(input_ids, attention_mask, output_hidden_states=False)
assert outputs.hidden_states is None
outputs_in_tuple = model(input_ids, attention_mask, return_dict=False)
assert outputs.cls_projected_embedding == outputs_in_tuple[2]
assert torch.allclose(outputs.last_hidden_state, outputs_in_tuple[0])
config = CXRBertConfig(hidden_size=6,
projection_size=4,
num_hidden_layers=1,
num_attention_heads=2,
output_attentions=True,
return_dict=True)
model = CXRBertModel(config)
model = model.eval()
_test_model_forward(model=model)
# Try saving this model and check the saved model/config
with TemporaryDirectory() as save_dir_as_str:
save_dir = Path(save_dir_as_str)
model.save_pretrained(save_dir)
weights_file = save_dir / WEIGHTS_NAME
assert weights_file.exists()
saved_weights = torch.load(weights_file)
# Make sure the MLM head was saved
assert "cls.predictions.bias" in saved_weights
# Make sure the project head was saved
assert "cls_projection_head.dense_to_hidden.weight" in saved_weights
# Check the config file
config_file = save_dir / CONFIG_NAME
assert config_file.exists()
with config_file.open() as f:
config_json = json.load(f)
assert "projection_size" in config_json
assert config_json["projection_size"] == config.projection_size
assert config_json["architectures"] == ["CXRBertModel"]
# Make sure we can load from the saved model
model_from_pretrained = CXRBertModel.from_pretrained(save_dir)
_test_model_forward(model=model_from_pretrained)

Просмотреть файл

@ -0,0 +1,119 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from typing import Tuple
import pytest
import torch
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED
from health_multimodal.text.inference_engine import TextInferenceEngine
from health_multimodal.text.model.modelling_cxrbert import CXRBertModel
from health_multimodal.text.model.configuration_cxrbert import CXRBertTokenizer
def test_text_inference_init_model_type() -> None:
"""
Test that init fails if the wrong model type is passed in
"""
tokenizer, _ = _get_cxr_bert()
false_model = torch.nn.Linear(4, 4)
with pytest.raises(AssertionError) as ex:
TextInferenceEngine(tokenizer=tokenizer, text_model=false_model) # type: ignore[arg-type]
assert f"Expected a BertForMaskedLM, got {type(false_model)}" in str(ex)
def test_l2_normalization() -> None:
"""
Test that the text embeddings (CLS token) are l2 normalized.
"""
tokenizer, text_model = _get_cxr_bert()
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
input_query = ["There is a tumor in the left lung", "Lungs are all clear"]
embedding = text_inference.get_embeddings_from_prompt(prompts=input_query)
norm = torch.norm(embedding, p=2, dim=-1)
assert torch.allclose(norm, torch.ones_like(norm))
def test_sentence_semantic_similarity() -> None:
"""
Test that the sentence embedding similarity computed by the text model is meaningful.
"""
tokenizer, text_model = _get_cxr_bert()
# CLS token has no dedicated meaning, but we can expect vector similarity due to token overlap between the sentences
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
input_query = ["There is a tumor in the left lung", "Tumor is present", "Patient is admitted to the hospital today"]
embedding = text_inference.get_embeddings_from_prompt(input_query)
pos_sim = torch.dot(embedding[0], embedding[1])
neg_sim_1 = torch.dot(embedding[0], embedding[2])
neg_sim_2 = torch.dot(embedding[1], embedding[2])
assert pos_sim > neg_sim_1
assert pos_sim > neg_sim_2
def _get_cxr_bert() -> Tuple[CXRBertTokenizer, CXRBertModel]:
model_name = BIOMED_VLP_CXR_BERT_SPECIALIZED
tokenizer = CXRBertTokenizer.from_pretrained(model_name)
text_model = CXRBertModel.from_pretrained(model_name)
return tokenizer, text_model
def test_triplet_similarities_with_inference_engine() -> None:
"""
Test that the triplet sentence similarities computed by the text model are meaningful.
"""
tokenizer, text_model = _get_cxr_bert()
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
reference = \
["Heart size is top normal.", "There is no pneumothorax or pleural effusion",
"The patient has been extubated.", "The lungs are clear bilaterally.",
"No pleural effusions."]
synonyms = \
["The cardiac silhouette is normal in size.", "No pleural effusion or pneumothorax is seen",
"There has been interval extubation.", "The lungs are unremarkable.",
"Also, the lateral pleural sinuses are free, which excludes major pleural effusion."]
contradictions = \
["The heart is largely enlarged.", "The extent of the pleural effusion is constant.",
"The patient is intubated", "The lungs are mostly clear aside from lower lung atelectasis.",
"The loculated right pleural effusion has increased, and is now moderate in size."]
synonym_score = text_inference.get_pairwise_similarities(reference, synonyms)
contradictory_score = text_inference.get_pairwise_similarities(reference, contradictions)
print("Synonym score:", synonym_score)
print("Contradictory score:", contradictory_score)
assert torch.all(synonym_score > contradictory_score)
assert torch.all(synonym_score > 0.5)
assert torch.all(1.0 >= synonym_score)
assert torch.all(contradictory_score < 0.5)
assert torch.all(contradictory_score >= -1.0)
def test_mlm_with_inference_engine_with_hf_hub() -> None:
"""
Test that the MLM model can be used with the inference engine and the filled masked tokens are correct.
"""
tokenizer, text_model = _get_cxr_bert()
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
# ##### Test Masked Language Modelling ######
mlm_prompts = ["Moderate [MASK] pleural effusions and associated [MASK]",
"Right basilar [MASK], potentially due to infiltrate in the proper clinical setting",
"The right basilar chest [MASK] appears to be in unchanged position",
"The small basal pneumothorax has slightly [MASK] compared to the prior",
"Poorly defined [MASK] in the right lung is concerning for aspiration",
"Retrocardiac opacity likely reflects known hiatal [MASK]"]
target_tokens = [['bilateral', 'atelectasis'], ['opacity'], ['tube'], ['increased'], ['opacity'], ['hernia']]
output_top_1 = text_inference.predict_masked_tokens(mlm_prompts)
assert output_top_1 == target_tokens

Просмотреть файл

@ -0,0 +1,76 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import tempfile
from pathlib import Path
import torch
import pytest
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED
from health_multimodal.text.inference_engine import TextInferenceEngine
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
text_inference = TextInferenceEngine(
tokenizer=AutoTokenizer.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
text_model=AutoModel.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
)
@pytest.mark.parametrize("height", (400, 500, 650))
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
def test_vlp_inference(height: int, query_text: str) -> None:
image_embedding_shapes = {
480: (15, 15),
}
joint_feature_size = 128
resize = 512
center_crop_size = 480
width = 600
image_inference = ImageInferenceEngine(
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
img_txt_inference = ImageTextInferenceEngine(
image_inference_engine=image_inference,
text_inference_engine=text_inference,
)
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
image_path = Path(f.name)
image = Image.new('RGB', (width, height))
image.save(image_path)
# Test integrated VLP inference engine
resampled_similarity_map = img_txt_inference.get_similarity_map_from_raw_data(
image_path=image_path,
query_text=query_text,
)
assert resampled_similarity_map.shape == (height, width)
np.nan_to_num(resampled_similarity_map, copy=False)
assert resampled_similarity_map.min() >= -1
assert resampled_similarity_map.max() <= 1
# Test individual components
image_embedding, size = img_txt_inference.image_inference_engine.get_patch_embeddings_from_image(image_path)
assert (width, height) == size
expected_image_embedding_size = image_embedding_shapes[center_crop_size]
assert image_embedding.shape == (*expected_image_embedding_size, joint_feature_size)
normalized_image_embedding = torch.norm(image_embedding, p=2, dim=-1)
assert torch.allclose(normalized_image_embedding, torch.ones_like(normalized_image_embedding))
text_embedding = img_txt_inference.text_inference_engine.get_embeddings_from_prompt(query_text)
assert text_embedding.shape == (1, joint_feature_size)
similarity_map = img_txt_inference._get_similarity_map_from_embeddings(image_embedding, text_embedding)
assert similarity_map.shape == expected_image_embedding_size

Просмотреть файл

@ -1,3 +1,4 @@
-r build_requirements.txt
-r hi-ml/run_requirements.txt
-r hi-ml-azure/run_requirements.txt
-r multimodal/requirements_run.txt