From 71bdc076dd4ba2f3264283d4bc8617755206dccd Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Tue, 2 Feb 2021 20:06:52 +0100 Subject: [PATCH] Add head_mask and decoder_head_mask to PyTorch LED (#9856) * Add {decoder_,}head_mask to LED * Fix create_custom_forward signatue in encoder * Add head_mask to longformer * Add head_mask to longformer to fix dependencies of LED on Longformer. * Not working yet * Add mising one input in longofrmer_modeling.py * make fix-copies --- src/transformers/models/led/modeling_led.py | 172 +++++++++++++++++- .../models/longformer/modeling_longformer.py | 59 +++++- tests/test_modeling_common.py | 1 - tests/test_modeling_led.py | 12 +- tests/test_modeling_longformer.py | 1 - 5 files changed, 238 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 7e04e95de..64efdf619 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module): self, hidden_states, attention_mask=None, + layer_head_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, @@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module): attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = attn_probs.type_as(attn_scores) @@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module): global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, @@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module): self, hidden_states, max_num_global_attn_indices, + layer_head_mask, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, @@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module): global_attn_scores, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability + # apply layer head masking + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view( + batch_size, self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs_float = global_attn_probs_float.view( + batch_size * self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs = F.dropout( global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training ) @@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, is_index_masked: Optional[torch.Tensor] = None, is_index_global_attn: Optional[torch.Tensor] = None, is_global_attn: Optional[bool] = None, @@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module): self_outputs = self.longformer_self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module): attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit akward, but it's required to @@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module): self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, @@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module): hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. """ residual = hidden_states attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. @@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. """ last_hidden_state: torch.FloatTensor = None @@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None + head_mask: Optional[torch.FloatTensor] = None + decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None + head_mask: Optional[torch.FloatTensor] = None + decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None + head_mask: Optional[torch.FloatTensor] = None + decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None + head_mask: Optional[torch.FloatTensor] = None + decoder_head_mask: Optional[torch.FloatTensor] = None LED_START_DOCSTRING = r""" @@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r""" - 0 for local attention (a sliding window attention), - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel): - 0 for local attention (a sliding window attention), - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel): all_attentions = () if output_attentions else None all_global_attentions = () if (output_attentions and is_global_attn) else None - for encoder_layer in self.layers: + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, ) @@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel): layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel): global_attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, global_attention_mask=None, past_key_values=None, @@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel): input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, global_attention_mask=None, past_key_values=None, @@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, ): # cut decoder_input_ids if past is used if past is not None: @@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } @@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, global_attention_mask=None, inputs_embeds=None, @@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel): decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, global_attention_mask=None, start_positions=None, @@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 2754afc34..df850524f 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module): self, hidden_states, attention_mask=None, + layer_head_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, @@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module): attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = attn_probs.type_as(attn_scores) @@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module): global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, @@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module): self, hidden_states, max_num_global_attn_indices, + layer_head_mask, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, @@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module): global_attn_scores, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability + # apply layer head masking + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view( + batch_size, self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs_float = global_attn_probs_float.view( + batch_size * self.num_heads, max_num_global_attn_indices, seq_len + ) + global_attn_probs = F.dropout( global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training ) @@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module): self, hidden_states, attention_mask=None, + layer_head_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, @@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module): self_outputs = self.self( hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module): self, hidden_states, attention_mask=None, + layer_head_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, @@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module): self_attn_outputs = self.attention( hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module): self, hidden_states, attention_mask=None, + head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module): all_attentions = () if output_attentions else None # All local attentions. all_global_attentions = () if (output_attentions and is_global_attn) else None - for i, layer_module in enumerate(self.layer): + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layer) + ), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}." + for idx, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module): create_custom_forward(layer_module), hidden_states, attention_mask, + head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, ) @@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module): layer_outputs = layer_module( hidden_states, attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r""" - 0 for local attention (a sliding window attention), - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: @@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel): encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, + head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): input_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, + head_mask=head_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): token_type_ids=None, attention_mask=None, global_attention_mask=None, + head_mask=None, labels=None, position_ids=None, inputs_embeds=None, @@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, global_attention_mask=flat_global_attention_mask, + head_mask=head_mask, inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4a0cf5e1c..2fe722a2b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -473,7 +473,6 @@ class ModelTesterMixin: arg_names = [*signature.parameters.keys()] if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model inputs["decoder_head_mask"] = head_mask - outputs = model(**inputs, return_dict=True) # Test that we can get a gradient back for importance score computation diff --git a/tests/test_modeling_led.py b/tests/test_modeling_led.py index 0e9990778..416606014 100644 --- a/tests/test_modeling_led.py +++ b/tests/test_modeling_led.py @@ -49,16 +49,24 @@ def prepare_led_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -160,9 +168,10 @@ class LEDModelTester: model = LEDModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index b577640e6..96333fced 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -273,7 +273,6 @@ class LongformerModelTester: @require_torch class LongformerModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False # pruning is not supported - test_headmasking = False # head masking is not supported test_torchscript = False all_model_classes = (