Add head_mask and decoder_head_mask to PyTorch LED (#9856)
* Add {decoder_,}head_mask to LED * Fix create_custom_forward signatue in encoder * Add head_mask to longformer * Add head_mask to longformer to fix dependencies of LED on Longformer. * Not working yet * Add mising one input in longofrmer_modeling.py * make fix-copies
This commit is contained in:
Родитель
d6217fb30c
Коммит
71bdc076dd
|
@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_head_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
|
@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||
|
||||
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||
attn_probs = attn_probs.type_as(attn_scores)
|
||||
|
@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||
hidden_states=hidden_states,
|
||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||
|
@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
max_num_global_attn_indices,
|
||||
layer_head_mask,
|
||||
is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
|
@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||
global_attn_scores, dim=-1, dtype=torch.float32
|
||||
) # use fp32 for numerical stability
|
||||
|
||||
# apply layer head masking
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
|
||||
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
|
||||
)
|
||||
global_attn_probs_float = global_attn_probs_float.view(
|
||||
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
|
||||
)
|
||||
|
||||
global_attn_probs = F.dropout(
|
||||
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
||||
)
|
||||
|
@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
is_index_masked: Optional[torch.Tensor] = None,
|
||||
is_index_global_attn: Optional[torch.Tensor] = None,
|
||||
is_global_attn: Optional[bool] = None,
|
||||
|
@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module):
|
|||
self_outputs = self.longformer_self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module):
|
|||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module):
|
|||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
|
@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
|
@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module):
|
|||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
attn_outputs = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module):
|
|||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
|
@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module):
|
|||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`): Whether the base model outputs attentions.
|
||||
This requires the attentions tensor to be reshaped in this function.
|
||||
|
@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module):
|
|||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module):
|
|||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
|||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
|
@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
|||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
|||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
|
@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
|||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
|
@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
|
@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
LED_START_DOCSTRING = r"""
|
||||
|
@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r"""
|
|||
|
||||
- 0 for local attention (a sliding window attention),
|
||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
|
@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
|
||||
- 0 for local attention (a sliding window attention),
|
||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
|
@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
all_attentions = () if output_attentions else None
|
||||
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
|
@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
)
|
||||
|
@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
global_attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
|
@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
|
@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
|
@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
combined_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
|
@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
|
@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
|
@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel):
|
|||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel):
|
|||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
|
@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
|
@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
|
@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
|
@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
start_positions=None,
|
||||
|
@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
|
|
|
@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_head_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
|
@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module):
|
|||
|
||||
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||
attn_probs = attn_probs.type_as(attn_scores)
|
||||
|
@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module):
|
|||
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||
hidden_states=hidden_states,
|
||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||
|
@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
max_num_global_attn_indices,
|
||||
layer_head_mask,
|
||||
is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
|
@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module):
|
|||
global_attn_scores, dim=-1, dtype=torch.float32
|
||||
) # use fp32 for numerical stability
|
||||
|
||||
# apply layer head masking
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
|
||||
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
|
||||
)
|
||||
global_attn_probs_float = global_attn_probs_float.view(
|
||||
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
|
||||
)
|
||||
|
||||
global_attn_probs = F.dropout(
|
||||
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
||||
)
|
||||
|
@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_head_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
|
@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module):
|
|||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_head_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
|
@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module):
|
|||
self_attn_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
|
@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module):
|
|||
all_attentions = () if output_attentions else None # All local attentions.
|
||||
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layer)
|
||||
), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module):
|
|||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
)
|
||||
|
@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module):
|
|||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
|||
- 0 for local attention (a sliding window attention),
|
||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
|
@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
|
|||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
labels=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
|
@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
|
|||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
global_attention_mask=flat_global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
|
|
@ -473,7 +473,6 @@ class ModelTesterMixin:
|
|||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
|
|
|
@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
|
|||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -160,9 +168,10 @@ class LEDModelTester:
|
|||
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
|
|
@ -273,7 +273,6 @@ class LongformerModelTester:
|
|||
@require_torch
|
||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False # pruning is not supported
|
||||
test_headmasking = False # head masking is not supported
|
||||
test_torchscript = False
|
||||
|
||||
all_model_classes = (
|
||||
|
|
Загрузка…
Ссылка в новой задаче