ENH: Improve support for choosing image model weights (#807)

This PR allows the user to determine whether to use ResNet weights
pre-trained on ImageNet or BioViL, or whether to have a randomly
initialised model: fixes #617 .
This commit is contained in:
markpinnock 2023-04-12 15:43:49 +01:00 коммит произвёл GitHub
Родитель d6ea7acdde
Коммит 8dbd3249cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 67 добавлений и 4 удалений

Просмотреть файл

@ -42,7 +42,7 @@ class ImageEncoder(nn.Module):
supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")
encoder = encoder_class(pretrained=True, **kwargs)
encoder = encoder_class(pretrained=False, **kwargs)
return encoder

Просмотреть файл

@ -8,10 +8,12 @@ from __future__ import annotations
import tempfile
from pathlib import Path
from torch.hub import load_state_dict_from_url
from torchvision.datasets.utils import download_url
from torchvision.models.resnet import model_urls
from .model import ImageModel
from .types import ImageEncoderType
from .types import ImageEncoderType, ImageEncoderWeightTypes
JOINT_FEATURE_SIZE = 128
@ -84,3 +86,39 @@ def get_biovil_t_image_encoder() -> ImageModel:
pretrained_model_path=biovilt_checkpoint_path,
)
return image_model
def get_imagenet_init_encoder() -> ImageModel:
"""Download ImageNet pre-trained weights and instantiate the image model."""
state_dict = load_state_dict_from_url(model_urls[ImageEncoderType.RESNET50])
image_model = ImageModel(
img_encoder_type=ImageEncoderType.RESNET50,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=None,
)
image_model.encoder.encoder.load_state_dict(state_dict)
return image_model
def get_image_encoder(weights: str) -> ImageModel:
"""Instantiate image model with random or pre-trained weights.
:param weights: Select one of `random`, `imagenet`, `biovil`, `biovil_t`
"""
if weights == ImageEncoderWeightTypes.RANDOM:
image_model = ImageModel(
img_encoder_type=ImageEncoderType.RESNET50,
joint_feature_size=JOINT_FEATURE_SIZE,
pretrained_model_path=None,
)
elif weights == ImageEncoderWeightTypes.IMAGENET:
image_model = get_imagenet_init_encoder()
elif weights == ImageEncoderWeightTypes.BIOVIL:
image_model = get_biovil_image_encoder()
elif weights == ImageEncoderWeightTypes.BIOVIL_T:
image_model = get_biovil_t_image_encoder()
else:
raise ValueError(f"Weights option not found: {weights}")
return image_model

Просмотреть файл

@ -35,3 +35,11 @@ class ImageEncoderType(str, Enum):
return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE]
else:
return [member for member in cls]
@unique
class ImageEncoderWeightTypes(str, Enum):
RANDOM = "random"
IMAGENET = "imagenet"
BIOVIL = "biovil"
BIOVIL_T = "biovil_t"

Просмотреть файл

@ -13,6 +13,7 @@ from health_multimodal.image.model.encoder import (
MultiImageEncoder,
restore_training_mode,
)
from health_multimodal.image.model.pretrained import get_imagenet_init_encoder
from health_multimodal.image.model.resnet import resnet50
from health_multimodal.image.model.types import ImageEncoderType
@ -29,8 +30,8 @@ def test_reload_resnet_with_dilation(replace_stride_with_dilation: Sequence[bool
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
# resnet50
original_model = ImageEncoder(img_encoder_type=ImageEncoderType.RESNET50).eval()
model_with_dilation = ImageEncoder(img_encoder_type=ImageEncoderType.RESNET50).eval()
original_model = get_imagenet_init_encoder().encoder.eval()
model_with_dilation = get_imagenet_init_encoder().encoder.eval()
model_with_dilation.reload_encoder_with_dilation(replace_stride_with_dilation)
assert not model_with_dilation.training

Просмотреть файл

@ -0,0 +1,16 @@
import torch
from health_multimodal.image.model.pretrained import get_imagenet_init_encoder
from health_multimodal.image.model.resnet import resnet50
def test_get_imagenet_init_encoder() -> None:
"""Test that the ``imagenet`` option loads weights correctly."""
expected_model = resnet50(pretrained=True)
imagenet_model = get_imagenet_init_encoder()
for imagenet_param, expected_param in zip(
imagenet_model.encoder.encoder.named_parameters(), expected_model.named_parameters()
):
assert imagenet_param[0] == expected_param[0]
assert torch.isclose(imagenet_param[1], expected_param[1]).all()