зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add hubconf to load models without installing (#543)
* Add hubconf file
* Refactor to minimise hubconf dependencies
* Pin hubconf dependencies
* Revert "Pin hubconf dependencies"
This reverts commit bc904a963e
as it
didn't seem to work.
* Add support for newer versions of torch
* Add only the model folder to path
* Remove unnecessary try-except block
* Avoid duplicate definition of Hugging Face strings
* Import from a more appropriate module
* Add test to compare package and PyTorch Hub models
* Add version number to package __init__
* Remove branch name from PyTorch Hub repo string
* Check only fields from package model
* Remove unnecessary zip wrap
Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
This commit is contained in:
Родитель
556c39594a
Коммит
bdbbf3e812
|
@ -3,20 +3,4 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
|
||||
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
|
||||
CXR_BERT_COMMIT_TAG = "v1.1"
|
||||
|
||||
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
|
||||
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
|
||||
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
|
||||
"REPO_URL",
|
||||
"CXR_BERT_COMMIT_TAG",
|
||||
"BIOVIL_IMAGE_WEIGHTS_NAME",
|
||||
"BIOVIL_IMAGE_WEIGHTS_URL",
|
||||
"BIOVIL_IMAGE_WEIGHTS_MD5",
|
||||
]
|
||||
__version__ = "0.1.0"
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
|
||||
from .model import ImageModel
|
||||
from .model import ResnetType
|
||||
from .model import get_biovil_resnet
|
||||
from .inference_engine import ImageInferenceEngine
|
||||
from .utils import get_biovil_resnet_inference
|
||||
|
||||
|
@ -44,5 +45,6 @@ __all__ = [
|
|||
"ImageModel",
|
||||
"ResnetType",
|
||||
"ImageInferenceEngine",
|
||||
"get_biovil_resnet",
|
||||
"get_biovil_resnet_inference",
|
||||
]
|
||||
|
|
|
@ -5,9 +5,14 @@
|
|||
|
||||
from .model import ImageModel
|
||||
from .model import ResnetType
|
||||
|
||||
from .model import get_biovil_resnet
|
||||
from .model import CXR_BERT_COMMIT_TAG
|
||||
from .model import BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
|
||||
__all__ = [
|
||||
"ImageModel",
|
||||
"ResnetType",
|
||||
"get_biovil_resnet",
|
||||
"CXR_BERT_COMMIT_TAG",
|
||||
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
|
||||
]
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union, Sequence
|
||||
|
@ -11,11 +14,49 @@ from typing import Any, Optional, Tuple, Union, Sequence
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from .resnet import resnet18, resnet50
|
||||
from .modules import MLP, MultiTaskModel
|
||||
|
||||
TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
MODEL_TYPE = "resnet50"
|
||||
JOINT_FEATURE_SIZE = 128
|
||||
|
||||
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
|
||||
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
|
||||
CXR_BERT_COMMIT_TAG = "v1.1"
|
||||
|
||||
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
|
||||
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
|
||||
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
|
||||
|
||||
|
||||
def _download_biovil_image_model_weights() -> Path:
|
||||
"""Download image model weights from Hugging Face.
|
||||
|
||||
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
|
||||
"""
|
||||
root_dir = tempfile.gettempdir()
|
||||
download_url(
|
||||
BIOVIL_IMAGE_WEIGHTS_URL,
|
||||
root=root_dir,
|
||||
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
|
||||
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
|
||||
)
|
||||
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
|
||||
|
||||
|
||||
def get_biovil_resnet(pretrained: bool = True) -> ImageModel:
|
||||
"""Download weights from Hugging Face and instantiate the image model."""
|
||||
resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None
|
||||
|
||||
image_model = ImageModel(
|
||||
img_model_type=MODEL_TYPE,
|
||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||
pretrained_model_path=resnet_checkpoint_path,
|
||||
)
|
||||
return image_model
|
||||
|
||||
|
||||
@enum.unique
|
||||
|
|
|
@ -6,9 +6,8 @@
|
|||
from typing import Any, List, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
|
||||
class ResNetHIML(ResNet):
|
||||
|
|
|
@ -3,51 +3,15 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from .. import BIOVIL_IMAGE_WEIGHTS_NAME
|
||||
from .. import BIOVIL_IMAGE_WEIGHTS_URL
|
||||
from .. import BIOVIL_IMAGE_WEIGHTS_MD5
|
||||
from .model import ImageModel
|
||||
from .inference_engine import ImageInferenceEngine
|
||||
from .data.transforms import create_chest_xray_transform_for_inference
|
||||
from .model import get_biovil_resnet
|
||||
|
||||
|
||||
MODEL_TYPE = "resnet50"
|
||||
JOINT_FEATURE_SIZE = 128
|
||||
TRANSFORM_RESIZE = 512
|
||||
TRANSFORM_CENTER_CROP_SIZE = 480
|
||||
|
||||
|
||||
def _download_biovil_image_model_weights() -> Path:
|
||||
"""Download image model weights from Hugging Face.
|
||||
|
||||
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
|
||||
"""
|
||||
root_dir = tempfile.gettempdir()
|
||||
download_url(
|
||||
BIOVIL_IMAGE_WEIGHTS_URL,
|
||||
root=root_dir,
|
||||
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
|
||||
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
|
||||
)
|
||||
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
|
||||
|
||||
|
||||
def get_biovil_resnet() -> ImageModel:
|
||||
"""Download weights from Hugging Face and instantiate the image model."""
|
||||
resnet_checkpoint_path = _download_biovil_image_model_weights()
|
||||
image_model = ImageModel(
|
||||
img_model_type=MODEL_TYPE,
|
||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||
pretrained_model_path=resnet_checkpoint_path,
|
||||
)
|
||||
return image_model
|
||||
|
||||
|
||||
def get_biovil_resnet_inference() -> ImageInferenceEngine:
|
||||
"""Create a :class:`ImageInferenceEngine` for the image model.
|
||||
|
||||
|
|
|
@ -6,9 +6,8 @@
|
|||
|
||||
from typing import Tuple
|
||||
|
||||
from .. import BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
from .. import CXR_BERT_COMMIT_TAG
|
||||
|
||||
from ..image.model import CXR_BERT_COMMIT_TAG
|
||||
from ..image.model import BIOMED_VLP_CXR_BERT_SPECIALIZED
|
||||
from .inference_engine import TextInferenceEngine
|
||||
from .model import CXRBertModel
|
||||
from .model import CXRBertTokenizer
|
||||
|
|
|
@ -3,11 +3,15 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from health_multimodal.image.model.model import ImageEncoder, ImageModel
|
||||
from health_multimodal.image.model.model import ImageModel
|
||||
from health_multimodal.image.model.model import ImageEncoder
|
||||
from health_multimodal.image.model.model import ImageModelOutput
|
||||
from health_multimodal.image.model.model import get_biovil_resnet
|
||||
from health_multimodal.image.model.modules import MultiTaskModel
|
||||
from health_multimodal.image.model.resnet import resnet50
|
||||
|
||||
|
@ -132,3 +136,23 @@ def test_reload_resnet_with_dilation() -> None:
|
|||
with torch.no_grad():
|
||||
expected_output = expected_model(image)
|
||||
assert torch.allclose(outputs_dilation, expected_output)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_hubconf() -> None:
|
||||
"""Test that instantiating the image model using the PyTorch Hub is consistent with older methods."""
|
||||
image = torch.rand(1, 3, 480, 480)
|
||||
|
||||
github = 'microsoft/hi-ml'
|
||||
model_hub = torch.hub.load(github, 'biovil_resnet', pretrained=True)
|
||||
model_himl = get_biovil_resnet()
|
||||
|
||||
output_hub: ImageModelOutput = model_hub(image)
|
||||
output_himl: ImageModelOutput = model_himl(image)
|
||||
|
||||
for field_himl in fields(output_himl):
|
||||
value_hub = getattr(output_hub, field_himl.name)
|
||||
value_himl = getattr(output_himl, field_himl.name)
|
||||
if value_hub is None and value_himl is None: # for example, class_logits
|
||||
continue
|
||||
assert torch.allclose(value_hub, value_himl)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# autopep8: off
|
||||
dependencies = ["torch", "torchvision"]
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
repo_dir = Path(__file__).parent
|
||||
multimodal_src_dir = repo_dir / "hi-ml-multimodal" / "src" / "health_multimodal" / "image"
|
||||
sys.path.append(str(multimodal_src_dir))
|
||||
|
||||
from model import ImageModel
|
||||
from model import get_biovil_resnet as _biovil_resnet
|
||||
# autopep8: on
|
||||
|
||||
|
||||
def biovil_resnet(pretrained: bool = False) -> ImageModel:
|
||||
"""Get BioViL image encoder.
|
||||
|
||||
:param pretrained: Load pretrained weights from a checkpoint.
|
||||
"""
|
||||
model = _biovil_resnet(pretrained=pretrained)
|
||||
return model
|
Загрузка…
Ссылка в новой задаче