зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add text and image inference code for CXR multimodal (#438)
* add image encoder and utils * update the imports * add file headers * pin the library versions * restructure the dir, add text related code, and update the env * add text inference tests * insert the missing headers * Fix PEP8 issues * Remove placeholder files * flake8 fixes * update the env and add img model tests * fix image related tests and environment * cosmetic changes related to linters * add test image preprocessing * update the test env * update env * pr comments -- env update * fix mypy and add text module * add tests for cxrbert * revert * mypy - pyright disable * Add minor edits to io module * Add __init__ files * Fix some mypy errors * add image and text inference engines * add init * add headers and update docstrings * sort imports * Unify imports * Replace Health Intelligence with Health Futures * Update GitHub workflow * Add multimodal package to flake8 checks * Fix flake8 errors * test pretrained checkpoint * remove unused par * Add phrase grounding notebook * pretrained path -- image model * add ipython in the env for notebooks * Fix incorrect install command in CI workflow * move visualisation and update the notebook * update the import * Remove phrase grounding example Python file * Use linear interpolation for similarity maps * Stop trying to compute and upload coverage * Fix float comparison in test * add another demo example, unifying plotting * Fix some mypy errors * Fix similarity map indexing * Fix warning about interpolation * Remove hard-coded strings from notebook * Ignore mypy error in test * Add typing hints to visualization function * flake8 * Start adding docs * Move notebook outside Python package * Run notebook in CI * Move requirement * Add call to pytest and install kernel for papermill * Move requirements to file * Improve docstring by adding info on percentiles * Use Path class for path parameter * Add typing hints in notebook functions * Add clarifying comment * Move notebooks command to local Makefile * Remove unused variables * Add docstrings to visualization functions * Add Microsoft file header * Launch CI tests also when hi-ml changes * Pass list of transforms to validation function * Fix docstring * Update typing hints * Update kwarg name * Fix docstring * Replace hard-coded variable with kwarg * Declare variable before the line in which it is used * Fix type of returned value * Rename class and fix docstrings * Fix docstring * Improve some docstrings * Improve assertion * Replace import with type definition * Stop initializing list in kwarg * Fix docstring * Fix docstring * Fix docstring * Fix docstring * Remove old comment from ImageTextInferenceEngine docstring * Update docstring for get_patch_embeddings_from_image * Apply docstring suggestion for get_patch_embeddings_from_image Co-authored-by: Fernando Pérez-García <fepegar@gmail.com> * Stop running multimodal tests when the hi-ml libraries change. * Fix docstring * Use double quotes for consistency * Improve error message and add ResNet enum * Update docstring for tokenize_input_prompts * Refactor some requirements * Add missing requirement * Stop installing hi-ml-azure in CI * Install all requirements in CI * Install all requirements in CI before pytest * Add missing run requirements * Ignore mypy error * Upgrade Pillow to 9.0.1 to avoid security flag * Add some VLP unit tests * Improve check for special tokens * Add missing Microsoft headers * Remove some unused test requirements * Allow [MASK] token in TextInput prompts * Update docstring for MultiTaskModel * Add __init__ file * Fix docstring * Read requirements from files in Conda YAML * Enable members documentation by default * Add multimodal folder to docs conf path * Add multimodal run requirements to RTD requirements * Stop ignoring multimodal API rst file * Git-add multimodal API file * Fix docstrings * Refactor API documentation * Use package description file for docs Co-authored-by: Fernando Pérez-García <fperezgarcia@microsoft.com> Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com> Co-authored-by: Fernando Pérez-García <fepegar@gmail.com>
This commit is contained in:
Родитель
a78c39f0e9
Коммит
a618a844aa
|
@ -58,29 +58,6 @@ jobs:
|
|||
make pip_test
|
||||
make mypy
|
||||
|
||||
pyright:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: '14'
|
||||
|
||||
- uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
environment-file: ${{ env.folder }}/environment.yml
|
||||
|
||||
- name: pyright
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
conda info
|
||||
cd ${{ env.folder }}
|
||||
make pyright_install
|
||||
make pyright
|
||||
|
||||
pytest:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
|
@ -93,13 +70,20 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ env.pythonVersion }}
|
||||
|
||||
- name: Upgrade PIP
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make pip_upgrade
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
|
||||
# Install local package in editable mode
|
||||
make pip_local
|
||||
|
||||
# Run tests
|
||||
make pip_test
|
||||
make pip
|
||||
make pytest
|
||||
|
||||
- name: Run Jupyter notebooks
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make notebooks
|
||||
|
|
|
@ -143,6 +143,7 @@ tensorboard_logs/
|
|||
|
||||
docs/source/api
|
||||
!docs/source/api/api.rst
|
||||
!docs/source/api/multimodal.rst
|
||||
# This file is copied from repository root to docs/source by the makefile
|
||||
/docs/source/CHANGELOG.md
|
||||
/docs/source/CONTRIBUTING.md
|
||||
|
|
|
@ -44,6 +44,11 @@ repos:
|
|||
alias: flake8-hi-ml-histopathology
|
||||
files: ^hi-ml-histopathology/
|
||||
args: [--config, hi-ml-histopathology/.flake8]
|
||||
- id: flake8
|
||||
name: flake8 ./multimodal/
|
||||
alias: flake8-multimodal
|
||||
files: ^multimodal/
|
||||
args: [--config, multimodal/.flake8]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-autopep8
|
||||
rev: v1.5.7
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
API
|
||||
###
|
||||
|
||||
.. currentmodule:: health_multimodal
|
||||
|
||||
|
||||
Vision-language processing (VLP)
|
||||
--------------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
vlp
|
||||
|
||||
|
||||
Image processing
|
||||
----------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
image
|
||||
|
||||
|
||||
Text processing
|
||||
----------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
text
|
||||
|
||||
|
||||
Common utils
|
||||
------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
common
|
|
@ -22,6 +22,7 @@ import sys
|
|||
sys.path.insert(0, os.path.abspath('../../hi-ml/src'))
|
||||
sys.path.insert(0, os.path.abspath('../../hi-ml-azure/src'))
|
||||
sys.path.insert(0, os.path.abspath('../../hi-ml-histopathology/src'))
|
||||
sys.path.insert(0, os.path.abspath('../../multimodal'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
|
@ -87,3 +88,7 @@ highlight_language = "python"
|
|||
|
||||
# For classes, insert documentation from the class itself AND the constructor
|
||||
autoclass_content = "both"
|
||||
|
||||
autodoc_default_options = {
|
||||
'members': True,
|
||||
}
|
||||
|
|
|
@ -46,6 +46,13 @@ The `hi-ml` toolbox provides
|
|||
azure_setup.md
|
||||
dsa.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Multimodal learning
|
||||
|
||||
package_description_multimodal.md
|
||||
api/multimodal.rst
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Self supervised learning
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Multimodal learning
|
||||
|
||||
Introduction
|
||||
|
||||
## Installation
|
||||
|
||||
## Example
|
||||
|
||||
## Credits
|
|
@ -0,0 +1 @@
|
|||
../../multimodal/package_description.md
|
|
@ -57,3 +57,8 @@ pyright:
|
|||
|
||||
# run basic checks
|
||||
check: flake8 mypy pyright
|
||||
|
||||
# run notebooks to ensure there are no errors
|
||||
notebooks: pip_test
|
||||
ipython kernel install --name "python3" --user
|
||||
papermill notebooks/phrase_grounding.ipynb /tmp/phrase_grounding_output.ipynb
|
||||
|
|
|
@ -1,24 +1,14 @@
|
|||
name: multimodal
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
- pytorch
|
||||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- cudatoolkit=11.1
|
||||
- pytorch=1.9.0
|
||||
- torchvision=0.10.0
|
||||
- pip:
|
||||
# This file should contain pinned versions of all packages in requirements_run.txt
|
||||
# and requirements_test.txt.
|
||||
# It is not possible to reference the requirements files here with "-r" because AzureML's
|
||||
# environment creation can't handle those.
|
||||
# Run requirements: presently empty
|
||||
# Test requirements
|
||||
- black==22.1.0
|
||||
- coverage==6.3.2
|
||||
- flake8==4.0.1
|
||||
- mypy==0.931
|
||||
- pylint==2.12.2
|
||||
- pycobertura==2.0.1
|
||||
- pytest==6.2.2
|
||||
- pytest-cov==2.11.1
|
||||
- pytest-rerunfailures==10.2
|
||||
- pytest-timeout==2.0.1
|
||||
- -r requirements_run.txt
|
||||
- -r requirements_test.txt
|
||||
|
|
|
@ -2,6 +2,3 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
def dummy() -> int:
|
||||
return 1
|
|
@ -2,11 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
|
||||
from health_multimodal.dummy import dummy
|
||||
"""General utils
|
||||
|
||||
.. currentmodule:: health_multimodal.common
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_dummy() -> None:
|
||||
assert dummy() == 1
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
visualization
|
||||
"""
|
|
@ -0,0 +1,119 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
|
||||
from health_multimodal.image.data.io import load_image
|
||||
|
||||
|
||||
TypeArrayImage = Union[np.ndarray, Image.Image]
|
||||
|
||||
|
||||
def _plot_image(
|
||||
image: TypeArrayImage,
|
||||
axis: plt.Axes,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Plot an image on a given axis, deleting the axis ticks and axis labels.
|
||||
|
||||
:param image: Input image.
|
||||
:param axis: Axis to plot the image on.
|
||||
:param title: Title used for the axis.
|
||||
"""
|
||||
axis.imshow(image)
|
||||
axis.axis('off')
|
||||
if title is not None:
|
||||
axis.set_title(title)
|
||||
|
||||
|
||||
def _get_isolines_levels(step_size: float) -> np.ndarray:
|
||||
num_steps = np.floor(round(1 / step_size)).astype(int)
|
||||
levels = np.linspace(step_size, 1, num_steps)
|
||||
return levels
|
||||
|
||||
|
||||
def _plot_isolines(
|
||||
image: TypeArrayImage,
|
||||
heatmap: np.ndarray,
|
||||
axis: plt.Axes,
|
||||
title: Optional[str] = None,
|
||||
colormap: str = 'RdBu_r',
|
||||
step: float = 0.25,
|
||||
) -> None:
|
||||
"""Plot an image and overlay heatmap isolines on it.
|
||||
|
||||
:param image: Input image.
|
||||
:param heatmap: Heatmap of the same size as the image.
|
||||
:param axis: Axis to plot the image on.
|
||||
:param title: Title used for the axis.
|
||||
:param colormap: Name of the Matplotlib colormap used for the isolines.
|
||||
:param step: Step size between the isolines levels. The levels are in :math:`(0, 1]`.
|
||||
For example, a step size of 0.25 will result in isolines levels of 0.25, 0.5, 0.75 and 1.
|
||||
"""
|
||||
axis.imshow(image)
|
||||
levels = _get_isolines_levels(step)
|
||||
contours = axis.contour(
|
||||
heatmap,
|
||||
cmap=colormap,
|
||||
vmin=-1,
|
||||
vmax=1,
|
||||
levels=levels,
|
||||
)
|
||||
axis.clabel(contours, inline=True, fontsize=10)
|
||||
axis.axis('off')
|
||||
if title is not None:
|
||||
axis.set_title(title)
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
image: TypeArrayImage,
|
||||
heatmap: np.ndarray,
|
||||
figure: plt.Figure,
|
||||
axis: plt.Axes,
|
||||
colormap: str = 'RdBu_r',
|
||||
title: Optional[str] = None,
|
||||
alpha: float = 0.5,
|
||||
) -> None:
|
||||
"""Plot a heatmap overlaid on an image.
|
||||
|
||||
:param image: Input image.
|
||||
:param heatmap: Input heatmap of the same size as the image.
|
||||
:param figure: Figure to plot the images on.
|
||||
:param axis: Axis to plot the images on.
|
||||
:param colormap: Name of the Matplotlib colormap for the heatmap.
|
||||
:param title: Title used for the axis.
|
||||
:param alpha: Heatmap opacity. Must be in :math:`[0, 1]`.
|
||||
"""
|
||||
axis.imshow(image)
|
||||
axes_image = axis.matshow(heatmap, alpha=alpha, cmap=colormap, vmin=-1, vmax=1)
|
||||
# https://www.geeksforgeeks.org/how-to-change-matplotlib-color-bar-size-in-python/
|
||||
divider = make_axes_locatable(axis)
|
||||
colorbar_axes = divider.append_axes('right', size='10%', pad=0.1)
|
||||
colorbar = figure.colorbar(axes_image, cax=colorbar_axes)
|
||||
# https://stackoverflow.com/a/50671487/3956024
|
||||
colorbar.ax.tick_params(pad=35)
|
||||
plt.setp(colorbar.ax.get_yticklabels(), ha='right')
|
||||
axis.axis('off')
|
||||
if title is not None:
|
||||
axis.set_title(title)
|
||||
|
||||
|
||||
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> None:
|
||||
"""Plot visualization of the input image, the similarity heatmap and the heatmap isolines.
|
||||
|
||||
:param image_path: Path to the input image.
|
||||
:param similarity_map: Phase grounding similarity map of the same size as the image.
|
||||
"""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 6))
|
||||
image = load_image(image_path).convert('RGB')
|
||||
_plot_image(image, axis=axes[0], title='Input image')
|
||||
_plot_isolines(image, similarity_map, axis=axes[1], title='Similarity isolines')
|
||||
_plot_heatmap(image, similarity_map, figure=fig, axis=axes[2], title='Similarity heatmap')
|
|
@ -0,0 +1,33 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
"""Image-related tools
|
||||
|
||||
.. currentmodule:: health_multimodal.image.data
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
io
|
||||
transforms
|
||||
|
||||
|
||||
.. currentmodule:: health_multimodal.image.model
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
model
|
||||
modules
|
||||
"""
|
||||
|
||||
from .model.model import ImageModel, ResnetType
|
||||
from .inference_engine import ImageInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
'ImageModel',
|
||||
'ResnetType',
|
||||
'ImageInferenceEngine',
|
||||
]
|
|
@ -0,0 +1,4 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
|
@ -0,0 +1,71 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydicom as dicom
|
||||
from PIL import Image
|
||||
import SimpleITK as sitk
|
||||
from skimage import io
|
||||
|
||||
|
||||
def remap_to_uint8(array: np.ndarray, percentiles: Optional[Tuple[float, float]] = None) -> np.ndarray:
|
||||
"""Remap values in input so the output range is :math:`[0, 255]`.
|
||||
|
||||
Percentiles can be used to specify the range of values to remap.
|
||||
This is useful to discard outliers in the input data.
|
||||
|
||||
:param array: Input array.
|
||||
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
|
||||
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
|
||||
:returns: Array with ``0`` and ``255`` as minimum and maximum values.
|
||||
"""
|
||||
array = array.astype(float)
|
||||
if percentiles is not None:
|
||||
len_percentiles = len(percentiles)
|
||||
if len_percentiles != 2:
|
||||
message = (
|
||||
'The value for percentiles should be a sequence of length 2,'
|
||||
f' but has length {len_percentiles}'
|
||||
)
|
||||
raise ValueError(message)
|
||||
a, b = percentiles
|
||||
if a >= b:
|
||||
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
|
||||
if a < 0 or b > 100:
|
||||
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
|
||||
cutoff: np.ndarray = np.percentile(array, percentiles)
|
||||
array = np.clip(array, *cutoff)
|
||||
array -= array.min()
|
||||
array /= array.max()
|
||||
array *= 255
|
||||
return array.astype(np.uint8)
|
||||
|
||||
|
||||
def load_image(path: Path) -> Image.Image:
|
||||
"""Load an image from disk.
|
||||
|
||||
The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers.
|
||||
|
||||
:param path: Path to image.
|
||||
:returns: Image as ``Pillow`` ``Image``.
|
||||
"""
|
||||
# Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models
|
||||
if path.suffix in [".jpg", ".jpeg", ".png"]:
|
||||
image = io.imread(path)
|
||||
elif path.suffixes == [".nii", ".gz"]:
|
||||
image = sitk.GetArrayFromImage(sitk.ReadImage(str(path)))
|
||||
if image.shape[0] == 1:
|
||||
image = np.squeeze(image, axis=0)
|
||||
assert image.ndim == 2
|
||||
elif path.suffix == ".dcm":
|
||||
image = dicom.dcmread(path).pixel_array
|
||||
else:
|
||||
raise ValueError(f"Image type not supported, filename was: {path}")
|
||||
|
||||
image = remap_to_uint8(image)
|
||||
return Image.fromarray(image).convert("L")
|
|
@ -0,0 +1,70 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
|
||||
|
||||
|
||||
class ExpandChannels:
|
||||
"""
|
||||
Transforms an image with one channel to an image with three channels by copying
|
||||
pixel intensities of the image along the 1st dimension.
|
||||
"""
|
||||
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param data: Tensor of shape [1, H, W].
|
||||
:return: Tensor with channel copied three times, shape [3, H, W].
|
||||
"""
|
||||
if data.shape[0] != 1:
|
||||
raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
|
||||
return torch.repeat_interleave(data, 3, dim=0)
|
||||
|
||||
|
||||
def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose:
|
||||
"""
|
||||
Defines the image transformation pipeline for Chest-Xray datasets.
|
||||
|
||||
:param resize: The size to resize the image to. Linear resampling is used.
|
||||
Resizing is applied on the axis with smaller shape.
|
||||
:param center_crop_size: The size to center crop the image to. Square crop is applied.
|
||||
"""
|
||||
|
||||
transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
|
||||
return Compose(transforms)
|
||||
|
||||
|
||||
def infer_resize_params(val_img_transforms: Sequence[Callable]) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Given the validation transforms pipeline, extract the sizes to which the image was resized and cropped, if any.
|
||||
"""
|
||||
resize_size_from_transforms = None
|
||||
crop_size_from_transforms = None
|
||||
supported_types = Resize, CenterCrop, ToTensor, ExpandChannels
|
||||
for transform in val_img_transforms:
|
||||
trsf_type = type(transform)
|
||||
if trsf_type not in supported_types:
|
||||
raise ValueError(f"Unsupported transform type {trsf_type}. Supported types are {supported_types}")
|
||||
if isinstance(transform, Resize):
|
||||
if resize_size_from_transforms is None and crop_size_from_transforms is None:
|
||||
assert transform.max_size is None
|
||||
assert isinstance(transform.size, int), f"Expected int, got {transform.size}"
|
||||
resize_size_from_transforms = transform.size
|
||||
else:
|
||||
raise ValueError("Expected Resize to be the first transform if present in val_img_transforms")
|
||||
elif isinstance(transform, CenterCrop):
|
||||
if crop_size_from_transforms is None:
|
||||
two_dims = len(transform.size) == 2
|
||||
same_sizes = transform.size[0] == transform.size[1]
|
||||
is_square = two_dims and same_sizes
|
||||
assert is_square, "Only square center crop supported"
|
||||
crop_size_from_transforms = transform.size[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Crop size has already been set to {crop_size_from_transforms} in a previous transform")
|
||||
|
||||
return resize_size_from_transforms, crop_size_from_transforms
|
|
@ -0,0 +1,68 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from health_multimodal.image.data.io import load_image
|
||||
from health_multimodal.image.data.transforms import infer_resize_params
|
||||
from health_multimodal.image.model.model import ImageModel
|
||||
|
||||
TypeShape2D = Tuple[int, int]
|
||||
|
||||
|
||||
class ImageInferenceEngine:
|
||||
"""
|
||||
Encapsulate inference-time operations on an image model.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model: ImageModel, transform: Compose):
|
||||
"""
|
||||
:param img_model: Trained image model
|
||||
:param transform: Transform to apply to the image after loading. Must return a torch.Tensor that can be
|
||||
input directly to the image model.
|
||||
"""
|
||||
|
||||
assert isinstance(image_model, ImageModel), f"Expected an ImageModel, got {type(image_model)}"
|
||||
|
||||
self.model = image_model
|
||||
self.transform = transform
|
||||
self.device = next(self.model.parameters()).device
|
||||
|
||||
self.model.eval()
|
||||
self.resize_size, self.crop_size = infer_resize_params(self.transform.transforms)
|
||||
|
||||
def load_and_transform_input_image(self, image_path: Path, transform: Callable) -> Tuple[torch.Tensor, TypeShape2D]:
|
||||
"""Read an image and apply the transform to it.
|
||||
|
||||
1. Read the image from the given path
|
||||
2. Apply transform
|
||||
3. Add the batch dimension
|
||||
4. Move to the correct device
|
||||
|
||||
:param return_original_shape: Whether to return an extra tuple that has the original shape of the image
|
||||
before the transforms. The tuple returned contains (width, height).
|
||||
"""
|
||||
image = load_image(image_path)
|
||||
transformed_image = transform(image).unsqueeze(0).to(self.device)
|
||||
return transformed_image, image.size
|
||||
|
||||
@torch.no_grad()
|
||||
def get_patch_embeddings_from_image(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
|
||||
"""Compute image embeddings in the joint latent space, preserving the image grid.
|
||||
|
||||
:param image_path: Path to the image to compute embeddings for.
|
||||
:return: A tuple containing the image patch embeddings and
|
||||
the shape of the original image (width, height) before applying transforms.
|
||||
"""
|
||||
input_image, img_shape = self.load_and_transform_input_image(image_path, self.transform)
|
||||
projected_img_emb = self.model.get_patchwise_projected_embeddings(input_image, normalize=True)
|
||||
assert projected_img_emb.shape[0] == 1
|
||||
|
||||
return projected_img_emb[0], img_shape
|
|
@ -0,0 +1,4 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
|
@ -0,0 +1,173 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
|
||||
|
||||
from .modules import MLP, MultiTaskModel
|
||||
|
||||
TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
||||
@enum.unique
|
||||
class ResnetType(str, enum.Enum):
|
||||
RESNET18 = "resnet18"
|
||||
RESNET50 = "resnet50"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageModelOutput():
|
||||
img_embedding: torch.Tensor
|
||||
patch_embedding: torch.Tensor
|
||||
projected_global_embedding: torch.Tensor
|
||||
class_logits: torch.Tensor
|
||||
projected_patch_embeddings: torch.Tensor
|
||||
|
||||
|
||||
class ImageModel(nn.Module):
|
||||
"""Image encoder module"""
|
||||
|
||||
def __init__(self,
|
||||
img_model_type: str,
|
||||
joint_feature_size: int,
|
||||
freeze_encoder: bool = False,
|
||||
pretrained_model_path: Optional[str] = None,
|
||||
**downstream_classifier_kwargs: Any):
|
||||
super().__init__()
|
||||
|
||||
# Initiate encoder, projector, and classifier
|
||||
self.encoder = ImageEncoder(img_model_type)
|
||||
self.feature_size = get_encoder_output_dim(self.encoder)
|
||||
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
|
||||
hidden_dim=joint_feature_size, use_1x1_convs=True)
|
||||
self.downstream_classifier_kwargs = downstream_classifier_kwargs
|
||||
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
|
||||
|
||||
# Initialise the mode of modules
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.train()
|
||||
|
||||
if pretrained_model_path is not None:
|
||||
assert isinstance(pretrained_model_path, str), f"Expected a string, got {type(pretrained_model_path)}"
|
||||
self.load_state_dict(torch.load(pretrained_model_path))
|
||||
|
||||
def train(self, mode: bool = True) -> Any:
|
||||
"""Switch the model between training and evaluation modes."""
|
||||
super().train(mode=mode)
|
||||
if self.freeze_encoder:
|
||||
self.encoder.train(mode=False)
|
||||
self.projector.train(mode=False)
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor) -> ImageModelOutput:
|
||||
with torch.set_grad_enabled(not self.freeze_encoder):
|
||||
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
|
||||
projected_patch_embeddings = self.projector(patch_x)
|
||||
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
|
||||
|
||||
logits = self.classifier(pooled_x) if self.classifier else None
|
||||
return ImageModelOutput(img_embedding=pooled_x,
|
||||
patch_embedding=patch_x,
|
||||
class_logits=logits,
|
||||
projected_patch_embeddings=projected_patch_embeddings,
|
||||
projected_global_embedding=projected_global_embedding)
|
||||
|
||||
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
|
||||
"""Create the classification module for the downstream task."""
|
||||
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
|
||||
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
|
||||
"""Get patch-wise projected embeddings from the CNN model.
|
||||
|
||||
:param input_img: input tensor image [B, C, H, W].
|
||||
:param normalize: If ``True``, the embeddings are L2-normalized.
|
||||
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
|
||||
"""
|
||||
assert not self.training, "This function is only implemented for evaluation mode"
|
||||
outputs = self.forward(input_img)
|
||||
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
|
||||
if normalize:
|
||||
projected_embeddings = F.normalize(projected_embeddings, dim=1)
|
||||
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
|
||||
return projected_embeddings
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""Image encoder trunk module for the ``ImageModel`` class.
|
||||
|
||||
:param img_model_type: Type of image model to use: either ``"resnet18"`` or ``"resnet50"``.
|
||||
"""
|
||||
|
||||
def __init__(self, img_model_type: str):
|
||||
super().__init__()
|
||||
self.img_model_type = img_model_type
|
||||
self.encoder = self._create_encoder()
|
||||
|
||||
def _create_encoder(self, **kwargs: Any) -> nn.Module:
|
||||
supported = ResnetType.RESNET18, ResnetType.RESNET50
|
||||
if self.img_model_type not in supported:
|
||||
raise NotImplementedError(f"Image model type \"{self.img_model_type}\" must be in {supported}")
|
||||
encoder_class = resnet18 if self.img_model_type == ResnetType.RESNET18 else resnet50
|
||||
encoder = encoder_class(return_all_feature_maps=True, pretrained=True, **kwargs)
|
||||
return encoder
|
||||
|
||||
def forward(self, x: torch.Tensor, return_patch_embeddings: bool = False) -> TypeImageEncoder:
|
||||
"""Image encoder forward pass."""
|
||||
|
||||
x = self.encoder(x)
|
||||
x = x[-1] if isinstance(x, list) else x
|
||||
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)), 1)
|
||||
if return_patch_embeddings:
|
||||
return x, avg_pooled_emb
|
||||
|
||||
return avg_pooled_emb
|
||||
|
||||
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
|
||||
"""Workaround for enabling dilated convolutions after model initialization.
|
||||
|
||||
:param replace_stride_with_dilation: for each layer to replace the 2x2 stride with a dilated convolution
|
||||
"""
|
||||
if self.img_model_type == "resnet18":
|
||||
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
|
||||
raise NotImplementedError("resnet18 does not support dilated convolutions")
|
||||
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = False, False, True
|
||||
|
||||
device = next(self.encoder.parameters()).device
|
||||
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)
|
||||
|
||||
if self.encoder.training:
|
||||
new_encoder.train()
|
||||
else:
|
||||
new_encoder.eval()
|
||||
|
||||
new_encoder.load_state_dict(self.encoder.state_dict())
|
||||
self.encoder = new_encoder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_encoder_output_dim(module: torch.nn.Module) -> int:
|
||||
"""Calculate the output dimension of ssl encoder by making a single forward pass.
|
||||
|
||||
:param module: Encoder module.
|
||||
"""
|
||||
# Target device
|
||||
device = next(module.parameters()).device # type: ignore
|
||||
assert isinstance(device, torch.device)
|
||||
|
||||
x = torch.rand((1, 3, 448, 448)).to(device)
|
||||
|
||||
# Extract the number of output feature dimensions
|
||||
representations = module(x)
|
||||
return representations.shape[1]
|
|
@ -0,0 +1,86 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Fully connected layers to map between image embeddings and projection space where pairs of images are compared.
|
||||
|
||||
:param input_dim: Input embedding feature size
|
||||
:param hidden_dim: Hidden layer size in MLP
|
||||
:param output_dim: Output projection size
|
||||
:param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
hidden_dim: Optional[int] = None,
|
||||
use_1x1_convs: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
if use_1x1_convs:
|
||||
linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
|
||||
linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
|
||||
normalisation_layer: Callable = nn.BatchNorm2d
|
||||
projection_layer: Callable = nn.Conv2d
|
||||
else:
|
||||
linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
|
||||
linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
|
||||
normalisation_layer = nn.BatchNorm1d
|
||||
projection_layer = nn.Linear
|
||||
|
||||
self.output_dim = output_dim
|
||||
self.input_dim = input_dim
|
||||
if hidden_dim is not None:
|
||||
self.model = nn.Sequential(
|
||||
projection_layer(**linear_proj_1_args),
|
||||
normalisation_layer(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
projection_layer(**linear_proj_2_args))
|
||||
else:
|
||||
self.model = nn.Linear(input_dim, output_dim) # type: ignore
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward pass of the multi-layer perceptron"""
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiTaskModel(nn.Module):
|
||||
"""Torch module for multi-task classification heads. We create a separate classification head
|
||||
for each task and perform a forward pass on each head independently in forward(). Classification
|
||||
heads are instances of `MLP`.
|
||||
|
||||
:param input_dim: Number of dimensions of the input feature map.
|
||||
:param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
|
||||
:param num_classes: Number of output classes per task.
|
||||
:param num_tasks: Number of classification tasks or heads required.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_tasks = num_tasks
|
||||
|
||||
for task in range(num_tasks):
|
||||
# TODO check if softmax not needed here.
|
||||
setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Returns [batch_size, num_tasks, num_classes] tensor of logits."""
|
||||
batch_size = x.shape[0]
|
||||
out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
|
||||
for task in range(self.num_tasks):
|
||||
classifier = getattr(self, "fc_" + str(task))
|
||||
out[:, :, task] = classifier(x)
|
||||
return out
|
|
@ -0,0 +1,41 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
"""Text-related tools
|
||||
|
||||
.. currentmodule:: health_multimodal.text
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
inference_engine
|
||||
|
||||
|
||||
.. currentmodule:: health_multimodal.text.data
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
io
|
||||
|
||||
|
||||
.. currentmodule:: health_multimodal.text.model
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
configuration_cxrbert
|
||||
modelling_cxrbert
|
||||
"""
|
||||
|
||||
from .data.io import TypePrompts
|
||||
|
||||
|
||||
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
|
||||
|
||||
__all__ = [
|
||||
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
|
||||
"TypePrompts",
|
||||
]
|
|
@ -0,0 +1,4 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
|
@ -0,0 +1,56 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
TypePrompts = Union[str, List[str]]
|
||||
|
||||
|
||||
class TextInput:
|
||||
"""Text input class that can be used for inference and deployment.
|
||||
|
||||
Implements tokenizer related operations and ensure that input strings
|
||||
conform with the standards expected from a BERT model.
|
||||
|
||||
:param tokenizer: A BertTokenizer object.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: BertTokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def tokenize_input_prompts(self, prompts: TypePrompts, verbose: bool) -> Any:
|
||||
"""
|
||||
Tokenizes the input sentence(s) and adds special tokens as defined by the tokenizer.
|
||||
:param prompts: Either a string containing a single sentence, or a list of strings each containing
|
||||
a single sentence. Note that this method will not correctly tokenize multiple sentences if they
|
||||
are input as a single string.
|
||||
:param verbose: If set to True, will log the sentence after tokenization.
|
||||
:return: A 2D tensor containing the tokenized sentences
|
||||
"""
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
self.assert_special_tokens_not_present(" ".join(prompts))
|
||||
|
||||
prompts = [prompt.rstrip("!?.") for prompt in prompts] # removes punctuation from end of prompt
|
||||
tokenizer_output = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=prompts,
|
||||
add_special_tokens=True,
|
||||
padding='longest',
|
||||
return_tensors='pt')
|
||||
if verbose:
|
||||
for prompt in tokenizer_output.input_ids:
|
||||
input_tokens = self.tokenizer.convert_ids_to_tokens(prompt.tolist())
|
||||
logging.info(f"Input tokens: {input_tokens}")
|
||||
|
||||
return tokenizer_output
|
||||
|
||||
def assert_special_tokens_not_present(self, prompt: str) -> None:
|
||||
"""Check if the input prompts contain special tokens."""
|
||||
special_tokens = self.tokenizer.all_special_tokens
|
||||
special_tokens.remove(self.tokenizer.mask_token) # [MASK] is allowed
|
||||
if any(map(lambda token: token in prompt, special_tokens)):
|
||||
raise ValueError(f"The input \"{prompt}\" contains at least one special token ({special_tokens})")
|
|
@ -0,0 +1,113 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from transformers import BertForMaskedLM, BertTokenizer
|
||||
|
||||
from health_multimodal.text.data.io import TextInput
|
||||
|
||||
|
||||
class TextInferenceEngine(TextInput):
|
||||
"""
|
||||
Text inference class that implements functionalities required to extract
|
||||
sentence embedding, similarity and MLM prediction tasks.
|
||||
|
||||
:param tokenizer: A BertTokenizer object.
|
||||
:param text_model: Text model either default HuggingFace class
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: BertTokenizer, text_model: BertForMaskedLM) -> None:
|
||||
super().__init__(tokenizer=tokenizer)
|
||||
|
||||
assert isinstance(text_model, BertForMaskedLM), f"Expected a BertForMaskedLM, got {type(text_model)}"
|
||||
|
||||
self.model = text_model
|
||||
self.device = next(self.model.parameters()).device
|
||||
self.max_allowed_input_length = self.model.config.max_position_embeddings
|
||||
|
||||
def is_in_eval(self) -> bool:
|
||||
"""Returns True if the model is in eval mode."""
|
||||
return not self.model.training
|
||||
|
||||
def tokenize_input_prompts(self, prompts: Union[str, List[str]], verbose: bool = True) -> Any:
|
||||
tokenizer_output = super().tokenize_input_prompts(prompts, verbose=verbose)
|
||||
tokenizer_output.input_ids = tokenizer_output.input_ids.to(self.device)
|
||||
tokenizer_output.attention_mask = tokenizer_output.attention_mask.to(self.device)
|
||||
|
||||
max_length = tokenizer_output.input_ids.shape[1]
|
||||
if tokenizer_output.input_ids.shape[1] > self.max_allowed_input_length:
|
||||
raise ValueError(f"The sequence length of the input ({max_length}) is "
|
||||
f"longer than the maximum allowed sequence length ({self.max_allowed_input_length}).")
|
||||
|
||||
return tokenizer_output
|
||||
|
||||
@torch.no_grad()
|
||||
def get_embeddings_from_prompt(self, prompts: Union[str, List[str]], verbose: bool = True) -> torch.Tensor:
|
||||
"""Generate L2-normalised embeddings for a list of input text prompts.
|
||||
|
||||
:param prompts: Input text prompt(s) either in string or list of string format.
|
||||
:param verbose: If set to True, tokenized words are displayed in the console.
|
||||
:return: Tensor of shape (batch_size, embedding_size).
|
||||
"""
|
||||
|
||||
assert self.is_in_eval()
|
||||
tokenizer_output = self.tokenize_input_prompts(prompts=prompts, verbose=verbose)
|
||||
txt_emb = self.model.get_projected_text_embeddings( # type: ignore
|
||||
input_ids=tokenizer_output.input_ids,
|
||||
attention_mask=tokenizer_output.attention_mask)
|
||||
|
||||
return txt_emb
|
||||
|
||||
@torch.no_grad()
|
||||
def get_pairwise_similarities(self,
|
||||
prompt_set_1: Union[str, List[str]],
|
||||
prompt_set_2: Union[str, List[str]]) -> torch.Tensor:
|
||||
"""Compute pairwise cosine similarities between the embeddings of the given prompts."""
|
||||
|
||||
emb_1 = self.get_embeddings_from_prompt(prompts=prompt_set_1, verbose=False)
|
||||
emb_2 = self.get_embeddings_from_prompt(prompts=prompt_set_2, verbose=False)
|
||||
sim = torch.diag(torch.mm(emb_1, emb_2.t())).detach()
|
||||
|
||||
return sim
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_masked_tokens(self, prompts: Union[str, List[str]]) -> List[List[str]]:
|
||||
"""Predict masked tokens for a single or list of input text prompts.
|
||||
|
||||
Requires models to be trained with a MLM prediction head.
|
||||
|
||||
:param prompts: Input text prompt(s) either in string or list of string format.
|
||||
:return: Predicted token candidates (Top-1) at masked position.
|
||||
"""
|
||||
|
||||
assert self.is_in_eval()
|
||||
|
||||
# Tokenize the input prompts
|
||||
tokenized_prompts = self.tokenize_input_prompts(prompts)
|
||||
|
||||
# Collect all token predictions
|
||||
text_model_output = self.model.forward(input_ids=tokenized_prompts.input_ids,
|
||||
attention_mask=tokenized_prompts.attention_mask)
|
||||
logits = text_model_output.logits
|
||||
logits = logits.detach()
|
||||
|
||||
predicted_token_ids = torch.argmax(logits, dim=-1) # Batch x Seq
|
||||
|
||||
# Identify the masked token indices
|
||||
batch_size = predicted_token_ids.shape[0]
|
||||
mask_token_id = self.tokenizer.mask_token_id
|
||||
mlm_mask = tokenized_prompts.input_ids == mask_token_id # Batch x Seq
|
||||
|
||||
# Convert the predicted token ids to token strings
|
||||
output = list()
|
||||
for b in range(batch_size):
|
||||
_ids = predicted_token_ids[b, mlm_mask[b]].cpu().tolist()
|
||||
_tokens = self.tokenizer.convert_ids_to_tokens(_ids)
|
||||
output.append(_tokens)
|
||||
|
||||
return output
|
|
@ -0,0 +1,4 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
|
@ -0,0 +1,27 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import BertConfig, BertTokenizer
|
||||
|
||||
|
||||
class CXRBertConfig(BertConfig):
|
||||
"""
|
||||
Config class for CXR-BERT model.
|
||||
|
||||
:param projection_size: Dimensionality of the joint latent space.
|
||||
"""
|
||||
|
||||
model_type = "cxr-bert"
|
||||
|
||||
def __init__(self, projection_size: int = 128, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.projection_size = projection_size
|
||||
|
||||
|
||||
class CXRBertTokenizer(BertTokenizer):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
|
@ -0,0 +1,133 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch import Tensor as T
|
||||
from transformers import BertForMaskedLM
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
|
||||
|
||||
BERTTupleOutput = Tuple[T, T, T, T, T]
|
||||
|
||||
|
||||
class CXRBertOutput(ModelOutput):
|
||||
last_hidden_state: torch.FloatTensor
|
||||
logits: torch.FloatTensor
|
||||
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class BertProjectionHead(nn.Module):
|
||||
"""Projection head to be used with BERT CLS token.
|
||||
|
||||
This is similar to ``BertPredictionHeadTransform`` in HuggingFace.
|
||||
|
||||
:param config: Configuration for BERT.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CXRBertConfig) -> None:
|
||||
super().__init__()
|
||||
self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
|
||||
self.transform_act_fn = nn.functional.gelu
|
||||
self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
|
||||
self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense_to_hidden(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
hidden_states = self.dense_to_output(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CXRBertModel(BertForMaskedLM):
|
||||
"""
|
||||
Implements the CXR-BERT model outlined in the manuscript:
|
||||
Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
|
||||
https://arxiv.org/abs/2204.09817
|
||||
|
||||
Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is
|
||||
used to align the latent vectors of image and text modalities.
|
||||
"""
|
||||
|
||||
config_class = CXRBertConfig # type: ignore
|
||||
|
||||
def __init__(self, config: CXRBertConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.cls_projection_head = BertProjectionHead(config)
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_cls_projected_embedding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Any
|
||||
) -> Union[BERTTupleOutput, CXRBertOutput]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
bert_for_masked_lm_output = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=True)
|
||||
|
||||
last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
|
||||
cls_projected_embedding = self.cls_projection_head(
|
||||
last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
|
||||
|
||||
if return_dict:
|
||||
return CXRBertOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
logits=bert_for_masked_lm_output.logits,
|
||||
cls_projected_embedding=cls_projected_embedding,
|
||||
hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
|
||||
attentions=bert_for_masked_lm_output.attentions,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
last_hidden_state,
|
||||
bert_for_masked_lm_output.logits,
|
||||
cls_projected_embedding,
|
||||
bert_for_masked_lm_output.hidden_states,
|
||||
bert_for_masked_lm_output.attentions,)
|
||||
|
||||
def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
|
||||
The joint latent space is trained using a contrastive objective between image and text data modalities.
|
||||
|
||||
:param input_ids: (batch_size, sequence_length)
|
||||
:param attention_mask: (batch_size, sequence_length)
|
||||
:return: (batch_size, projection_size)
|
||||
"""
|
||||
|
||||
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
|
||||
output_cls_projected_embedding=True, return_dict=True)
|
||||
assert isinstance(outputs, CXRBertOutput)
|
||||
|
||||
assert outputs.cls_projected_embedding is not None
|
||||
normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
|
||||
return normalized_cls_embedding
|
|
@ -0,0 +1,14 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
"""Visual-language processing tools
|
||||
|
||||
.. currentmodule:: health_multimodal.vlp
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
|
||||
inference_engine
|
||||
"""
|
|
@ -0,0 +1,117 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
"""Tools related to joint image and text inference"""
|
||||
|
||||
from math import ceil, floor
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import ndimage
|
||||
|
||||
from health_multimodal.image.inference_engine import ImageInferenceEngine
|
||||
from health_multimodal.text.inference_engine import TextInferenceEngine
|
||||
|
||||
|
||||
class ImageTextInferenceEngine:
|
||||
"""
|
||||
Encapsulate functions related to inference on ImageTextModels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_inference_engine: ImageInferenceEngine,
|
||||
text_inference_engine: TextInferenceEngine) -> None:
|
||||
"""
|
||||
This class takes an ImageTextModel as well as an ImageInferenceEngine and TextInferenceEngine.
|
||||
"""
|
||||
self.image_inference_engine = image_inference_engine
|
||||
self.text_inference_engine = text_inference_engine
|
||||
|
||||
def get_similarity_map_from_raw_data(self, image_path: Path, query_text: str) -> np.ndarray:
|
||||
"""
|
||||
Return a heatmap of the similarities between each patch embedding from the image and the text embedding.
|
||||
"""
|
||||
assert not self.image_inference_engine.model.training
|
||||
assert not self.text_inference_engine.model.training
|
||||
|
||||
# TODO: Add checks in here regarding the text query, etc.
|
||||
|
||||
image_embedding, (width, height) = self.image_inference_engine.get_patch_embeddings_from_image(image_path)
|
||||
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text)
|
||||
|
||||
sim = self._get_similarity_map_from_embeddings(image_embedding, text_embedding)
|
||||
|
||||
resized_sim_map = self.convert_similarity_to_image_size(
|
||||
sim,
|
||||
width=width,
|
||||
height=height,
|
||||
resize_size=self.image_inference_engine.resize_size,
|
||||
crop_size=self.image_inference_engine.crop_size,
|
||||
val_img_transform=self.image_inference_engine.transform,
|
||||
)
|
||||
return resized_sim_map
|
||||
|
||||
@staticmethod
|
||||
def _get_similarity_map_from_embeddings(projected_patch_embeddings: torch.Tensor,
|
||||
projected_text_embeddings: torch.Tensor,
|
||||
sigma: float = 1.5) -> torch.Tensor:
|
||||
"""
|
||||
Get smoothed similarity map for a given image patch embeddings and text embeddings.
|
||||
|
||||
:param projected_patch_embeddings: [n_patches_h, n_patches_w, feature_size]
|
||||
:param projected_text_embeddings: [1, feature_size]
|
||||
:return: similarity_map: similarity map of shape [n_patches_h, n_patches_w]
|
||||
"""
|
||||
n_patches_h, n_patches_w, feature_size = projected_patch_embeddings.shape
|
||||
assert feature_size == projected_text_embeddings.shape[1]
|
||||
assert projected_text_embeddings.shape[0] == 1
|
||||
assert projected_text_embeddings.dim() == 2
|
||||
patch_wise_similarity = projected_patch_embeddings.view(-1, feature_size) @ projected_text_embeddings.t()
|
||||
patch_wise_similarity = patch_wise_similarity.reshape(n_patches_h, n_patches_w).cpu().numpy()
|
||||
smoothed_similarity_map = torch.tensor(ndimage.gaussian_filter(
|
||||
patch_wise_similarity, sigma=(sigma, sigma), order=0))
|
||||
return smoothed_similarity_map
|
||||
|
||||
@staticmethod
|
||||
def convert_similarity_to_image_size(
|
||||
similarity_map: torch.Tensor, width: int, height: int, resize_size: Optional[int],
|
||||
crop_size: Optional[int], val_img_transform: Optional[Callable] = None) -> np.ndarray:
|
||||
"""
|
||||
Convert similarity map from raw patch grid to original image size,
|
||||
taking into account whether the image has been resized and/or cropped prior to entering the network.
|
||||
"""
|
||||
n_patches_h, n_patches_w = similarity_map.shape[0], similarity_map.shape[1]
|
||||
target_shape = 1, 1, n_patches_h, n_patches_w
|
||||
smallest_dimension = min(height, width)
|
||||
|
||||
# TODO:
|
||||
# verify_resize_params(val_img_transforms, resize_size, crop_size)
|
||||
|
||||
if crop_size is not None:
|
||||
if resize_size is not None:
|
||||
cropped_size_orig_space = int(crop_size * smallest_dimension / resize_size)
|
||||
target_size = cropped_size_orig_space, cropped_size_orig_space
|
||||
else:
|
||||
target_size = crop_size, crop_size
|
||||
similarity_map = F.interpolate(
|
||||
similarity_map.reshape(target_shape),
|
||||
size=target_size,
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
margin_w, margin_h = (width - target_size[0]), (height - target_size[1])
|
||||
margins_for_pad = (floor(margin_w / 2), ceil(margin_w / 2), floor(margin_h / 2), ceil(margin_h / 2))
|
||||
similarity_map = F.pad(similarity_map[0, 0], margins_for_pad, value=float("NaN"))
|
||||
else:
|
||||
similarity_map = F.interpolate(
|
||||
similarity_map.reshape(target_shape),
|
||||
size=(height, width),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)[0, 0]
|
||||
return similarity_map.numpy()
|
|
@ -0,0 +1,131 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# -------------------------------------------------------------------------------------------\n",
|
||||
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.\n",
|
||||
"# -------------------------------------------------------------------------------------------"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tempfile\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from transformers import AutoModel, AutoTokenizer\n",
|
||||
"\n",
|
||||
"from health_multimodal.text.inference_engine import TextInferenceEngine\n",
|
||||
"from health_multimodal.image.inference_engine import ImageInferenceEngine\n",
|
||||
"from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine\n",
|
||||
"\n",
|
||||
"from health_multimodal.image.model.model import ImageModel\n",
|
||||
"from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference\n",
|
||||
"\n",
|
||||
"from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.cuda.is_available()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the text inference engine\n",
|
||||
"HUGGING_FACE_URL = \"microsoft/BiomedVLP-CXR-BERT-specialized\"\n",
|
||||
"text_inference = TextInferenceEngine(\n",
|
||||
" tokenizer=AutoTokenizer.from_pretrained(HUGGING_FACE_URL, trust_remote_code=True),\n",
|
||||
" text_model=AutoModel.from_pretrained(HUGGING_FACE_URL, trust_remote_code=True),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the image inference engine\n",
|
||||
"resnet_checkpoint_path = \"\" # add path to checkpoint here\n",
|
||||
"if not Path(resnet_checkpoint_path).is_file():\n",
|
||||
" print(\"Checkpoint file not found!\")\n",
|
||||
" resnet_checkpoint_path = None\n",
|
||||
"image_inference = ImageInferenceEngine(\n",
|
||||
" image_model=ImageModel(img_model_type=\"resnet50\", joint_feature_size=128, pretrained_model_path=resnet_checkpoint_path),\n",
|
||||
" transform=create_chest_xray_transform_for_inference(resize=512, center_crop_size=480))\n",
|
||||
"\n",
|
||||
"# Instantiate the joint inference engine\n",
|
||||
"image_text_inference = ImageTextInferenceEngine(\n",
|
||||
" image_inference_engine=image_inference,\n",
|
||||
" text_inference_engine=text_inference,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def plot_phrase_grounding(image_path: Path, text_prompt: str) -> None:\n",
|
||||
" sim_map = image_text_inference.get_similarity_map_from_raw_data(image_path=image_path, query_text=text_prompt)\n",
|
||||
" plot_phrase_grounding_similarity_map(image_path=image_path, similarity_map=sim_map)\n",
|
||||
"\n",
|
||||
"def plot_phrase_grounding_from_url(image_url: str, text_prompt: str) -> None:\n",
|
||||
" image_path = Path(tempfile.tempdir, 'downloaded_chest_xray.jpg')\n",
|
||||
" !curl -s -L -o {image_path} {image_url}\n",
|
||||
" plot_phrase_grounding(image_path, text_prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text_prompt = \"Pneumonia in the right lung\"\n",
|
||||
"image_url = \"https://prod-images-static.radiopaedia.org/images/1371188/0a1f5edc85aa58d5780928cb39b08659c1fc4d6d7c7dce2f8db1d63c7c737234_gallery.jpeg\"\n",
|
||||
"plot_phrase_grounding_from_url(image_url, text_prompt)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.7.3 ('himl-multimodal')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b10e2d33e98f46e002b38decbb3115032da80ae497861a1d67d5527569b17994"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,8 +1,5 @@
|
|||
{
|
||||
"include": [
|
||||
"health_multimodal",
|
||||
"test_multimodal",
|
||||
],
|
||||
"include": [],
|
||||
"useLibraryCodeForTypes": false,
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
huggingface-hub==0.6.0
|
||||
lightning-bolts==0.3.4
|
||||
pillow==9.0.1
|
||||
pydicom==2.2.2
|
||||
pytorch-lightning==1.5.5
|
||||
scikit-image==0.18.1
|
||||
SimpleITK==2.1.1
|
||||
torch==1.9.0
|
||||
torchvision==0.10.0
|
||||
transformers==4.17.0
|
|
@ -1,10 +1,7 @@
|
|||
black==22.1.0
|
||||
coverage==6.3.2
|
||||
flake8==4.0.1
|
||||
ipykernel==6.15.0
|
||||
ipython==7.34.0
|
||||
mypy==0.931
|
||||
pylint==2.12.2
|
||||
pycobertura==2.0.1
|
||||
pytest==6.2.2
|
||||
pytest-cov==2.11.1
|
||||
pytest-rerunfailures==10.2
|
||||
pytest-timeout==2.0.1
|
||||
papermill==2.3.4
|
||||
setuptools==59.5.0
|
||||
|
|
|
@ -62,7 +62,7 @@ install_requires = (here / 'requirements_run.txt').read_text().split("\n")
|
|||
# Remove any whitespace and blank lines
|
||||
install_requires = [line.strip() for line in install_requires if line.strip()]
|
||||
|
||||
description = 'Microsoft Health Intelligence package to work with multi-modal health data'
|
||||
description = 'Microsoft Health Futures package to work with multi-modal health data'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import SimpleITK as sitk
|
||||
from PIL import Image
|
||||
|
||||
from health_multimodal.image.data.io import load_image, remap_to_uint8
|
||||
|
||||
|
||||
def _assert_min_max_dtype(array: np.ndarray) -> None:
|
||||
assert array.min() == 0
|
||||
assert array.max() == 255
|
||||
assert array.dtype == np.uint8
|
||||
|
||||
|
||||
def test_load_image() -> None:
|
||||
"""Test the image loading function using dummy files."""
|
||||
size = 4, 4
|
||||
|
||||
def _assertions(path: Path) -> None:
|
||||
img = load_image(path)
|
||||
assert img.size == size
|
||||
array = np.asarray(img)
|
||||
_assert_min_max_dtype(array)
|
||||
|
||||
array = np.arange(np.prod(size), dtype=np.uint8).reshape(*size)
|
||||
image = Image.fromarray(array).convert('RGB')
|
||||
for suffix in '.jpg', '.jpeg', '.png':
|
||||
with NamedTemporaryFile(suffix=suffix) as file:
|
||||
image.save(file)
|
||||
_assertions(Path(file.name))
|
||||
|
||||
nifti_img = sitk.GetImageFromArray(np.arange(16, dtype=np.uint16).reshape(*size) + 100)
|
||||
with NamedTemporaryFile(suffix='.nii.gz') as file:
|
||||
sitk.WriteImage(nifti_img, file.name)
|
||||
_assertions(Path(file.name))
|
||||
|
||||
|
||||
def test_remap_to_uint8() -> None:
|
||||
"""Test the intensity casting function using different percentiles."""
|
||||
array = np.arange(10).astype(np.uint16) # mimic DICOM data type
|
||||
with pytest.raises(ValueError):
|
||||
remap_to_uint8(array, (1, 2, 3)) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
remap_to_uint8(array, (-1, 50))
|
||||
with pytest.raises(ValueError):
|
||||
remap_to_uint8(array, (1, 150))
|
||||
with pytest.raises(ValueError):
|
||||
remap_to_uint8(array, (5, 2))
|
||||
normalized = remap_to_uint8(array)
|
||||
_assert_min_max_dtype(normalized)
|
||||
normalized = remap_to_uint8(array, (1, 99))
|
||||
_assert_min_max_dtype(normalized)
|
||||
|
||||
array_positive_min = array + 5
|
||||
normalized = remap_to_uint8(array_positive_min)
|
||||
_assert_min_max_dtype(normalized)
|
||||
normalized = remap_to_uint8(array_positive_min, (1, 99))
|
||||
_assert_min_max_dtype(normalized)
|
|
@ -0,0 +1,135 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pl_bolts.models.self_supervised.resnets import resnet50
|
||||
|
||||
from health_multimodal.image.model.model import ImageEncoder, ImageModel
|
||||
from health_multimodal.image.model.modules import MultiTaskModel
|
||||
|
||||
|
||||
def test_frozen_cnn_model() -> None:
|
||||
"""
|
||||
Checks if the mode of module parameters is set correctly.
|
||||
"""
|
||||
|
||||
model = ImageModel(img_model_type='resnet18',
|
||||
joint_feature_size=4,
|
||||
num_classes=2,
|
||||
freeze_encoder=True,
|
||||
classifier_hidden_dim=24,
|
||||
num_tasks=1)
|
||||
|
||||
assert not model.encoder.training
|
||||
assert not model.projector.training
|
||||
assert isinstance(model.classifier, MultiTaskModel)
|
||||
assert model.classifier.training
|
||||
|
||||
model.train()
|
||||
assert not model.encoder.training
|
||||
assert not model.projector.training
|
||||
assert isinstance(model.classifier, MultiTaskModel)
|
||||
assert model.classifier.training
|
||||
|
||||
model.eval()
|
||||
assert not model.encoder.training
|
||||
assert not model.projector.training
|
||||
assert isinstance(model.classifier, MultiTaskModel)
|
||||
assert not model.classifier.training
|
||||
|
||||
model = ImageModel(img_model_type='resnet18',
|
||||
joint_feature_size=4,
|
||||
num_classes=2,
|
||||
freeze_encoder=False,
|
||||
classifier_hidden_dim=24,
|
||||
num_tasks=1)
|
||||
assert model.encoder.training
|
||||
assert model.projector.training
|
||||
assert model.classifier.training # type: ignore
|
||||
|
||||
|
||||
def test_image_get_patchwise_projected_embeddings() -> None:
|
||||
"""
|
||||
Checks if the image patch embeddings are correctly computed and projected to the latent space.
|
||||
"""
|
||||
|
||||
num_classes = 2
|
||||
num_tasks = 1
|
||||
joint_feature_size = 4
|
||||
model = ImageModel(img_model_type='resnet18',
|
||||
joint_feature_size=joint_feature_size,
|
||||
num_classes=num_classes,
|
||||
freeze_encoder=True,
|
||||
classifier_hidden_dim=None,
|
||||
num_tasks=num_tasks)
|
||||
model.train()
|
||||
with pytest.raises(AssertionError) as ex:
|
||||
model.get_patchwise_projected_embeddings(torch.rand(size=(2, 3, 448, 448)), normalize=True)
|
||||
assert "This function is only implemented for evaluation mode" in str(ex)
|
||||
model.eval()
|
||||
|
||||
batch_size = 2
|
||||
image = torch.rand(size=(batch_size, 3, 64, 64))
|
||||
encoder_output, _ = model.encoder.forward(image, return_patch_embeddings=True)
|
||||
h, w = encoder_output.shape[2:]
|
||||
|
||||
# First check the model output is in the expected shape,
|
||||
# since this is used internally by get_patchwise_projected_embeddings
|
||||
model_output = model.forward(image)
|
||||
assert model_output.projected_patch_embeddings.shape == (batch_size, joint_feature_size, h, w)
|
||||
assert model_output.projected_global_embedding.shape == (batch_size, joint_feature_size)
|
||||
projected_global_embedding = model_output.projected_global_embedding
|
||||
|
||||
unnormalized_patch_embeddings = model.get_patchwise_projected_embeddings(image, normalize=False)
|
||||
# Make sure the projected embeddings returned are the right shape
|
||||
assert unnormalized_patch_embeddings.shape == (batch_size, h, w, joint_feature_size)
|
||||
result_1 = torch.mean(unnormalized_patch_embeddings, dim=(1, 2)) # B x W x H x D
|
||||
result_2 = projected_global_embedding
|
||||
assert torch.allclose(result_1, result_2)
|
||||
|
||||
# test normalized version
|
||||
normalized_patch_embeddings = model.get_patchwise_projected_embeddings(image, normalize=True)
|
||||
assert normalized_patch_embeddings.shape == (batch_size, h, w, joint_feature_size)
|
||||
# Make sure the norm is 1 along the embedding dimension
|
||||
norm = normalized_patch_embeddings.norm(p=2, dim=-1)
|
||||
assert torch.all(torch.abs(norm - 1.0) < 1e-5)
|
||||
|
||||
|
||||
def test_reload_resnet_with_dilation() -> None:
|
||||
"""
|
||||
Tests if the resnet model can be switched from pooling to dilated convolutions.
|
||||
"""
|
||||
|
||||
replace_stride_with_dilation = [False, False, True]
|
||||
|
||||
# resnet18 does not support dilation
|
||||
model_with_dilation = ImageEncoder(img_model_type="resnet18")
|
||||
with pytest.raises(NotImplementedError):
|
||||
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
|
||||
|
||||
# resnet50
|
||||
original_model = ImageEncoder(img_model_type="resnet50").eval()
|
||||
model_with_dilation = ImageEncoder(img_model_type="resnet50").eval()
|
||||
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
|
||||
assert not model_with_dilation.training
|
||||
|
||||
batch_size = 2
|
||||
image = torch.rand(size=(batch_size, 3, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dilation, _ = model_with_dilation(image, return_patch_embeddings=True)
|
||||
outputs_original, _ = original_model(image, return_patch_embeddings=True)
|
||||
assert outputs_original.shape[2] * \
|
||||
2 == outputs_dilation.shape[2], "The dilation model should return larger feature maps."
|
||||
|
||||
expected_model = resnet50(return_all_feature_maps=True, pretrained=True,
|
||||
replace_stride_with_dilation=replace_stride_with_dilation)
|
||||
|
||||
expected_model.eval()
|
||||
with torch.no_grad():
|
||||
expected_output = expected_model(image)[-1]
|
||||
assert torch.allclose(outputs_dilation, expected_output)
|
|
@ -0,0 +1,27 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED, TypePrompts
|
||||
from health_multimodal.text.inference_engine import TextInferenceEngine
|
||||
|
||||
|
||||
text_inference = TextInferenceEngine(
|
||||
tokenizer=AutoTokenizer.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
|
||||
text_model=AutoModel.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompts", ("", "hello", "this is a test", ["this is", "also a test"]))
|
||||
def test_good_prompts(prompts: TypePrompts) -> None:
|
||||
text_inference.tokenize_input_prompts(prompts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompts", ("[CLS]", "hello [PAD]"))
|
||||
def test_bad_prompts(prompts: TypePrompts) -> None:
|
||||
with pytest.raises(ValueError, match="The input .* contains at least one special token"):
|
||||
text_inference.tokenize_input_prompts(prompts)
|
|
@ -0,0 +1,75 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
|
||||
from health_multimodal.text.model.modelling_cxrbert import CXRBertModel
|
||||
|
||||
|
||||
def test_model_instantiation() -> None:
|
||||
|
||||
def _test_model_forward(model: CXRBertModel) -> None:
|
||||
batch_size = 2
|
||||
seq_length = 5
|
||||
input_ids = torch.randint(0, 512, (batch_size, seq_length))
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model(input_ids, attention_mask, output_cls_projected_embedding=True)
|
||||
assert outputs.last_hidden_state.shape == (batch_size, seq_length, config.hidden_size)
|
||||
assert outputs.cls_projected_embedding.shape == (batch_size, config.projection_size)
|
||||
|
||||
projected_embeddings = model.get_projected_text_embeddings(input_ids, attention_mask)
|
||||
assert projected_embeddings.shape == (batch_size, config.projection_size)
|
||||
norm = torch.norm(projected_embeddings[0], p=2).item()
|
||||
assert pytest.approx(norm) == 1
|
||||
|
||||
outputs = model(input_ids, attention_mask, output_hidden_states=False)
|
||||
assert outputs.hidden_states is None
|
||||
|
||||
outputs_in_tuple = model(input_ids, attention_mask, return_dict=False)
|
||||
assert outputs.cls_projected_embedding == outputs_in_tuple[2]
|
||||
assert torch.allclose(outputs.last_hidden_state, outputs_in_tuple[0])
|
||||
|
||||
config = CXRBertConfig(hidden_size=6,
|
||||
projection_size=4,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=2,
|
||||
output_attentions=True,
|
||||
return_dict=True)
|
||||
model = CXRBertModel(config)
|
||||
model = model.eval()
|
||||
_test_model_forward(model=model)
|
||||
|
||||
# Try saving this model and check the saved model/config
|
||||
with TemporaryDirectory() as save_dir_as_str:
|
||||
save_dir = Path(save_dir_as_str)
|
||||
model.save_pretrained(save_dir)
|
||||
|
||||
weights_file = save_dir / WEIGHTS_NAME
|
||||
assert weights_file.exists()
|
||||
saved_weights = torch.load(weights_file)
|
||||
# Make sure the MLM head was saved
|
||||
assert "cls.predictions.bias" in saved_weights
|
||||
# Make sure the project head was saved
|
||||
assert "cls_projection_head.dense_to_hidden.weight" in saved_weights
|
||||
|
||||
# Check the config file
|
||||
config_file = save_dir / CONFIG_NAME
|
||||
assert config_file.exists()
|
||||
with config_file.open() as f:
|
||||
config_json = json.load(f)
|
||||
assert "projection_size" in config_json
|
||||
assert config_json["projection_size"] == config.projection_size
|
||||
assert config_json["architectures"] == ["CXRBertModel"]
|
||||
|
||||
# Make sure we can load from the saved model
|
||||
model_from_pretrained = CXRBertModel.from_pretrained(save_dir)
|
||||
_test_model_forward(model=model_from_pretrained)
|
|
@ -0,0 +1,119 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
from health_multimodal.text.inference_engine import TextInferenceEngine
|
||||
from health_multimodal.text.model.modelling_cxrbert import CXRBertModel
|
||||
from health_multimodal.text.model.configuration_cxrbert import CXRBertTokenizer
|
||||
|
||||
|
||||
def test_text_inference_init_model_type() -> None:
|
||||
"""
|
||||
Test that init fails if the wrong model type is passed in
|
||||
"""
|
||||
tokenizer, _ = _get_cxr_bert()
|
||||
false_model = torch.nn.Linear(4, 4)
|
||||
with pytest.raises(AssertionError) as ex:
|
||||
TextInferenceEngine(tokenizer=tokenizer, text_model=false_model) # type: ignore[arg-type]
|
||||
assert f"Expected a BertForMaskedLM, got {type(false_model)}" in str(ex)
|
||||
|
||||
|
||||
def test_l2_normalization() -> None:
|
||||
"""
|
||||
Test that the text embeddings (CLS token) are l2 normalized.
|
||||
"""
|
||||
tokenizer, text_model = _get_cxr_bert()
|
||||
|
||||
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
|
||||
input_query = ["There is a tumor in the left lung", "Lungs are all clear"]
|
||||
embedding = text_inference.get_embeddings_from_prompt(prompts=input_query)
|
||||
norm = torch.norm(embedding, p=2, dim=-1)
|
||||
assert torch.allclose(norm, torch.ones_like(norm))
|
||||
|
||||
|
||||
def test_sentence_semantic_similarity() -> None:
|
||||
"""
|
||||
Test that the sentence embedding similarity computed by the text model is meaningful.
|
||||
"""
|
||||
tokenizer, text_model = _get_cxr_bert()
|
||||
|
||||
# CLS token has no dedicated meaning, but we can expect vector similarity due to token overlap between the sentences
|
||||
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
|
||||
input_query = ["There is a tumor in the left lung", "Tumor is present", "Patient is admitted to the hospital today"]
|
||||
embedding = text_inference.get_embeddings_from_prompt(input_query)
|
||||
pos_sim = torch.dot(embedding[0], embedding[1])
|
||||
neg_sim_1 = torch.dot(embedding[0], embedding[2])
|
||||
neg_sim_2 = torch.dot(embedding[1], embedding[2])
|
||||
assert pos_sim > neg_sim_1
|
||||
assert pos_sim > neg_sim_2
|
||||
|
||||
|
||||
def _get_cxr_bert() -> Tuple[CXRBertTokenizer, CXRBertModel]:
|
||||
model_name = BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
tokenizer = CXRBertTokenizer.from_pretrained(model_name)
|
||||
text_model = CXRBertModel.from_pretrained(model_name)
|
||||
return tokenizer, text_model
|
||||
|
||||
|
||||
def test_triplet_similarities_with_inference_engine() -> None:
|
||||
"""
|
||||
Test that the triplet sentence similarities computed by the text model are meaningful.
|
||||
"""
|
||||
|
||||
tokenizer, text_model = _get_cxr_bert()
|
||||
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
|
||||
|
||||
reference = \
|
||||
["Heart size is top normal.", "There is no pneumothorax or pleural effusion",
|
||||
"The patient has been extubated.", "The lungs are clear bilaterally.",
|
||||
"No pleural effusions."]
|
||||
synonyms = \
|
||||
["The cardiac silhouette is normal in size.", "No pleural effusion or pneumothorax is seen",
|
||||
"There has been interval extubation.", "The lungs are unremarkable.",
|
||||
"Also, the lateral pleural sinuses are free, which excludes major pleural effusion."]
|
||||
contradictions = \
|
||||
["The heart is largely enlarged.", "The extent of the pleural effusion is constant.",
|
||||
"The patient is intubated", "The lungs are mostly clear aside from lower lung atelectasis.",
|
||||
"The loculated right pleural effusion has increased, and is now moderate in size."]
|
||||
|
||||
synonym_score = text_inference.get_pairwise_similarities(reference, synonyms)
|
||||
contradictory_score = text_inference.get_pairwise_similarities(reference, contradictions)
|
||||
|
||||
print("Synonym score:", synonym_score)
|
||||
print("Contradictory score:", contradictory_score)
|
||||
|
||||
assert torch.all(synonym_score > contradictory_score)
|
||||
assert torch.all(synonym_score > 0.5)
|
||||
assert torch.all(1.0 >= synonym_score)
|
||||
assert torch.all(contradictory_score < 0.5)
|
||||
assert torch.all(contradictory_score >= -1.0)
|
||||
|
||||
|
||||
def test_mlm_with_inference_engine_with_hf_hub() -> None:
|
||||
"""
|
||||
Test that the MLM model can be used with the inference engine and the filled masked tokens are correct.
|
||||
"""
|
||||
|
||||
tokenizer, text_model = _get_cxr_bert()
|
||||
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
|
||||
|
||||
# ##### Test Masked Language Modelling ######
|
||||
mlm_prompts = ["Moderate [MASK] pleural effusions and associated [MASK]",
|
||||
"Right basilar [MASK], potentially due to infiltrate in the proper clinical setting",
|
||||
"The right basilar chest [MASK] appears to be in unchanged position",
|
||||
"The small basal pneumothorax has slightly [MASK] compared to the prior",
|
||||
"Poorly defined [MASK] in the right lung is concerning for aspiration",
|
||||
"Retrocardiac opacity likely reflects known hiatal [MASK]"]
|
||||
|
||||
target_tokens = [['bilateral', 'atelectasis'], ['opacity'], ['tube'], ['increased'], ['opacity'], ['hernia']]
|
||||
|
||||
output_top_1 = text_inference.predict_masked_tokens(mlm_prompts)
|
||||
assert output_top_1 == target_tokens
|
|
@ -0,0 +1,76 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from health_multimodal.text import BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
from health_multimodal.text.inference_engine import TextInferenceEngine
|
||||
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
|
||||
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
|
||||
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
|
||||
|
||||
|
||||
text_inference = TextInferenceEngine(
|
||||
tokenizer=AutoTokenizer.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
|
||||
text_model=AutoModel.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, trust_remote_code=True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("height", (400, 500, 650))
|
||||
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
|
||||
def test_vlp_inference(height: int, query_text: str) -> None:
|
||||
image_embedding_shapes = {
|
||||
480: (15, 15),
|
||||
}
|
||||
|
||||
joint_feature_size = 128
|
||||
resize = 512
|
||||
center_crop_size = 480
|
||||
|
||||
width = 600
|
||||
image_inference = ImageInferenceEngine(
|
||||
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
|
||||
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
|
||||
img_txt_inference = ImageTextInferenceEngine(
|
||||
image_inference_engine=image_inference,
|
||||
text_inference_engine=text_inference,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
|
||||
image_path = Path(f.name)
|
||||
image = Image.new('RGB', (width, height))
|
||||
image.save(image_path)
|
||||
|
||||
# Test integrated VLP inference engine
|
||||
resampled_similarity_map = img_txt_inference.get_similarity_map_from_raw_data(
|
||||
image_path=image_path,
|
||||
query_text=query_text,
|
||||
)
|
||||
assert resampled_similarity_map.shape == (height, width)
|
||||
np.nan_to_num(resampled_similarity_map, copy=False)
|
||||
assert resampled_similarity_map.min() >= -1
|
||||
assert resampled_similarity_map.max() <= 1
|
||||
|
||||
# Test individual components
|
||||
image_embedding, size = img_txt_inference.image_inference_engine.get_patch_embeddings_from_image(image_path)
|
||||
|
||||
assert (width, height) == size
|
||||
expected_image_embedding_size = image_embedding_shapes[center_crop_size]
|
||||
assert image_embedding.shape == (*expected_image_embedding_size, joint_feature_size)
|
||||
normalized_image_embedding = torch.norm(image_embedding, p=2, dim=-1)
|
||||
assert torch.allclose(normalized_image_embedding, torch.ones_like(normalized_image_embedding))
|
||||
|
||||
text_embedding = img_txt_inference.text_inference_engine.get_embeddings_from_prompt(query_text)
|
||||
assert text_embedding.shape == (1, joint_feature_size)
|
||||
|
||||
similarity_map = img_txt_inference._get_similarity_map_from_embeddings(image_embedding, text_embedding)
|
||||
assert similarity_map.shape == expected_image_embedding_size
|
|
@ -1,3 +1,4 @@
|
|||
-r build_requirements.txt
|
||||
-r hi-ml/run_requirements.txt
|
||||
-r hi-ml-azure/run_requirements.txt
|
||||
-r multimodal/requirements_run.txt
|
||||
|
|
Загрузка…
Ссылка в новой задаче