Add head_mask and decoder_head_mask to TF LED (#9988)
* Add head masking to TF LED * Add head_mask to Longformer + one doc piece to LED * Fix integration tests
This commit is contained in:
Родитель
77c0ce8c0c
Коммит
e7381c4596
|
@ -200,6 +200,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
|
@ -275,6 +276,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||
attn_probs,
|
||||
)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||
|
||||
# apply dropout
|
||||
attn_probs = self.dropout(attn_probs, training=training)
|
||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
@ -310,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||
attn_output=attn_output,
|
||||
hidden_states=hidden_states,
|
||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||
|
@ -752,6 +762,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||
attn_output,
|
||||
hidden_states,
|
||||
max_num_global_attn_indices,
|
||||
layer_head_mask,
|
||||
is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
|
@ -817,6 +828,20 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||
# compute global attn probs
|
||||
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
|
||||
|
||||
# apply layer head maskin
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(
|
||||
global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
|
||||
# dropout
|
||||
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
|
||||
|
||||
|
@ -875,13 +900,14 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
|
|||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
) = inputs
|
||||
|
||||
self_outputs = self.longformer_self_attn(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -927,6 +953,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
@ -993,6 +1020,17 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
@ -1031,6 +1069,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
|||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
layer_head_mask: tf.Tensor,
|
||||
is_index_masked: tf.Tensor,
|
||||
is_index_global_attn: tf.Tensor,
|
||||
is_global_attn: bool,
|
||||
|
@ -1041,10 +1080,12 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
layer_outputs = self.self_attn(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -1104,6 +1145,8 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
|
@ -1115,6 +1158,10 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
@ -1127,6 +1174,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
@ -1143,6 +1191,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
@ -1438,6 +1487,18 @@ LED_INPUTS_DOCSTRING = r"""
|
|||
shifting the input_ids right, following the paper.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
|
@ -1517,6 +1578,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
|
@ -1541,6 +1603,12 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
|
@ -1559,6 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -1617,8 +1686,15 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
|
||||
|
@ -1631,6 +1707,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
|
@ -1753,6 +1830,8 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
|
@ -1784,6 +1863,19 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding. If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||
|
@ -1810,6 +1902,8 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
|
@ -1865,6 +1959,14 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
@ -1881,6 +1983,10 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
|
@ -1950,6 +2056,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
|
@ -1969,6 +2077,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
|||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -1990,6 +2100,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
|||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
|
@ -2012,6 +2123,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
|||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
|
@ -2065,6 +2178,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
|
@ -2084,6 +2199,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -2103,6 +2220,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
|||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
|
@ -2180,6 +2299,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFLEDEncoderBaseModelOutput] = None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
|
@ -2217,6 +2338,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -2245,6 +2368,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
|
|
|
@ -719,6 +719,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
|
@ -794,6 +795,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||
attn_probs,
|
||||
)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||
|
||||
# apply dropout
|
||||
attn_probs = self.dropout(attn_probs, training=training)
|
||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
@ -829,6 +838,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||
attn_output=attn_output,
|
||||
hidden_states=hidden_states,
|
||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||
layer_head_mask=layer_head_mask,
|
||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||
|
@ -1271,6 +1281,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||
attn_output,
|
||||
hidden_states,
|
||||
max_num_global_attn_indices,
|
||||
layer_head_mask,
|
||||
is_local_index_global_attn_nonzero,
|
||||
is_index_global_attn_nonzero,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
|
@ -1336,6 +1347,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||
# compute global attn probs
|
||||
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
|
||||
|
||||
# apply layer head maskin
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(
|
||||
global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
|
||||
# dropout
|
||||
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
|
||||
|
||||
|
@ -1398,13 +1423,14 @@ class TFLongformerAttention(tf.keras.layers.Layer):
|
|||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
) = inputs
|
||||
|
||||
self_outputs = self.self_attention(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
training=training,
|
||||
)
|
||||
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
|
||||
|
@ -1425,13 +1451,14 @@ class TFLongformerLayer(tf.keras.layers.Layer):
|
|||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
) = inputs
|
||||
|
||||
attention_outputs = self.attention(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
training=training,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
@ -1469,7 +1496,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = all_global_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
for idx, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||
|
@ -1478,6 +1505,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||
[
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
|
@ -1558,6 +1586,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
|
@ -1573,6 +1602,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1649,6 +1679,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
padding_len=padding_len,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
|
@ -1842,6 +1873,12 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
|||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to decide the attention given on each token, local attention or global attention. Tokens with global
|
||||
attention attends to all other tokens, and all other tokens attend to them. This is important for
|
||||
|
@ -1918,6 +1955,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
|
@ -1933,6 +1971,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1946,6 +1985,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
|||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
|
@ -2004,6 +2044,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
|
@ -2026,6 +2067,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -2040,6 +2082,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
|||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
|
@ -2109,6 +2152,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
|
@ -2136,6 +2180,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -2170,6 +2215,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
|||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
|
@ -2274,6 +2320,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
|
@ -2290,6 +2337,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -2321,6 +2369,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
|||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
|
@ -2397,6 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
|
@ -2419,6 +2469,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -2464,6 +2515,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
|||
position_ids=flat_position_ids,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=flat_global_attention_mask,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -2547,6 +2599,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
|||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
|
@ -2568,6 +2621,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
|||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -2582,6 +2636,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
|||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
|
|
|
@ -162,6 +162,8 @@ def prepare_led_inputs_dict(
|
|||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
|
@ -173,11 +175,17 @@ def prepare_led_inputs_dict(
|
|||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLEDModelTester(self)
|
||||
|
|
|
@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLongformerModelTester(self)
|
||||
|
@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
|
||||
)[0]
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
|
@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[
|
||||
hidden_states,
|
||||
-tf.math.abs(attention_mask),
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
|
@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states, local_attentions, global_attentions = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[
|
||||
hidden_states,
|
||||
-tf.math.abs(attention_mask),
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
|
||||
|
|
Загрузка…
Ссылка в новой задаче