The main thing this PR does is to add TransformerPooling. However, I moved the constructer of the pooling layers outside of DeepMIL to allow for more flexibility in the constructor arguments. This required changing some tests.

The transformer pooling is based on the pytorch Transformer class, however, in the pytorch implementation you have to pass a flag (need_weights) to the forward (??? like why not the init???) of the MultiheadAttention class. See line 186 of attention_layer.py

In the forward() of TransformerPooling (see line 244 of attention_layer.py) we discard the self attention of the cls token and rescale the remaining attention scores, so they sum to one.
This commit is contained in:
maxilse 2022-03-11 14:27:05 +01:00 коммит произвёл GitHub
Родитель f6c5aadd24
Коммит 54632a895f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 261 добавлений и 109 удалений

Просмотреть файл

@ -10,6 +10,7 @@ Each release contains a link for "Full Changelog"
## 0.1.14
### Added
- ([#227](https://github.com/microsoft/hi-ml/pull/227)) Add TransformerPooling.
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
the environment file since it is necessary for the augmentations.
- ([#193](https://github.com/microsoft/hi-ml/pull/193)) Add transformation adaptor to hi-ml-histopathology.
@ -21,6 +22,7 @@ the environment file since it is necessary for the augmentations.
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Improved editor setup for VSCode.
### Changed
- ([#227](https://github.com/microsoft/hi-ml/pull/227)) Pooling constructor is outside of DeepMIL and inside of BaseMIL now.
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Model config loader is now more flexible, can accept fully qualified class name or just top-level module name and class (like histopathology.DeepSMILECrck)
- ([#198](https://github.com/microsoft/hi-ml/pull/198)) Runner raises an error when Conda environment file contains a pip include (-r) statement

Просмотреть файл

@ -8,14 +8,15 @@ It is responsible for instantiating the encoder and full DeepMIL model. Subclass
their datamodules and configure experiment-specific parameters.
"""
from pathlib import Path
from typing import Optional, Type # noqa
from typing import Optional, Tuple, Type # noqa
import param
from torch import nn
from torchvision.models import resnet18
from health_ml.lightning_container import LightningContainer
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer,\
TransformerPooling
from histopathology.datasets.base_dataset import SlidesDataset
from histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
@ -27,7 +28,16 @@ from histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
class BaseMIL(LightningContainer):
# Model parameters:
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
pool_type: str = param.String(doc="Name of the pooling layer class to use.")
pool_hidden_dim: int = param.Integer(128, doc="If pooling has a learnable part, this defines the number of the\
hidden dimensions.")
pool_out_dim: int = param.Integer(1, doc="Dimension of the pooled representation.")
num_transformer_pool_layers: int = param.Integer(4, doc="If transformer pooling is chosen, this defines the number\
of encoding layers.")
num_transformer_pool_heads: int = param.Integer(4, doc="If transformer pooling is chosen, this defines the number\
of attention heads.")
is_finetune: bool = param.Boolean(False, doc="If True, fine-tune the encoder during training. If False (default), "
"keep the encoder frozen.")
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
@ -52,8 +62,6 @@ class BaseMIL(LightningContainer):
"`none` (default),`cpu`, `gpu`")
encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading"
"enconding_chunk_size tiles per chunk")
is_finetune: bool = param.Boolean(False, doc="If True, fine-tune the encoder during training. If False (default), "
"keep the encoder frozen.")
# local_dataset (used as data module root_path) is declared in DatasetParams superclass
@property
@ -62,7 +70,7 @@ class BaseMIL(LightningContainer):
def setup(self) -> None:
if self.encoder_type == SSLEncoder.__name__:
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
raise NotImplementedError("SSLEncoder requires a pre-trained checkpoint.")
self.encoder = self.get_encoder()
if not self.is_finetune:
@ -86,16 +94,30 @@ class BaseMIL(LightningContainer):
else:
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
def get_pooling_layer(self) -> Type[nn.Module]:
if self.pooling_type == AttentionLayer.__name__:
return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
def get_pooling_layer(self) -> Tuple[nn.Module, int]:
num_encoding = self.encoder.num_encoding
if self.pool_type == AttentionLayer.__name__:
pooling_layer = AttentionLayer(num_encoding,
self.pool_hidden_dim,
self.pool_out_dim)
elif self.pool_type == GatedAttentionLayer.__name__:
pooling_layer = GatedAttentionLayer(num_encoding,
self.pool_hidden_dim,
self.pool_out_dim)
elif self.pool_type == MeanPoolingLayer.__name__:
pooling_layer = MeanPoolingLayer()
elif self.pool_type == TransformerPooling.__name__:
pooling_layer = TransformerPooling(self.num_transformer_pool_layers,
self.num_transformer_pool_heads,
num_encoding)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
num_features = num_encoding * self.pool_out_dim
return pooling_layer, num_features
def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
@ -106,10 +128,15 @@ class BaseMIL(LightningContainer):
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
# Construct pooling layer
pooling_layer, num_features = self.get_pooling_layer()
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
pooling_layer=pooling_layer,
num_features=num_features,
dropout_rate=self.dropout_rate,
class_weights=self.data_module.class_weights,
l_rate=self.l_rate,

Просмотреть файл

@ -47,7 +47,9 @@ class DeepSMILECrck(BaseMIL):
# Define dictionary with default params that can be overridden from subclasses or CLI
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=AttentionLayer.__name__,
pool_type=AttentionLayer.__name__,
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,

Просмотреть файл

@ -12,7 +12,7 @@ from pytorch_lightning.callbacks import Callback
from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_ml.utils import fixed_paths
from histopathology.datamodules.base_module import CacheMode, CacheLocation
from histopathology.datamodules.panda_module import PandaTilesDataModule
@ -44,7 +44,9 @@ class DeepSMILEPanda(BaseMIL):
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
pool_type=AttentionLayer.__name__,
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,

Просмотреть файл

@ -27,7 +27,6 @@ from histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_t
from histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey
from histopathology.utils.viz_utils import load_image_dict
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
@ -45,9 +44,8 @@ class DeepMILModule(LightningModule):
label_column: str,
n_classes: int,
encoder: TileEncoder,
pooling_layer: Callable[[int, int, int], nn.Module],
pool_hidden_dim: int = 128,
pool_out_dim: int = 1,
pooling_layer: Callable[[Tensor], Tuple[Tensor, Tensor]],
num_features: int,
dropout_rate: Optional[float] = None,
class_weights: Optional[Tensor] = None,
l_rate: float = 5e-4,
@ -61,14 +59,12 @@ class DeepMILModule(LightningModule):
is_finetune: bool = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes
should be set to 1.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
set to 1.
:param encoder: The tile encoder to use for feature extraction. If no encoding is needed,
you should use `IdentityEncoder`.
:param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
:param pool_out_dim: Output dimension of pooling layer (default=1).
:param pooling_layer: A pooling layer nn.module
:param num_features: Dimensions of the input encoding features * attention dim outputs
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param class_weights: Tensor containing class weights (default=None).
:param l_rate: Optimiser learning rate.
@ -86,13 +82,11 @@ class DeepMILModule(LightningModule):
# Dataset specific attributes
self.label_column = label_column
self.n_classes = n_classes
self.pool_hidden_dim = pool_hidden_dim
self.pool_out_dim = pool_out_dim
self.pooling_layer = pooling_layer
self.dropout_rate = dropout_rate
self.class_weights = class_weights
self.encoder = encoder
self.num_encoding = self.encoder.num_encoding
self.aggregation_fn = pooling_layer
self.num_pooling = num_features
if class_names is not None:
self.class_names = class_names
@ -125,7 +119,6 @@ class DeepMILModule(LightningModule):
# Finetuning attributes
self.is_finetune = is_finetune
self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
self.activation_fn = self.get_activation()
@ -135,13 +128,6 @@ class DeepMILModule(LightningModule):
self.val_metrics = self.get_metrics()
self.test_metrics = self.get_metrics()
def get_pooling(self) -> Tuple[Callable, int]:
pooling_layer = self.pooling_layer(self.num_encoding,
self.pool_hidden_dim,
self.pool_out_dim)
num_features = self.num_encoding * self.pool_out_dim
return pooling_layer, num_features
def get_classifier(self) -> Callable:
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
@ -284,7 +270,6 @@ class DeepMILModule(LightningModule):
if is_global_rank_zero():
logging.warning("Coordinates not found in batch. If this is not expected check your"
"input tiles dataset.")
return results
def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Tuple
from unittest.mock import MagicMock
import pytest
@ -14,11 +14,8 @@ from torch.utils.data._utils.collate import default_collate
from torchvision.models import resnet18
from health_ml.lightning_container import LightningContainer
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)
from health_ml.networks.layers.attention_layers import AttentionLayer
from histopathology.configs.classification.DeepSMILECrck import DeepSMILECrck
from histopathology.configs.classification.DeepSMILEPanda import DeepSMILEPanda
@ -34,12 +31,22 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
def get_attention_pooling_layer(num_encoding: int = 512,
pool_out_dim: int = 1) -> Tuple[Type[nn.Module], int]:
pool_hidden_dim = 5 # different dimensions get tested in test_attentionlayers.py
pooling_layer = AttentionLayer(num_encoding,
pool_hidden_dim,
pool_out_dim)
num_features = num_encoding * pool_out_dim
return pooling_layer, num_features
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
@ -48,13 +55,16 @@ def _test_lightningmodule(
# hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
encoder = get_supervised_imagenet_encoder()
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=pool_out_dim)
module = DeepMILModule(
encoder=encoder,
label_column="label",
n_classes=n_classes,
pooling_layer=pooling_layer,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
num_features=num_features,
dropout_rate=dropout_rate,
)
@ -70,7 +80,7 @@ def _test_lightningmodule(
bag_labels_list.append(module.get_bag_label(labels))
logit, attn = module(bag)
assert logit.shape == (1, n_classes)
assert attn.shape == (module.pool_out_dim, max_bag_size)
assert attn.shape == (pool_out_dim, max_bag_size)
bag_logits_list.append(logit.view(-1))
bag_attn_list.append(attn)
@ -110,49 +120,24 @@ def _test_lightningmodule(
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)
def validate_metric_inputs(scores: torch.Tensor, labels: torch.Tensor) -> None:
def is_integral(x: torch.Tensor) -> bool:
return (x == x.long()).all() # type: ignore
@ -172,11 +157,17 @@ def add_callback(fn: Callable, callback: Callable) -> Callable:
def test_metrics() -> None:
input_dim = (128,)
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
pooling_layer, num_features = get_attention_pooling_layer(num_encoding=input_dim[0],
pool_out_dim=1)
module = DeepMILModule(
encoder=IdentityEncoder(input_dim=input_dim),
label_column=TilesDataset.LABEL_COLUMN,
n_classes=1,
pooling_layer=AttentionLayer,
pooling_layer=pooling_layer,
num_features=num_features
)
# Patching to enable running the module without a Trainer object
@ -326,13 +317,16 @@ def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> N
def test_class_weights_binary() -> None:
class_weights = Tensor([0.5, 3.5])
n_classes = 1
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=1)
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
pooling_layer=pooling_layer,
num_features=num_features,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))
@ -351,13 +345,16 @@ def test_class_weights_binary() -> None:
def test_class_weights_multiclass() -> None:
class_weights = Tensor([0.33, 0.33, 0.33])
n_classes = 3
# hard-coded here to avoid test explosion; correctness of other pooling layers is tested elsewhere
pooling_layer, num_features = get_attention_pooling_layer(pool_out_dim=1)
module = DeepMILModule(
encoder=get_supervised_imagenet_encoder(),
label_column="label",
n_classes=n_classes,
pooling_layer=AttentionLayer,
pool_hidden_dim=5,
pool_out_dim=1,
pooling_layer=pooling_layer,
num_features=num_features,
class_weights=class_weights,
)
logits = Tensor(randn(1, n_classes))

Просмотреть файл

@ -6,25 +6,22 @@
Created using the original DeepMIL paper and code from Ilse et al., 2018
https://github.com/AMLab-Amsterdam/AttentionDeepMIL (MIT License)
"""
from typing import Any, Tuple
from typing import Tuple, Optional
from torch import nn, Tensor, transpose, mm
import torch
import torch.nn.functional as F
from torch.nn import Module, TransformerEncoderLayer
class MeanPoolingLayer(nn.Module):
"""Mean pooling returns uniform weights and the average feature vector over the first axis"""
# args/kwargs added here for compatibility with parametrised pooling modules
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
num_instances = features.shape[0]
A = torch.full((1, num_instances), 1. / num_instances)
M = features.mean(dim=0)
M = M.view(1, -1)
return (A, M)
attention_weights = torch.full((1, num_instances), 1. / num_instances)
pooled_features = features.mean(dim=0)
pooled_features = pooled_features.view(1, -1)
return (attention_weights, pooled_features)
class AttentionLayer(nn.Module):
@ -48,12 +45,12 @@ class AttentionLayer(nn.Module):
)
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
H = features.view(-1, self.input_dims) # N x L
A = self.attention(H) # N x K
A = transpose(A, 1, 0) # K x N
A = F.softmax(A, dim=1) # Softmax over N : K x N
M = mm(A, H) # Matrix multiplication : K x L
return(A, M)
features = features.view(-1, self.input_dims) # N x L
attention_weights = self.attention(features) # N x K
attention_weights = transpose(attention_weights, 1, 0) # K x N
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
return(attention_weights, pooled_features)
class GatedAttentionLayer(nn.Module):
@ -81,11 +78,137 @@ class GatedAttentionLayer(nn.Module):
self.attention_weights = nn.Linear(self.hidden_dims, self.attention_dims)
def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
H = features.view(-1, self.input_dims) # N x L
A_V = self.attention_V(H) # N x D
A_U = self.attention_U(H) # N x D
A = self.attention_weights(A_V * A_U) # Element-wise multiplication : N x K
A = transpose(A, 1, 0) # K x N
A = F.softmax(A, dim=1) # Softmax over N : K x N
M = mm(A, H) # Matrix multiplication : K x L
return(A, M)
features = features.view(-1, self.input_dims) # N x L
A_V = self.attention_V(features) # N x D
A_U = self.attention_U(features) # N x D
attention_weights = self.attention_weights(A_V * A_U) # Element-wise multiplication : N x K
attention_weights = transpose(attention_weights, 1, 0) # K x N
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
return(attention_weights, pooled_features)
class CustomTransformerEncoderLayer(TransformerEncoderLayer):
"""Adaptation of the pytorch TransformerEncoderLayer that always outputs the attention weights.
TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, layer norm is done prior to attention and feedforward
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out, attention_weights = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out, attention_weights = encoder_layer(src)
"""
# new forward returns output as well as attention weights
def forward(self, src: torch.Tensor, # type: ignore
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = src
if self.norm_first:
sa_block_out, a = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + sa_block_out
x = x + self._ff_block(self.norm2(x))
else:
sa_block_out, a = self._sa_block(x, src_mask, src_key_padding_mask)
x = self.norm1(x + sa_block_out)
x = self.norm2(x + self._ff_block(x))
return x, a
# new self-attention block, returns output as well as attention weights
def _sa_block(self, x: Tensor, # type: ignore
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
x, a = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=True) # Just because of this flag I had to copy all of the code...
x = x[0]
return self.dropout1(x), a
class TransformerPooling(Module):
"""Create a Transformer encoder module consisting of multiple Transformer encoder layers.
We use a additional classification token (cls token) for pooling like seen in ViT/Bert. First, the cls token is
appended to the list of tiles encodings. Second, we perform self-attention between all tile encodings and the cls
token. Last, we extract the cls token and use it for classification.
Args:
num_layers: Number of Transformer encoder layers.
num_heads: Number of attention heads per layer.
dim_representation: Dimension of input encoding.
"""
def __init__(self, num_layers: int, num_heads: int, dim_representation: int) -> None:
super(TransformerPooling, self).__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_representation = dim_representation
self.cls_token = nn.Parameter(torch.zeros([1, dim_representation]))
self.transformer_encoder_layers = []
for _ in range(self.num_layers):
self.transformer_encoder_layers.append(
CustomTransformerEncoderLayer(self.dim_representation,
self.num_heads,
dim_feedforward=self.dim_representation,
dropout=0.1,
activation=F.gelu,
batch_first=True))
self.transformer_encoder_layers = torch.nn.ModuleList(self.transformer_encoder_layers) # type: ignore
def forward(self, features: Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Append cls token
features = torch.vstack([self.cls_token, features]).unsqueeze(0)
for i in range(self.num_layers):
features, attention_weights = self.transformer_encoder_layers[i](features)
# Extract cls token
pooled_features = features[:, 0]
# Get attention weights with respect to the cls token, without the element where it attends to itself
self_attention_cls_token = attention_weights[0, 0, 0] # type: ignore
attention_weights = attention_weights[:, 0, 1:] # type: ignore
# We want A to sum to one, simple hack: add self_attention_cls_token/num_tiles to each element
attention_weights += self_attention_cls_token / attention_weights.shape[-1]
return (attention_weights, pooled_features)

Просмотреть файл

@ -4,7 +4,7 @@ from typing import Type, Union
from torch import nn, rand, sum, allclose, ones_like
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
MeanPoolingLayer)
MeanPoolingLayer, TransformerPooling)
def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
@ -18,8 +18,11 @@ def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
row_sums = sum(attn_weights, dim=1, keepdim=True)
assert allclose(row_sums, ones_like(row_sums))
pooled_features = attn_weights @ features.flatten(start_dim=1)
assert allclose(pooled_features, output_features)
if isinstance(attentionlayer, TransformerPooling):
pass
else:
pooled_features = attn_weights @ features.flatten(start_dim=1)
assert allclose(pooled_features, output_features)
@pytest.mark.parametrize("dim_in", [1, 3])
@ -41,3 +44,14 @@ def test_attentionlayer(dim_in: int, dim_hid: int, dim_att: int, batch_size: int
@pytest.mark.parametrize("batch_size", [1, 7])
def test_mean_pooling(dim_in: int, batch_size: int,) -> None:
_test_attention_layer(MeanPoolingLayer(), dim_in=dim_in, dim_att=1, batch_size=batch_size)
@pytest.mark.parametrize("num_layers", [1, 4])
@pytest.mark.parametrize("num_heads", [1, 2])
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
@pytest.mark.parametrize("batch_size", [1, 7])
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int, batch_size: int) -> None:
transformer_pooling = TransformerPooling(num_layers=num_layers,
num_heads=num_heads,
dim_representation=dim_in).eval()
_test_attention_layer(transformer_pooling, dim_in=dim_in, dim_att=1, batch_size=batch_size)