зеркало из https://github.com/microsoft/hi-ml.git
MNT: Update `hi-ml-multimodal` (#928)
Users (including ourselves) keep having issues installing `hi-ml-multimodal`. I've unpinned the requirements so it's easy to install and use. I've had to update some bits that were using ancient versions of e.g. TorchVision. I've checked the notebook and it's working fine. Related issues: - https://github.com/microsoft/hi-ml/issues/927 - https://github.com/microsoft/hi-ml/issues/916 - https://github.com/microsoft/hi-ml/issues/850 Fixes #850. Fixes #916. Fixes #927.
This commit is contained in:
Родитель
f4e21b78af
Коммит
61a2c4d330
|
@ -1,9 +1,10 @@
|
||||||
huggingface-hub==0.6.0
|
huggingface-hub
|
||||||
pillow==10.0.1
|
matplotlib
|
||||||
pydicom==2.2.2
|
Pillow
|
||||||
scikit-image==0.18.1
|
pydicom
|
||||||
SimpleITK==2.1.1
|
scikit-image
|
||||||
timm==0.6.5
|
SimpleITK
|
||||||
torch==1.9.0
|
timm
|
||||||
torchvision>0.9,<=0.10.0
|
torch
|
||||||
transformers==4.17.0
|
torchvision
|
||||||
|
transformers
|
||||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from torch.hub import load_state_dict_from_url
|
from torch.hub import load_state_dict_from_url
|
||||||
from torchvision.datasets.utils import download_url
|
from torchvision.datasets.utils import download_url
|
||||||
from torchvision.models.resnet import model_urls
|
from torchvision.models.resnet import ResNet50_Weights
|
||||||
|
|
||||||
from .model import ImageModel
|
from .model import ImageModel
|
||||||
from .types import ImageEncoderType, ImageEncoderWeightTypes
|
from .types import ImageEncoderType, ImageEncoderWeightTypes
|
||||||
|
@ -75,7 +75,7 @@ def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
|
||||||
return image_model
|
return image_model
|
||||||
|
|
||||||
|
|
||||||
def get_biovil_t_image_encoder() -> ImageModel:
|
def get_biovil_t_image_encoder(**kwargs) -> ImageModel:
|
||||||
"""Download weights from Hugging Face and instantiate the image model."""
|
"""Download weights from Hugging Face and instantiate the image model."""
|
||||||
|
|
||||||
biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
|
biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
|
||||||
|
@ -84,14 +84,15 @@ def get_biovil_t_image_encoder() -> ImageModel:
|
||||||
img_encoder_type=model_type,
|
img_encoder_type=model_type,
|
||||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||||
pretrained_model_path=biovilt_checkpoint_path,
|
pretrained_model_path=biovilt_checkpoint_path,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return image_model
|
return image_model
|
||||||
|
|
||||||
|
|
||||||
def get_imagenet_init_encoder() -> ImageModel:
|
def get_imagenet_init_encoder() -> ImageModel:
|
||||||
"""Download ImageNet pre-trained weights and instantiate the image model."""
|
"""Download ImageNet pre-trained weights and instantiate the image model."""
|
||||||
|
url = ResNet50_Weights.IMAGENET1K_V1.url
|
||||||
state_dict = load_state_dict_from_url(model_urls[ImageEncoderType.RESNET50])
|
state_dict = load_state_dict_from_url(url)
|
||||||
image_model = ImageModel(
|
image_model = ImageModel(
|
||||||
img_encoder_type=ImageEncoderType.RESNET50,
|
img_encoder_type=ImageEncoderType.RESNET50,
|
||||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||||
|
|
|
@ -7,7 +7,9 @@ from typing import Any, List, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.hub import load_state_dict_from_url
|
from torch.hub import load_state_dict_from_url
|
||||||
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
|
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
|
||||||
|
from torchvision.models.resnet import ResNet18_Weights
|
||||||
|
from torchvision.models.resnet import ResNet50_Weights
|
||||||
|
|
||||||
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
@ -49,7 +51,7 @@ class ResNetHIML(ResNet):
|
||||||
|
|
||||||
|
|
||||||
def _resnet(
|
def _resnet(
|
||||||
arch: str,
|
url: str,
|
||||||
block: Type[Union[BasicBlock, Bottleneck]],
|
block: Type[Union[BasicBlock, Bottleneck]],
|
||||||
layers: List[int],
|
layers: List[int],
|
||||||
pretrained: bool,
|
pretrained: bool,
|
||||||
|
@ -62,7 +64,7 @@ def _resnet(
|
||||||
"""
|
"""
|
||||||
model = ResNetHIML(block=block, layers=layers, **kwargs)
|
model = ResNetHIML(block=block, layers=layers, **kwargs)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
state_dict = load_state_dict_from_url(url, progress=progress)
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -74,7 +76,8 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
||||||
:param pretrained: If ``True``, returns a model pre-trained on ImageNet.
|
:param pretrained: If ``True``, returns a model pre-trained on ImageNet.
|
||||||
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
||||||
"""
|
"""
|
||||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
url = ResNet18_Weights.IMAGENET1K_V1.url
|
||||||
|
return _resnet(url, BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
|
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
|
||||||
|
@ -84,4 +87,5 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
||||||
:param pretrained: If ``True``, returns a model pre-trained on ImageNet
|
:param pretrained: If ``True``, returns a model pre-trained on ImageNet
|
||||||
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
:param progress: If ``True``, displays a progress bar of the download to ``stderr``.
|
||||||
"""
|
"""
|
||||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
url = ResNet50_Weights.IMAGENET1K_V1.url
|
||||||
|
return _resnet(url, Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from typing import Any, Callable, Optional, Set, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
||||||
from transformers.pytorch_utils import torch_int_div
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -265,7 +264,7 @@ class SinePositionEmbedding:
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
||||||
|
|
||||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
|
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
|
||||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,9 +18,10 @@ from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
|
||||||
BERTTupleOutput = Tuple[T, T, T, T, T]
|
BERTTupleOutput = Tuple[T, T, T, T, T]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class CXRBertOutput(ModelOutput):
|
class CXRBertOutput(ModelOutput):
|
||||||
last_hidden_state: torch.FloatTensor
|
last_hidden_state: torch.FloatTensor
|
||||||
logits: torch.FloatTensor
|
logits: Optional[torch.FloatTensor] = None
|
||||||
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
Загрузка…
Ссылка в новой задаче