ENH: Dropout in transformer layers in DeepMIL (#590)

In this PR:
- transformer_dropout parameter is added to TransformerPooling and TransformerPoolingBenchmark pooling layers.
- Tests are updated with the transformer_dropout parameter.
This commit is contained in:
Harshita Sharma 2022-09-05 11:20:24 +01:00 коммит произвёл GitHub
Родитель 763dc7f865
Коммит f7c21a622a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 61 добавлений и 24 удалений

Просмотреть файл

@ -127,8 +127,13 @@ class PoolingParams(param.Parameterized):
doc="If True (default), fine-tune the pooling layer during training. If False, keep the pooling layer frozen.",
)
pretrained_pooling = param.Boolean(
False, doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the pooling"
"layer. Else (False), initialize the pooling layer randomly."
default=False,
doc="If True, transfer weights from the pretrained model (specified in `src_checkpoint`) to the pooling"
"layer. Else (False), initialize the pooling layer randomly.",
)
transformer_dropout: float = param.Number(
default=0.0,
doc="If transformer pooling is chosen, this defines the dropout of the tranformer encoder layers.",
)
def get_pooling_layer(self, num_encoding: int) -> Tuple[nn.Module, int]:
@ -141,22 +146,31 @@ class PoolingParams(param.Parameterized):
"""
pooling_layer: nn.Module
if self.pool_type == AttentionLayer.__name__:
pooling_layer = AttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
pooling_layer = AttentionLayer(input_dims=num_encoding,
hidden_dims=self.pool_hidden_dim,
attention_dims=self.pool_out_dim)
elif self.pool_type == GatedAttentionLayer.__name__:
pooling_layer = GatedAttentionLayer(num_encoding, self.pool_hidden_dim, self.pool_out_dim)
pooling_layer = GatedAttentionLayer(input_dims=num_encoding,
hidden_dims=self.pool_hidden_dim,
attention_dims=self.pool_out_dim)
elif self.pool_type == MeanPoolingLayer.__name__:
pooling_layer = MeanPoolingLayer()
elif self.pool_type == MaxPoolingLayer.__name__:
pooling_layer = MaxPoolingLayer()
elif self.pool_type == TransformerPooling.__name__:
pooling_layer = TransformerPooling(
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding
)
num_layers=self.num_transformer_pool_layers,
num_heads=self.num_transformer_pool_heads,
dim_representation=num_encoding,
transformer_dropout=self.transformer_dropout)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
elif self.pool_type == TransformerPoolingBenchmark.__name__:
pooling_layer = TransformerPoolingBenchmark(
self.num_transformer_pool_layers, self.num_transformer_pool_heads, num_encoding, self.pool_hidden_dim
)
num_layers=self.num_transformer_pool_layers,
num_heads=self.num_transformer_pool_heads,
dim_representation=num_encoding,
hidden_dim=self.pool_hidden_dim,
transformer_dropout=self.transformer_dropout)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
else:
raise ValueError(f"Unsupported pooling type: {self.pool_type}")

Просмотреть файл

@ -47,11 +47,13 @@ def get_attention_pooling_layer_params(pool_out_dim: int = 1, tune_pooling: bool
tune_pooling=tune_pooling)
def get_transformer_pooling_layer_params(num_layers: int, num_heads: int, hidden_dim: int) -> PoolingParams:
def get_transformer_pooling_layer_params(num_layers: int, num_heads: int,
hidden_dim: int, transformer_dropout: float) -> PoolingParams:
return PoolingParams(pool_type=TransformerPoolingBenchmark.__name__,
num_transformer_pool_layers=num_layers,
num_transformer_pool_heads=num_heads,
pool_hidden_dim=hidden_dim)
pool_hidden_dim=hidden_dim,
transformer_dropout=transformer_dropout)
def _test_lightningmodule(
@ -563,12 +565,13 @@ def _get_tiles_deepmil_module(
num_layers: int = 2,
num_heads: int = 1,
hidden_dim: int = 8,
transformer_dropout: float = 0.1
) -> TilesDeepMILModule:
module = TilesDeepMILModule(
n_classes=n_classes,
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim),
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim, transformer_dropout),
)
module.encoder_params.pretrained_encoder = pretrained_encoder
module.pooling_params.pretrained_pooling = pretrained_pooling

Просмотреть файл

@ -187,14 +187,15 @@ class TransformerPooling(Module):
num_layers: Number of Transformer encoder layers.
num_heads: Number of attention heads per layer.
dim_representation: Dimension of input encoding.
transformer_dropout: The dropout value of transformer encoder layers.
"""
def __init__(self, num_layers: int, num_heads: int, dim_representation: int) -> None:
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, transformer_dropout: float) -> None:
super(TransformerPooling, self).__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_representation = dim_representation
self.transformer_dropout = transformer_dropout
self.cls_token = nn.Parameter(torch.zeros([1, dim_representation]))
self.transformer_encoder_layers = []
@ -203,7 +204,7 @@ class TransformerPooling(Module):
CustomTransformerEncoderLayer(self.dim_representation,
self.num_heads,
dim_feedforward=self.dim_representation,
dropout=0.1,
dropout=self.transformer_dropout,
activation=F.gelu,
batch_first=True))
self.transformer_encoder_layers = torch.nn.ModuleList(self.transformer_encoder_layers) # type: ignore
@ -239,17 +240,21 @@ class TransformerPoolingBenchmark(Module):
num_layers: Number of Transformer encoder layers.
num_heads: Number of attention heads per layer.
dim_representation: Dimension of input encoding.
transformer_dropout: The dropout value of transformer encoder layers.
"""
def __init__(self, num_layers: int, num_heads: int, dim_representation: int, hidden_dim: int) -> None:
def __init__(self, num_layers: int, num_heads: int,
dim_representation: int, hidden_dim: int,
transformer_dropout: float) -> None:
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_representation = dim_representation
self.hidden_dim = hidden_dim
self.transformer_dropout = transformer_dropout
transformer_layer = nn.TransformerEncoderLayer(d_model=self.dim_representation,
nhead=self.num_heads,
dropout=0.0,
dropout=self.transformer_dropout,
batch_first=True)
self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=self.num_layers)
self.attention = nn.Sequential(nn.Linear(self.dim_representation, self.hidden_dim),

Просмотреть файл

@ -5,7 +5,7 @@ from torch import nn, rand, sum, allclose, ones_like
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
MeanPoolingLayer, TransformerPooling,
MaxPoolingLayer)
MaxPoolingLayer, TransformerPoolingBenchmark)
def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
@ -19,11 +19,7 @@ def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
row_sums = sum(attn_weights, dim=1, keepdim=True)
assert allclose(row_sums, ones_like(row_sums))
if isinstance(attentionlayer, TransformerPooling):
pass
elif isinstance(attentionlayer, MaxPoolingLayer):
pass
else:
if not isinstance(attentionlayer, (MaxPoolingLayer, TransformerPooling, TransformerPoolingBenchmark)):
pooled_features = attn_weights @ features.flatten(start_dim=1)
assert allclose(pooled_features, output_features)
@ -59,8 +55,27 @@ def test_max_pooling(dim_in: int, batch_size: int,) -> None:
@pytest.mark.parametrize("num_heads", [1, 2])
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
@pytest.mark.parametrize("batch_size", [1, 7])
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int, batch_size: int) -> None:
def test_transformer_pooling(num_layers: int, num_heads: int, dim_in: int,
batch_size: int) -> None:
transformer_dropout = 0.5
transformer_pooling = TransformerPooling(num_layers=num_layers,
num_heads=num_heads,
dim_representation=dim_in).eval()
dim_representation=dim_in,
transformer_dropout=transformer_dropout).eval()
_test_attention_layer(transformer_pooling, dim_in=dim_in, dim_att=1, batch_size=batch_size)
@pytest.mark.parametrize("num_layers", [1, 4])
@pytest.mark.parametrize("num_heads", [1, 2])
@pytest.mark.parametrize("dim_in", [4, 8]) # dim_in % num_heads must be 0
@pytest.mark.parametrize("batch_size", [1, 7])
@pytest.mark.parametrize("dim_hid", [1, 4])
def test_transformer_pooling_benchmark(num_layers: int, num_heads: int, dim_in: int,
batch_size: int, dim_hid: int) -> None:
transformer_dropout = 0.5
transformer_pooling_benchmark = TransformerPoolingBenchmark(num_layers=num_layers,
num_heads=num_heads,
dim_representation=dim_in,
hidden_dim=dim_hid,
transformer_dropout=transformer_dropout).eval()
_test_attention_layer(transformer_pooling_benchmark, dim_in=dim_in, dim_att=1, batch_size=batch_size)