MNT: Update `hi-ml-multimodal` (#928)

Users (including ourselves) keep having issues installing
`hi-ml-multimodal`. I've unpinned the requirements so it's easy to
install and use. I've had to update some bits that were using ancient
versions of e.g. TorchVision. I've checked the notebook and it's working
fine.

Related issues:
- https://github.com/microsoft/hi-ml/issues/927
- https://github.com/microsoft/hi-ml/issues/916
- https://github.com/microsoft/hi-ml/issues/850

Fixes #850.
Fixes #916.
Fixes #927.
This commit is contained in:
Fernando Pérez-García 2024-04-11 10:16:03 +01:00 коммит произвёл GitHub
Родитель f4e21b78af
Коммит 61a2c4d330
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 28 добавлений и 21 удалений

Просмотреть файл

@ -1,9 +1,10 @@
huggingface-hub==0.6.0 huggingface-hub
pillow==10.0.1 matplotlib
pydicom==2.2.2 Pillow
scikit-image==0.18.1 pydicom
SimpleITK==2.1.1 scikit-image
timm==0.6.5 SimpleITK
torch==1.9.0 timm
torchvision>0.9,<=0.10.0 torch
transformers==4.17.0 torchvision
transformers

Просмотреть файл

@ -10,7 +10,7 @@ from pathlib import Path
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
from torchvision.models.resnet import model_urls from torchvision.models.resnet import ResNet50_Weights
from .model import ImageModel from .model import ImageModel
from .types import ImageEncoderType, ImageEncoderWeightTypes from .types import ImageEncoderType, ImageEncoderWeightTypes
@ -75,7 +75,7 @@ def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
return image_model return image_model
def get_biovil_t_image_encoder() -> ImageModel: def get_biovil_t_image_encoder(**kwargs) -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model.""" """Download weights from Hugging Face and instantiate the image model."""
biovilt_checkpoint_path = _download_biovil_t_image_model_weights() biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
@ -84,14 +84,15 @@ def get_biovil_t_image_encoder() -> ImageModel:
img_encoder_type=model_type, img_encoder_type=model_type,
joint_feature_size=JOINT_FEATURE_SIZE, joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=biovilt_checkpoint_path, pretrained_model_path=biovilt_checkpoint_path,
**kwargs,
) )
return image_model return image_model
def get_imagenet_init_encoder() -> ImageModel: def get_imagenet_init_encoder() -> ImageModel:
"""Download ImageNet pre-trained weights and instantiate the image model.""" """Download ImageNet pre-trained weights and instantiate the image model."""
url = ResNet50_Weights.IMAGENET1K_V1.url
state_dict = load_state_dict_from_url(model_urls[ImageEncoderType.RESNET50]) state_dict = load_state_dict_from_url(url)
image_model = ImageModel( image_model = ImageModel(
img_encoder_type=ImageEncoderType.RESNET50, img_encoder_type=ImageEncoderType.RESNET50,
joint_feature_size=JOINT_FEATURE_SIZE, joint_feature_size=JOINT_FEATURE_SIZE,

Просмотреть файл

@ -7,7 +7,9 @@ from typing import Any, List, Tuple, Type, Union
import torch import torch
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet18_Weights
from torchvision.models.resnet import ResNet50_Weights
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
@ -49,7 +51,7 @@ class ResNetHIML(ResNet):
def _resnet( def _resnet(
arch: str, url: str,
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int], layers: List[int],
pretrained: bool, pretrained: bool,
@ -62,7 +64,7 @@ def _resnet(
""" """
model = ResNetHIML(block=block, layers=layers, **kwargs) model = ResNetHIML(block=block, layers=layers, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) state_dict = load_state_dict_from_url(url, progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
@ -74,7 +76,8 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
:param pretrained: If ``True``, returns a model pre-trained on ImageNet. :param pretrained: If ``True``, returns a model pre-trained on ImageNet.
:param progress: If ``True``, displays a progress bar of the download to ``stderr``. :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
""" """
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) url = ResNet18_Weights.IMAGENET1K_V1.url
return _resnet(url, BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML: def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
@ -84,4 +87,5 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
:param pretrained: If ``True``, returns a model pre-trained on ImageNet :param pretrained: If ``True``, returns a model pre-trained on ImageNet
:param progress: If ``True``, displays a progress bar of the download to ``stderr``. :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
""" """
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) url = ResNet50_Weights.IMAGENET1K_V1.url
return _resnet(url, Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Any, Callable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_ from timm.models.layers import DropPath, Mlp, trunc_normal_
from transformers.pytorch_utils import torch_int_div
@dataclass @dataclass
@ -265,7 +264,7 @@ class SinePositionEmbedding:
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32) dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
@ -17,9 +18,10 @@ from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
BERTTupleOutput = Tuple[T, T, T, T, T] BERTTupleOutput = Tuple[T, T, T, T, T]
@dataclass
class CXRBertOutput(ModelOutput): class CXRBertOutput(ModelOutput):
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor
logits: torch.FloatTensor logits: Optional[torch.FloatTensor] = None
cls_projected_embedding: Optional[torch.FloatTensor] = None cls_projected_embedding: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None