зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add transformer pooling (#227)
The main thing this PR does is to add TransformerPooling. However, I moved the constructer of the pooling layers outside of DeepMIL to allow for more flexibility in the constructor arguments. This required changing some tests. The transformer pooling is based on the pytorch Transformer class, however, in the pytorch implementation you have to pass a flag (need_weights) to the forward (??? like why not the init???) of the MultiheadAttention class. See line 186 of attention_layer.py In the forward() of TransformerPooling (see line 244 of attention_layer.py) we discard the self attention of the cls token and rescale the remaining attention scores, so they sum to one.
This commit is contained in:
Родитель
f6c5aadd24
Коммит
54632a895f
|
@ -10,6 +10,7 @@ Each release contains a link for "Full Changelog"
|
|||
## 0.1.14
|
||||
|
||||
### Added
|
||||
- ([#227](https://github.com/microsoft/hi-ml/pull/227)) Add TransformerPooling.
|
||||
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
|
||||
the environment file since it is necessary for the augmentations.
|
||||
- ([#193](https://github.com/microsoft/hi-ml/pull/193)) Add transformation adaptor to hi-ml-histopathology.
|
||||
|
@ -21,6 +22,7 @@ the environment file since it is necessary for the augmentations.
|
|||
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Improved editor setup for VSCode.
|
||||
|
||||
### Changed
|
||||
- ([#227](https://github.com/microsoft/hi-ml/pull/227)) Pooling constructor is outside of DeepMIL and inside of BaseMIL now.
|
||||
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Model config loader is now more flexible, can accept fully qualified class name or just top-level module name and class (like histopathology.DeepSMILECrck)
|
||||
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Runner raises an error when Conda environment file contains a pip include (-r) statement
|
||||
|
||||
|
|
|
@ -8,14 +8,15 @@ It is responsible for instantiating the encoder and full DeepMIL model. Subclass
|
|||
their datamodules and configure experiment-specific parameters.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type # noqa
|
||||
from typing import Optional, Tuple, Type # noqa
|
||||
|
||||
import param
|
||||
from torch import nn
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer,\
|
||||
TransformerPooling
|
||||
|
||||
from histopathology.datasets.base_dataset import SlidesDataset
|
||||
from histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
||||
|
@ -27,7 +28,16 @@ from histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
|
|||
|
||||
class BaseMIL(LightningContainer):
|
||||
# Model parameters:
|
||||
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
|
||||
pool_type: str = param.String(doc="Name of the pooling layer class to use.")
|
||||
pool_hidden_dim: int = param.Integer(128, doc="If pooling has a learnable part, this defines the number of the\
|
||||
hidden dimensions.")
|
||||
pool_out_dim: int = param.Integer(1, doc="Dimension of the pooled representation.")
|
||||
num_transformer_pool_layers: int = param.Integer(4, doc="If transformer pooling is chosen, this defines the number\
|
||||
of encoding layers.")
|
||||
num_transformer_pool_heads: int = param.Integer(4, doc="If transformer pooling is chosen, this defines the number\
|
||||
of attention heads.")
|
||||
is_finetune: bool = param.Boolean(False, doc="If True, fine-tune the encoder during training. If False (default), "
|
||||
"keep the encoder frozen.")
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
|
||||
|
||||
|
@ -52,8 +62,6 @@ class BaseMIL(LightningContainer):
|
|||
"`none` (default),`cpu`, `gpu`")
|
||||
encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading"
|
||||
"enconding_chunk_size tiles per chunk")
|
||||
is_finetune: bool = param.Boolean(False, doc="If True, fine-tune the encoder during training. If False (default), "
|
||||
"keep the encoder frozen.")
|
||||
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
|
||||
|
||||
@property
|
||||
|
@ -62,7 +70,7 @@ class BaseMIL(LightningContainer):
|
|||
|
||||
def setup(self) -> None:
|
||||
if self.encoder_type == SSLEncoder.__name__:
|
||||
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
|
||||
raise NotImplementedError("SSLEncoder requires a pre-trained checkpoint.")
|
||||
|
||||
self.encoder = self.get_encoder()
|
||||
if not self.is_finetune:
|
||||
|
@ -86,16 +94,30 @@ class BaseMIL(LightningContainer):
|
|||
else:
|
||||
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
|
||||
|
||||
def get_pooling_layer(self) -> Type[nn.Module]:
|
||||
if self.pooling_type == AttentionLayer.__name__:
|
||||
return AttentionLayer
|
||||
elif self.pooling_type == GatedAttentionLayer.__name__:
|
||||
return GatedAttentionLayer
|
||||
elif self.pooling_type == MeanPoolingLayer.__name__:
|
||||
return MeanPoolingLayer
|
||||
def get_pooling_layer(self) -> Tuple[nn.Module, int]:
|
||||
num_encoding = self.encoder.num_encoding
|
||||
|
||||
if self.pool_type == AttentionLayer.__name__:
|
||||
pooling_layer = AttentionLayer(num_encoding,
|
||||
self.pool_hidden_dim,
|
||||
self.pool_out_dim)
|
||||
elif self.pool_type == GatedAttentionLayer.__name__:
|
||||
pooling_layer = GatedAttentionLayer(num_encoding,
|
||||
self.pool_hidden_dim,
|
||||
self.pool_out_dim)
|
||||
elif self.pool_type == MeanPoolingLayer.__name__:
|
||||
pooling_layer = MeanPoolingLayer()
|
||||
elif self.pool_type == TransformerPooling.__name__:
|
||||
pooling_layer = TransformerPooling(self.num_transformer_pool_layers,
|
||||
self.num_transformer_pool_heads,
|
||||
num_encoding)
|
||||
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||
|
||||
num_features = num_encoding * self.pool_out_dim
|
||||
return pooling_layer, num_features
|
||||
|
||||
def create_model(self) -> DeepMILModule:
|
||||
self.data_module = self.get_data_module()
|
||||
# Encoding is done in the datamodule, so here we provide instead a dummy
|
||||
|
@ -106,10 +128,15 @@ class BaseMIL(LightningContainer):
|
|||
params.requires_grad = True
|
||||
else:
|
||||
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
|
||||
|
||||
# Construct pooling layer
|
||||
pooling_layer, num_features = self.get_pooling_layer()
|
||||
|
||||
return DeepMILModule(encoder=self.model_encoder,
|
||||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
pooling_layer=pooling_layer,
|
||||
num_features=num_features,
|
||||
dropout_rate=self.dropout_rate,
|
||||
class_weights=self.data_module.class_weights,
|
||||
l_rate=self.l_rate,
|
||||
|
|
|
@ -47,7 +47,9 @@ class DeepSMILECrck(BaseMIL):
|
|||
# Define dictionary with default params that can be overridden from subclasses or CLI
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMIL:
|
||||
pooling_type=AttentionLayer.__name__,
|
||||
pool_type=AttentionLayer.__name__,
|
||||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
encoding_chunk_size=60,
|
||||
cache_mode=CacheMode.MEMORY,
|
||||
precache_location=CacheLocation.CPU,
|
||||
|
|
|
@ -12,7 +12,7 @@ from pytorch_lightning.callbacks import Callback
|
|||
|
||||
from health_azure.utils import CheckpointDownloader
|
||||
from health_azure.utils import get_workspace, is_running_in_azure_ml
|
||||
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_ml.utils import fixed_paths
|
||||
from histopathology.datamodules.base_module import CacheMode, CacheLocation
|
||||
from histopathology.datamodules.panda_module import PandaTilesDataModule
|
||||
|
@ -44,7 +44,9 @@ class DeepSMILEPanda(BaseMIL):
|
|||
def __init__(self, **kwargs: Any) -> None:
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMIL:
|
||||
pooling_type=GatedAttentionLayer.__name__,
|
||||
pool_type=AttentionLayer.__name__,
|
||||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
# average number of tiles is 56 for PANDA
|
||||
encoding_chunk_size=60,
|
||||
cache_mode=CacheMode.MEMORY,
|
||||
|
|
|
@ -27,7 +27,6 @@ from histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_t
|
|||
from histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey
|
||||
from histopathology.utils.viz_utils import load_image_dict
|
||||
|
||||
|
||||
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
|
||||
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
|
||||
|
@ -45,9 +44,8 @@ class DeepMILModule(LightningModule):
|
|||
label_column: str,
|
||||
n_classes: int,
|
||||
encoder: TileEncoder,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
pool_hidden_dim: int = 128,
|
||||
pool_out_dim: int = 1,
|
||||
pooling_layer: Callable[[Tensor], Tuple[Tensor, Tensor]],
|
||||
num_features: int,
|
||||
dropout_rate: Optional[float] = None,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
l_rate: float = 5e-4,
|
||||
|
@ -61,14 +59,12 @@ class DeepMILModule(LightningModule):
|
|||
is_finetune: bool = False) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes
|
||||
should be set to 1.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
|
||||
set to 1.
|
||||
:param encoder: The tile encoder to use for feature extraction. If no encoding is needed,
|
||||
you should use `IdentityEncoder`.
|
||||
:param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a
|
||||
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
|
||||
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
|
||||
:param pool_out_dim: Output dimension of pooling layer (default=1).
|
||||
:param pooling_layer: A pooling layer nn.module
|
||||
:param num_features: Dimensions of the input encoding features * attention dim outputs
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param l_rate: Optimiser learning rate.
|
||||
|
@ -86,13 +82,11 @@ class DeepMILModule(LightningModule):
|
|||
# Dataset specific attributes
|
||||
self.label_column = label_column
|
||||
self.n_classes = n_classes
|
||||
self.pool_hidden_dim = pool_hidden_dim
|
||||
self.pool_out_dim = pool_out_dim
|
||||
self.pooling_layer = pooling_layer
|
||||
self.dropout_rate = dropout_rate
|
||||
self.class_weights = class_weights
|
||||
self.encoder = encoder
|
||||
self.num_encoding = self.encoder.num_encoding
|
||||
self.aggregation_fn = pooling_layer
|
||||
self.num_pooling = num_features
|
||||
|
||||
if class_names is not None:
|
||||
self.class_names = class_names
|
||||
|
@ -125,7 +119,6 @@ class DeepMILModule(LightningModule):
|
|||
# Finetuning attributes
|
||||
self.is_finetune = is_finetune
|
||||
|
||||
self.aggregation_fn, self.num_pooling = self.get_pooling()
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.loss_fn = self.get_loss()
|
||||
self.activation_fn = self.get_activation()
|
||||
|
@ -135,13 +128,6 @@ class DeepMILModule(LightningModule):
|
|||
self.val_metrics = self.get_metrics()
|
||||
self.test_metrics = self.get_metrics()
|
||||
|
||||
def get_pooling(self) -> Tuple[Callable, int]:
|
||||
pooling_layer = self.pooling_layer(self.num_encoding,
|
||||
self.pool_hidden_dim,
|
||||
self.pool_out_dim)
|
||||
num_features = self.num_encoding * self.pool_out_dim
|
||||
return pooling_layer, num_features
|
||||
|
||||
def get_classifier(self) -> Callable:
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
|
@ -284,7 +270,6 @@ class DeepMILModule(LightningModule):
|
|||
if is_global_rank_zero():
|
||||
logging.warning("Coordinates not found in batch. If this is not expected check your"
|
||||
"input tiles dataset.")
|
||||
|
||||
return results
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -14,11 +14,8 @@ from torch.utils.data._utils.collate import default_collate
|
|||
from torchvision.models import resnet18
|
||||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.networks.layers.attention_layers import (
|
||||
AttentionLayer,
|
||||
GatedAttentionLayer,
|
||||
MeanPoolingLayer,
|
||||
)
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
|
||||
|
||||
from histopathology.configs.classification.DeepSMILECrck import DeepSMILECrck
|
||||
from histopathology.configs.classification.DeepSMILEPanda import DeepSMILEPanda
|
||||
|
@ -34,12 +31,22 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
|
|||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
||||
|
||||
|
||||
def get_attention_pooling_layer(num_encoding: int = 512,
|
||||
pool_out_dim: int = 1) -> Tuple[Type[nn.Module], int]:
|
||||
|
||||
pool_hidden_dim = 5 # different dimensions get tested in test_attentionlayers.py
|
||||
pooling_layer = AttentionLayer(num_encoding,
|
||||
pool_hidden_dim,
|
||||
pool_out_dim)
|
||||
|
||||
num_features = num_encoding * pool_out_dim
|
||||
return pooling_layer, num_features
|
||||
|
||||
|
||||
def _test_lightningmodule(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
|
@ -48,13 +55,16 @@ def _test_lightningmodule(
|
|||
|
||||
# hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
|
||||
encoder = get_supervised_imagenet_encoder()
|
||||
|
||||
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
|
||||
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=pool_out_dim)
|
||||
|
||||
module = DeepMILModule(
|
||||
encoder=encoder,
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=pooling_layer,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
num_features=num_features,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
|
||||
|
@ -70,7 +80,7 @@ def _test_lightningmodule(
|
|||
bag_labels_list.append(module.get_bag_label(labels))
|
||||
logit, attn = module(bag)
|
||||
assert logit.shape == (1, n_classes)
|
||||
assert attn.shape == (module.pool_out_dim, max_bag_size)
|
||||
assert attn.shape == (pool_out_dim, max_bag_size)
|
||||
bag_logits_list.append(logit.view(-1))
|
||||
bag_attn_list.append(attn)
|
||||
|
||||
|
@ -110,49 +120,24 @@ def _test_lightningmodule(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_attention(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=pooling_layer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_mean_pooling(
|
||||
n_classes: int,
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=MeanPoolingLayer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=1,
|
||||
pool_out_dim=1,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
|
||||
def is_integral(x: torch.Tensor) -> bool:
|
||||
return (x == x.long()).all() # type: ignore
|
||||
|
@ -172,11 +157,17 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
|
|||
|
||||
def test_metrics() -> None:
|
||||
input_dim = (128,)
|
||||
|
||||
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
|
||||
pooling_layer, num_features = get_attention_pooling_layer(num_encoding=input_dim[0],
|
||||
pool_out_dim=1)
|
||||
|
||||
module = DeepMILModule(
|
||||
encoder=IdentityEncoder(input_dim=input_dim),
|
||||
label_column=TilesDataset.LABEL_COLUMN,
|
||||
n_classes=1,
|
||||
pooling_layer=AttentionLayer,
|
||||
pooling_layer=pooling_layer,
|
||||
num_features=num_features
|
||||
)
|
||||
|
||||
# Patching to enable running the module without a Trainer object
|
||||
|
@ -326,13 +317,16 @@ def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> N
|
|||
def test_class_weights_binary() -> None:
|
||||
class_weights = Tensor([0.5, 3.5])
|
||||
n_classes = 1
|
||||
|
||||
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
|
||||
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=1)
|
||||
|
||||
module = DeepMILModule(
|
||||
encoder=get_supervised_imagenet_encoder(),
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=AttentionLayer,
|
||||
pool_hidden_dim=5,
|
||||
pool_out_dim=1,
|
||||
pooling_layer=pooling_layer,
|
||||
num_features=num_features,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
logits = Tensor(randn(1, n_classes))
|
||||
|
@ -351,13 +345,16 @@ def test_class_weights_binary() -> None:
|
|||
def test_class_weights_multiclass() -> None:
|
||||
class_weights = Tensor([0.33, 0.33, 0.33])
|
||||
n_classes = 3
|
||||
|
||||
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
|
||||
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=1)
|
||||
|
||||
module = DeepMILModule(
|
||||
encoder=get_supervised_imagenet_encoder(),
|
||||
label_column="label",
|
||||
n_classes=n_classes,
|
||||
pooling_layer=AttentionLayer,
|
||||
pool_hidden_dim=5,
|
||||
pool_out_dim=1,
|
||||
pooling_layer=pooling_layer,
|
||||
num_features=num_features,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
logits = Tensor(randn(1, n_classes))
|
||||
|
|
|
@ -6,25 +6,22 @@
|
|||
Created using the original DeepMIL paper and code from Ilse et al., 2018
|
||||
https://github.com/AMLab-Amsterdam/AttentionDeepMIL (MIT License)
|
||||
"""
|
||||
from typing import Any, Tuple
|
||||
from typing import Tuple, Optional
|
||||
from torch import nn, Tensor, transpose, mm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, TransformerEncoderLayer
|
||||
|
||||
|
||||
class MeanPoolingLayer(nn.Module):
|
||||
"""Mean pooling returns uniform weights and the average feature vector over the first axis"""
|
||||
|
||||
# args/kwargs added here for compatibility with parametrised pooling modules
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
num_instances = features.shape[0]
|
||||
A = torch.full((1, num_instances), 1. / num_instances)
|
||||
M = features.mean(dim=0)
|
||||
M = M.view(1, -1)
|
||||
return (A, M)
|
||||
attention_weights = torch.full((1, num_instances), 1. / num_instances)
|
||||
pooled_features = features.mean(dim=0)
|
||||
pooled_features = pooled_features.view(1, -1)
|
||||
return (attention_weights, pooled_features)
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
|
@ -48,12 +45,12 @@ class AttentionLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
H = features.view(-1, self.input_dims) # N x L
|
||||
A = self.attention(H) # N x K
|
||||
A = transpose(A, 1, 0) # K x N
|
||||
A = F.softmax(A, dim=1) # Softmax over N : K x N
|
||||
M = mm(A, H) # Matrix multiplication : K x L
|
||||
return(A, M)
|
||||
features = features.view(-1, self.input_dims) # N x L
|
||||
attention_weights = self.attention(features) # N x K
|
||||
attention_weights = transpose(attention_weights, 1, 0) # K x N
|
||||
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
|
||||
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
|
||||
return(attention_weights, pooled_features)
|
||||
|
||||
|
||||
class GatedAttentionLayer(nn.Module):
|
||||
|
@ -81,11 +78,137 @@ class GatedAttentionLayer(nn.Module):
|
|||
self.attention_weights = nn.Linear(self.hidden_dims, self.attention_dims)
|
||||
|
||||
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
H = features.view(-1, self.input_dims) # N x L
|
||||
A_V = self.attention_V(H) # N x D
|
||||
A_U = self.attention_U(H) # N x D
|
||||
A = self.attention_weights(A_V * A_U) # Element-wise multiplication : N x K
|
||||
A = transpose(A, 1, 0) # K x N
|
||||
A = F.softmax(A, dim=1) # Softmax over N : K x N
|
||||
M = mm(A, H) # Matrix multiplication : K x L
|
||||
return(A, M)
|
||||
features = features.view(-1, self.input_dims) # N x L
|
||||
A_V = self.attention_V(features) # N x D
|
||||
A_U = self.attention_U(features) # N x D
|
||||
attention_weights = self.attention_weights(A_V * A_U) # Element-wise multiplication : N x K
|
||||
attention_weights = transpose(attention_weights, 1, 0) # K x N
|
||||
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
|
||||
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
|
||||
return(attention_weights, pooled_features)
|
||||
|
||||
|
||||
class CustomTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Adaptation of the pytorch TransformerEncoderLayer that always outputs the attention weights.
|
||||
|
||||
TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||
in a different way during application.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
activation: the activation function of the intermediate layer, can be a string
|
||||
("relu" or "gelu") or a unary callable. Default: relu
|
||||
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False``.
|
||||
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
||||
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out, attention_weights = encoder_layer(src)
|
||||
|
||||
Alternatively, when ``batch_first`` is ``True``:
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
||||
>>> src = torch.rand(32, 10, 512)
|
||||
>>> out, attention_weights = encoder_layer(src)
|
||||
"""
|
||||
# new forward returns output as well as attention weights
|
||||
def forward(self, src: torch.Tensor, # type: ignore
|
||||
src_mask: Optional[torch.Tensor] = None,
|
||||
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
see the docs in Transformer class.
|
||||
"""
|
||||
|
||||
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
||||
|
||||
x = src
|
||||
if self.norm_first:
|
||||
sa_block_out, a = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
||||
x = x + sa_block_out
|
||||
x = x + self._ff_block(self.norm2(x))
|
||||
else:
|
||||
sa_block_out, a = self._sa_block(x, src_mask, src_key_padding_mask)
|
||||
x = self.norm1(x + sa_block_out)
|
||||
x = self.norm2(x + self._ff_block(x))
|
||||
|
||||
return x, a
|
||||
|
||||
# new self-attention block, returns output as well as attention weights
|
||||
def _sa_block(self, x: Tensor, # type: ignore
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x, a = self.self_attn(x, x, x,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=True) # Just because of this flag I had to copy all of the code...
|
||||
x = x[0]
|
||||
return self.dropout1(x), a
|
||||
|
||||
|
||||
class TransformerPooling(Module):
|
||||
"""Create a Transformer encoder module consisting of multiple Transformer encoder layers.
|
||||
|
||||
We use a additional classification token (cls token) for pooling like seen in ViT/Bert. First, the cls token is
|
||||
appended to the list of tiles encodings. Second, we perform self-attention between all tile encodings and the cls
|
||||
token. Last, we extract the cls token and use it for classification.
|
||||
|
||||
Args:
|
||||
num_layers: Number of Transformer encoder layers.
|
||||
num_heads: Number of attention heads per layer.
|
||||
dim_representation: Dimension of input encoding.
|
||||
"""
|
||||
def __init__(self, num_layers: int, num_heads: int, dim_representation: int) -> None:
|
||||
super(TransformerPooling, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.dim_representation = dim_representation
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros([1, dim_representation]))
|
||||
|
||||
self.transformer_encoder_layers = []
|
||||
for _ in range(self.num_layers):
|
||||
self.transformer_encoder_layers.append(
|
||||
CustomTransformerEncoderLayer(self.dim_representation,
|
||||
self.num_heads,
|
||||
dim_feedforward=self.dim_representation,
|
||||
dropout=0.1,
|
||||
activation=F.gelu,
|
||||
batch_first=True))
|
||||
self.transformer_encoder_layers = torch.nn.ModuleList(self.transformer_encoder_layers) # type: ignore
|
||||
|
||||
def forward(self, features: Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Append cls token
|
||||
features = torch.vstack([self.cls_token, features]).unsqueeze(0)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
features, attention_weights = self.transformer_encoder_layers[i](features)
|
||||
|
||||
# Extract cls token
|
||||
pooled_features = features[:, 0]
|
||||
|
||||
# Get attention weights with respect to the cls token, without the element where it attends to itself
|
||||
|
||||
self_attention_cls_token = attention_weights[0, 0, 0] # type: ignore
|
||||
attention_weights = attention_weights[:, 0, 1:] # type: ignore
|
||||
|
||||
# We want A to sum to one, simple hack: add self_attention_cls_token/num_tiles to each element
|
||||
attention_weights += self_attention_cls_token / attention_weights.shape[-1]
|
||||
|
||||
return (attention_weights, pooled_features)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Type, Union
|
|||
from torch import nn, rand, sum, allclose, ones_like
|
||||
|
||||
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
|
||||
MeanPoolingLayer)
|
||||
MeanPoolingLayer, TransformerPooling)
|
||||
|
||||
|
||||
def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
|
||||
|
@ -18,8 +18,11 @@ def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
|
|||
row_sums = sum(attn_weights, dim=1, keepdim=True)
|
||||
assert allclose(row_sums, ones_like(row_sums))
|
||||
|
||||
pooled_features = attn_weights @ features.flatten(start_dim=1)
|
||||
assert allclose(pooled_features, output_features)
|
||||
if isinstance(attentionlayer, TransformerPooling):
|
||||
pass
|
||||
else:
|
||||
pooled_features = attn_weights @ features.flatten(start_dim=1)
|
||||
assert allclose(pooled_features, output_features)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim_in", [1, 3])
|
||||
|
@ -41,3 +44,14 @@ def test_attentionlayer(dim_in: int, dim_hid: int, dim_att: int, batch_size: int
|
|||
@pytest.mark.parametrize("batch_size", [1, 7])
|
||||
def test_mean_pooling(dim_in: int, batch_size: int,) -> None:
|
||||
_test_attention_layer(MeanPoolingLayer(), dim_in=dim_in, dim_att=1, batch_size=batch_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_layers", [1, 4])
|
||||
@pytest.mark.parametrize("num_heads", [1, 2])
|
||||
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
|
||||
@pytest.mark.parametrize("batch_size", [1, 7])
|
||||
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int, batch_size: int) -> None:
|
||||
transformer_pooling = TransformerPooling(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
dim_representation=dim_in).eval()
|
||||
_test_attention_layer(transformer_pooling, dim_in=dim_in, dim_att=1, batch_size=batch_size)
|
||||
|
|
Загрузка…
Ссылка в новой задаче