зеркало из https://github.com/microsoft/hi-ml.git
ENH: Improve support for choosing image model weights (#807)
This PR allows the user to determine whether to use ResNet weights pre-trained on ImageNet or BioViL, or whether to have a randomly initialised model: fixes #617 .
This commit is contained in:
Родитель
d6ea7acdde
Коммит
8dbd3249cc
|
@ -42,7 +42,7 @@ class ImageEncoder(nn.Module):
|
|||
supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
|
||||
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")
|
||||
|
||||
encoder = encoder_class(pretrained=True, **kwargs)
|
||||
encoder = encoder_class(pretrained=False, **kwargs)
|
||||
|
||||
return encoder
|
||||
|
||||
|
|
|
@ -8,10 +8,12 @@ from __future__ import annotations
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.datasets.utils import download_url
|
||||
from torchvision.models.resnet import model_urls
|
||||
|
||||
from .model import ImageModel
|
||||
from .types import ImageEncoderType
|
||||
from .types import ImageEncoderType, ImageEncoderWeightTypes
|
||||
|
||||
|
||||
JOINT_FEATURE_SIZE = 128
|
||||
|
@ -84,3 +86,39 @@ def get_biovil_t_image_encoder() -> ImageModel:
|
|||
pretrained_model_path=biovilt_checkpoint_path,
|
||||
)
|
||||
return image_model
|
||||
|
||||
|
||||
def get_imagenet_init_encoder() -> ImageModel:
|
||||
"""Download ImageNet pre-trained weights and instantiate the image model."""
|
||||
|
||||
state_dict = load_state_dict_from_url(model_urls[ImageEncoderType.RESNET50])
|
||||
image_model = ImageModel(
|
||||
img_encoder_type=ImageEncoderType.RESNET50,
|
||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||
pretrained_model_path=None,
|
||||
)
|
||||
image_model.encoder.encoder.load_state_dict(state_dict)
|
||||
return image_model
|
||||
|
||||
|
||||
def get_image_encoder(weights: str) -> ImageModel:
|
||||
"""Instantiate image model with random or pre-trained weights.
|
||||
:param weights: Select one of `random`, `imagenet`, `biovil`, `biovil_t`
|
||||
"""
|
||||
|
||||
if weights == ImageEncoderWeightTypes.RANDOM:
|
||||
image_model = ImageModel(
|
||||
img_encoder_type=ImageEncoderType.RESNET50,
|
||||
joint_feature_size=JOINT_FEATURE_SIZE,
|
||||
pretrained_model_path=None,
|
||||
)
|
||||
elif weights == ImageEncoderWeightTypes.IMAGENET:
|
||||
image_model = get_imagenet_init_encoder()
|
||||
elif weights == ImageEncoderWeightTypes.BIOVIL:
|
||||
image_model = get_biovil_image_encoder()
|
||||
elif weights == ImageEncoderWeightTypes.BIOVIL_T:
|
||||
image_model = get_biovil_t_image_encoder()
|
||||
else:
|
||||
raise ValueError(f"Weights option not found: {weights}")
|
||||
|
||||
return image_model
|
||||
|
|
|
@ -35,3 +35,11 @@ class ImageEncoderType(str, Enum):
|
|||
return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE]
|
||||
else:
|
||||
return [member for member in cls]
|
||||
|
||||
|
||||
@unique
|
||||
class ImageEncoderWeightTypes(str, Enum):
|
||||
RANDOM = "random"
|
||||
IMAGENET = "imagenet"
|
||||
BIOVIL = "biovil"
|
||||
BIOVIL_T = "biovil_t"
|
||||
|
|
|
@ -13,6 +13,7 @@ from health_multimodal.image.model.encoder import (
|
|||
MultiImageEncoder,
|
||||
restore_training_mode,
|
||||
)
|
||||
from health_multimodal.image.model.pretrained import get_imagenet_init_encoder
|
||||
from health_multimodal.image.model.resnet import resnet50
|
||||
from health_multimodal.image.model.types import ImageEncoderType
|
||||
|
||||
|
@ -29,8 +30,8 @@ def test_reload_resnet_with_dilation(replace_stride_with_dilation: Sequence[bool
|
|||
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
|
||||
|
||||
# resnet50
|
||||
original_model = ImageEncoder(img_encoder_type=ImageEncoderType.RESNET50).eval()
|
||||
model_with_dilation = ImageEncoder(img_encoder_type=ImageEncoderType.RESNET50).eval()
|
||||
original_model = get_imagenet_init_encoder().encoder.eval()
|
||||
model_with_dilation = get_imagenet_init_encoder().encoder.eval()
|
||||
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
|
||||
assert not model_with_dilation.training
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
|
||||
from health_multimodal.image.model.pretrained import get_imagenet_init_encoder
|
||||
from health_multimodal.image.model.resnet import resnet50
|
||||
|
||||
|
||||
def test_get_imagenet_init_encoder() -> None:
|
||||
"""Test that the ``imagenet`` option loads weights correctly."""
|
||||
expected_model = resnet50(pretrained=True)
|
||||
imagenet_model = get_imagenet_init_encoder()
|
||||
|
||||
for imagenet_param, expected_param in zip(
|
||||
imagenet_model.encoder.encoder.named_parameters(), expected_model.named_parameters()
|
||||
):
|
||||
assert imagenet_param[0] == expected_param[0]
|
||||
assert torch.isclose(imagenet_param[1], expected_param[1]).all()
|
Загрузка…
Ссылка в новой задаче