ENH: Add hubconf to load models without installing (#543)

* Add hubconf file

* Refactor to minimise hubconf dependencies

* Pin hubconf dependencies

* Revert "Pin hubconf dependencies"

This reverts commit bc904a963e as it
didn't seem to work.

* Add support for newer versions of torch

* Add only the model folder to path

* Remove unnecessary try-except block

* Avoid duplicate definition of Hugging Face strings

* Import from a more appropriate module

* Add test to compare package and PyTorch Hub models

* Add version number to package __init__

* Remove branch name from PyTorch Hub repo string

* Check only fields from package model

* Remove unnecessary zip wrap

Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>

Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
This commit is contained in:
Fernando Pérez-García 2022-08-04 11:22:54 +02:00 коммит произвёл GitHub
Родитель 556c39594a
Коммит bdbbf3e812
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 100 добавлений и 61 удалений

Просмотреть файл

@ -3,20 +3,4 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
CXR_BERT_COMMIT_TAG = "v1.1"
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
__all__ = [
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
"REPO_URL",
"CXR_BERT_COMMIT_TAG",
"BIOVIL_IMAGE_WEIGHTS_NAME",
"BIOVIL_IMAGE_WEIGHTS_URL",
"BIOVIL_IMAGE_WEIGHTS_MD5",
]
__version__ = "0.1.0"

Просмотреть файл

@ -36,6 +36,7 @@
from .model import ImageModel
from .model import ResnetType
from .model import get_biovil_resnet
from .inference_engine import ImageInferenceEngine
from .utils import get_biovil_resnet_inference
@ -44,5 +45,6 @@ __all__ = [
"ImageModel",
"ResnetType",
"ImageInferenceEngine",
"get_biovil_resnet",
"get_biovil_resnet_inference",
]

Просмотреть файл

@ -5,9 +5,14 @@
from .model import ImageModel
from .model import ResnetType
from .model import get_biovil_resnet
from .model import CXR_BERT_COMMIT_TAG
from .model import BIOMED_VLP_CXR_BERT_SPECIALIZED
__all__ = [
"ImageModel",
"ResnetType",
"get_biovil_resnet",
"CXR_BERT_COMMIT_TAG",
"BIOMED_VLP_CXR_BERT_SPECIALIZED",
]

Просмотреть файл

@ -3,7 +3,10 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from __future__ import annotations
import enum
import tempfile
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union, Sequence
@ -11,11 +14,49 @@ from typing import Any, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from .resnet import resnet18, resnet50
from .modules import MLP, MultiTaskModel
TypeImageEncoder = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
MODEL_TYPE = "resnet50"
JOINT_FEATURE_SIZE = 128
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
REPO_URL = f"https://huggingface.co/{BIOMED_VLP_CXR_BERT_SPECIALIZED}"
CXR_BERT_COMMIT_TAG = "v1.1"
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
BIOVIL_IMAGE_WEIGHTS_URL = f"{REPO_URL}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
def _download_biovil_image_model_weights() -> Path:
"""Download image model weights from Hugging Face.
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
"""
root_dir = tempfile.gettempdir()
download_url(
BIOVIL_IMAGE_WEIGHTS_URL,
root=root_dir,
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
)
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
def get_biovil_resnet(pretrained: bool = True) -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model."""
resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None
image_model = ImageModel(
img_model_type=MODEL_TYPE,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=resnet_checkpoint_path,
)
return image_model
@enum.unique

Просмотреть файл

@ -6,9 +6,8 @@
from typing import Any, List, Type, Union
import torch
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
class ResNetHIML(ResNet):

Просмотреть файл

@ -3,51 +3,15 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import tempfile
from pathlib import Path
from torchvision.datasets.utils import download_url
from .. import BIOVIL_IMAGE_WEIGHTS_NAME
from .. import BIOVIL_IMAGE_WEIGHTS_URL
from .. import BIOVIL_IMAGE_WEIGHTS_MD5
from .model import ImageModel
from .inference_engine import ImageInferenceEngine
from .data.transforms import create_chest_xray_transform_for_inference
from .model import get_biovil_resnet
MODEL_TYPE = "resnet50"
JOINT_FEATURE_SIZE = 128
TRANSFORM_RESIZE = 512
TRANSFORM_CENTER_CROP_SIZE = 480
def _download_biovil_image_model_weights() -> Path:
"""Download image model weights from Hugging Face.
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
"""
root_dir = tempfile.gettempdir()
download_url(
BIOVIL_IMAGE_WEIGHTS_URL,
root=root_dir,
filename=BIOVIL_IMAGE_WEIGHTS_NAME,
md5=BIOVIL_IMAGE_WEIGHTS_MD5,
)
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
def get_biovil_resnet() -> ImageModel:
"""Download weights from Hugging Face and instantiate the image model."""
resnet_checkpoint_path = _download_biovil_image_model_weights()
image_model = ImageModel(
img_model_type=MODEL_TYPE,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=resnet_checkpoint_path,
)
return image_model
def get_biovil_resnet_inference() -> ImageInferenceEngine:
"""Create a :class:`ImageInferenceEngine` for the image model.

Просмотреть файл

@ -6,9 +6,8 @@
from typing import Tuple
from .. import BIOMED_VLP_CXR_BERT_SPECIALIZED
from .. import CXR_BERT_COMMIT_TAG
from ..image.model import CXR_BERT_COMMIT_TAG
from ..image.model import BIOMED_VLP_CXR_BERT_SPECIALIZED
from .inference_engine import TextInferenceEngine
from .model import CXRBertModel
from .model import CXRBertTokenizer

Просмотреть файл

@ -3,11 +3,15 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from dataclasses import fields
import pytest
import torch
from health_multimodal.image.model.model import ImageEncoder, ImageModel
from health_multimodal.image.model.model import ImageModel
from health_multimodal.image.model.model import ImageEncoder
from health_multimodal.image.model.model import ImageModelOutput
from health_multimodal.image.model.model import get_biovil_resnet
from health_multimodal.image.model.modules import MultiTaskModel
from health_multimodal.image.model.resnet import resnet50
@ -132,3 +136,23 @@ def test_reload_resnet_with_dilation() -> None:
with torch.no_grad():
expected_output = expected_model(image)
assert torch.allclose(outputs_dilation, expected_output)
@torch.no_grad()
def test_hubconf() -> None:
"""Test that instantiating the image model using the PyTorch Hub is consistent with older methods."""
image = torch.rand(1, 3, 480, 480)
github = 'microsoft/hi-ml'
model_hub = torch.hub.load(github, 'biovil_resnet', pretrained=True)
model_himl = get_biovil_resnet()
output_hub: ImageModelOutput = model_hub(image)
output_himl: ImageModelOutput = model_himl(image)
for field_himl in fields(output_himl):
value_hub = getattr(output_hub, field_himl.name)
value_himl = getattr(output_himl, field_himl.name)
if value_hub is None and value_himl is None: # for example, class_logits
continue
assert torch.allclose(value_hub, value_himl)

21
hubconf.py Normal file
Просмотреть файл

@ -0,0 +1,21 @@
# autopep8: off
dependencies = ["torch", "torchvision"]
import sys
from pathlib import Path
repo_dir = Path(__file__).parent
multimodal_src_dir = repo_dir / "hi-ml-multimodal" / "src" / "health_multimodal" / "image"
sys.path.append(str(multimodal_src_dir))
from model import ImageModel
from model import get_biovil_resnet as _biovil_resnet
# autopep8: on
def biovil_resnet(pretrained: bool = False) -> ImageModel:
"""Get BioViL image encoder.
:param pretrained: Load pretrained weights from a checkpoint.
"""
model = _biovil_resnet(pretrained=pretrained)
return model