Add head_mask and decoder_head_mask to PyTorch LED (#9856)

* Add {decoder_,}head_mask to LED

* Fix create_custom_forward signatue in encoder

* Add head_mask to longformer

* Add head_mask to longformer to fix dependencies
of LED on Longformer.

* Not working yet

* Add mising one input in longofrmer_modeling.py

* make fix-copies
This commit is contained in:
Daniel Stancl 2021-02-02 20:06:52 +01:00 коммит произвёл GitHub
Родитель d6217fb30c
Коммит 71bdc076dd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 238 добавлений и 7 удалений

Просмотреть файл

@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module):
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module):
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores)
@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module):
self,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
# apply layer head masking
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs_float = global_attn_probs_float.view(
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
)
@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
is_index_masked: Optional[torch.Tensor] = None,
is_index_global_attn: Optional[torch.Tensor] = None,
is_global_attn: Optional[bool] = None,
@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module):
self_outputs = self.longformer_self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module):
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module):
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit akward, but it's required to
@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module):
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
"""
residual = hidden_states
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module):
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions.
This requires the attentions tensor to be reshaped in this function.
@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
last_hidden_state: torch.FloatTensor = None
@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass
@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
loss: Optional[torch.FloatTensor] = None
@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass
@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
loss: Optional[torch.FloatTensor] = None
@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass
@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
loss: Optional[torch.FloatTensor] = None
@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
LED_START_DOCSTRING = r"""
@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel):
- 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel):
all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
for encoder_layer in self.layers:
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel):
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
)
@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel):
global_attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel):
all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel):
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None,
None,
)
else:
@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel):
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
global_attention_mask=None,
past_key_values=None,
@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel):
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel):
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
global_attention_mask=None,
past_key_values=None,
@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
global_attention_mask=None,
inputs_embeds=None,
@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
global_attention_mask=None,
start_positions=None,
@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,

Просмотреть файл

@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module):
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module):
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores)
@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module):
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module):
self,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
# apply layer head masking
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs_float = global_attn_probs_float.view(
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
)
@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module):
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module):
self_outputs = self.self(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module):
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module):
self_attn_outputs = self.attention(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module):
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module):
all_attentions = () if output_attentions else None # All local attentions.
all_global_attentions = () if (output_attentions and is_global_attn) else None
for i, layer_module in enumerate(self.layer):
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layer)
), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
for idx, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module):
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
)
@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module):
layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel):
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
labels=None,
position_ids=None,
inputs_embeds=None,
@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,

Просмотреть файл

@ -473,7 +473,6 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation

Просмотреть файл

@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@ -160,9 +168,10 @@ class LEDModelTester:
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):

Просмотреть файл

@ -273,7 +273,6 @@ class LongformerModelTester:
@require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (