MNT: Update `hi-ml-multimodal` (#928)

Users (including ourselves) keep having issues installing
`hi-ml-multimodal`. I've unpinned the requirements so it's easy to
install and use. I've had to update some bits that were using ancient
versions of e.g. TorchVision. I've checked the notebook and it's working
fine.

Related issues:
- https://github.com/microsoft/hi-ml/issues/927
- https://github.com/microsoft/hi-ml/issues/916
- https://github.com/microsoft/hi-ml/issues/850

Fixes #850.
Fixes #916.
Fixes #927.
This commit is contained in:
Fernando Pérez-García 2024-04-11 10:16:03 +01:00 коммит произвёл GitHub
Родитель f4e21b78af
Коммит 61a2c4d330
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 28 добавлений и 21 удалений

Просмотреть файл

@ -1,9 +1,10 @@
huggingface-hub==0.6.0
pillow==10.0.1
pydicom==2.2.2
scikit-image==0.18.1
SimpleITK==2.1.1
timm==0.6.5
torch==1.9.0
torchvision>0.9,<=0.10.0
transformers==4.17.0
huggingface-hub
matplotlib
Pillow
pydicom
scikit-image
SimpleITK
timm
torch
torchvision
transformers

Просмотреть файл

@ -10,7 +10,7 @@ from pathlib import Path
from torch.hub import load_state_dict_from_url
from torchvision.datasets.utils import download_url
from torchvision.models.resnet import model_urls
from torchvision.models.resnet import ResNet50_Weights
from .model import ImageModel
from .types import ImageEncoderType, ImageEncoderWeightTypes
@ -75,7 +75,7 @@ def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
return image_model
def get_biovil_t_image_encoder() -> ImageModel:
def get_biovil_t_image_encoder(**kwargs) -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model."""
biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
@ -84,14 +84,15 @@ def get_biovil_t_image_encoder() -> ImageModel:
img_encoder_type=model_type,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=biovilt_checkpoint_path,
**kwargs,
)
return image_model
def get_imagenet_init_encoder() -> ImageModel:
"""Download ImageNet pre-trained weights and instantiate the image model."""
state_dict = load_state_dict_from_url(model_urls[ImageEncoderType.RESNET50])
url = ResNet50_Weights.IMAGENET1K_V1.url
state_dict = load_state_dict_from_url(url)
image_model = ImageModel(
img_encoder_type=ImageEncoderType.RESNET50,
joint_feature_size=JOINT_FEATURE_SIZE,

Просмотреть файл

@ -7,7 +7,9 @@ from typing import Any, List, Tuple, Type, Union
import torch
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet18_Weights
from torchvision.models.resnet import ResNet50_Weights
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
@ -49,7 +51,7 @@ class ResNetHIML(ResNet):
def _resnet(
arch: str,
url: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
@ -62,7 +64,7 @@ def _resnet(
"""
model = ResNetHIML(block=block, layers=layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
state_dict = load_state_dict_from_url(url, progress=progress)
model.load_state_dict(state_dict)
return model
@ -74,7 +76,8 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
:param pretrained: If ``True``, returns a model pre-trained on ImageNet.
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
url = ResNet18_Weights.IMAGENET1K_V1.url
return _resnet(url, BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
@ -84,4 +87,5 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
:param pretrained: If ``True``, returns a model pre-trained on ImageNet
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
url = ResNet50_Weights.IMAGENET1K_V1.url
return _resnet(url, Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Any, Callable, Optional, Set, Tuple
import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_
from transformers.pytorch_utils import torch_int_div
@dataclass
@ -265,7 +264,7 @@ class SinePositionEmbedding:
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
@ -17,9 +18,10 @@ from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
BERTTupleOutput = Tuple[T, T, T, T, T]
@dataclass
class CXRBertOutput(ModelOutput):
last_hidden_state: torch.FloatTensor
logits: torch.FloatTensor
logits: Optional[torch.FloatTensor] = None
cls_projected_embedding: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None