Add head_mask and decoder_head_mask to TF LED (#9988)

* Add head masking to TF LED

* Add head_mask to Longformer + one doc piece to LED

* Fix integration tests
This commit is contained in:
Daniel Stancl 2021-02-09 17:45:18 +01:00 коммит произвёл GitHub
Родитель 77c0ce8c0c
Коммит e7381c4596
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 217 добавлений и 11 удалений

Просмотреть файл

@ -200,6 +200,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
( (
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
@ -275,6 +276,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
attn_probs, attn_probs,
) )
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout # apply dropout
attn_probs = self.dropout(attn_probs, training=training) attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
@ -310,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
attn_output=attn_output, attn_output=attn_output,
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
@ -752,6 +762,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
attn_output, attn_output,
hidden_states, hidden_states,
max_num_global_attn_indices, max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
@ -817,6 +828,20 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute global attn probs # compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# apply layer head maskin
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
global_attn_probs_float = tf.reshape(
global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
)
# dropout # dropout
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
@ -875,13 +900,14 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
( (
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
) = inputs ) = inputs
self_outputs = self.longformer_self_attn( self_outputs = self.longformer_self_attn(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training, training=training,
) )
@ -927,6 +953,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@ -993,6 +1020,17 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
@ -1031,6 +1069,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
self, self,
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
layer_head_mask: tf.Tensor,
is_index_masked: tf.Tensor, is_index_masked: tf.Tensor,
is_index_global_attn: tf.Tensor, is_index_global_attn: tf.Tensor,
is_global_attn: bool, is_global_attn: bool,
@ -1041,10 +1080,12 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
""" """
residual = hidden_states residual = hidden_states
layer_outputs = self.self_attn( layer_outputs = self.self_attn(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training, training=training,
) )
@ -1104,6 +1145,8 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@ -1115,6 +1158,10 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`.
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@ -1127,6 +1174,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -1143,6 +1191,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@ -1438,6 +1487,18 @@ LED_INPUTS_DOCSTRING = r"""
shifting the input_ids right, following the paper. shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@ -1517,6 +1578,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
@ -1541,6 +1603,12 @@ class TFLEDEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@ -1559,6 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -1617,8 +1686,15 @@ class TFLEDEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
@ -1631,6 +1707,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
@ -1753,6 +1830,8 @@ class TFLEDDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@ -1784,6 +1863,19 @@ class TFLEDDecoder(tf.keras.layers.Layer):
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. If :obj:`past_key_values` are used, the user can optionally input only the last decoding. If :obj:`past_key_values` are used, the user can optionally input only the last
@ -1810,6 +1902,8 @@ class TFLEDDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@ -1865,6 +1959,14 @@ class TFLEDDecoder(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
@ -1881,6 +1983,10 @@ class TFLEDDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@ -1950,6 +2056,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
@ -1969,6 +2077,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -1990,6 +2100,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
@ -2012,6 +2123,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@ -2065,6 +2178,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
@ -2084,6 +2199,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -2103,6 +2220,8 @@ class TFLEDModel(TFLEDPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@ -2180,6 +2299,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFLEDEncoderBaseModelOutput] = None, encoder_outputs: Optional[TFLEDEncoderBaseModelOutput] = None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
@ -2217,6 +2338,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -2245,6 +2368,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],

Просмотреть файл

@ -719,6 +719,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
( (
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
@ -794,6 +795,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_probs, attn_probs,
) )
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout # apply dropout
attn_probs = self.dropout(attn_probs, training=training) attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
@ -829,6 +838,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output=attn_output, attn_output=attn_output,
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
@ -1271,6 +1281,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output, attn_output,
hidden_states, hidden_states,
max_num_global_attn_indices, max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
@ -1336,6 +1347,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute global attn probs # compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# apply layer head maskin
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
global_attn_probs_float = tf.reshape(
global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
)
# dropout # dropout
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
@ -1398,13 +1423,14 @@ class TFLongformerAttention(tf.keras.layers.Layer):
( (
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
) = inputs ) = inputs
self_outputs = self.self_attention( self_outputs = self.self_attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training, training=training,
) )
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
@ -1425,13 +1451,14 @@ class TFLongformerLayer(tf.keras.layers.Layer):
( (
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
) = inputs ) = inputs
attention_outputs = self.attention( attention_outputs = self.attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training, training=training,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
@ -1469,7 +1496,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if output_attentions else None all_attentions = all_global_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer): for idx, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,) all_hidden_states = all_hidden_states + (hidden_states_to_add,)
@ -1478,6 +1505,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
[ [
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn, is_global_attn,
@ -1558,6 +1586,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@ -1573,6 +1602,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1649,6 +1679,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask,
padding_len=padding_len, padding_len=padding_len,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
@ -1842,6 +1873,12 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to decide the attention given on each token, local attention or global attention. Tokens with global Mask to decide the attention given on each token, local attention or global attention. Tokens with global
attention attends to all other tokens, and all other tokens attend to them. This is important for attention attends to all other tokens, and all other tokens attend to them. This is important for
@ -1918,6 +1955,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@ -1933,6 +1971,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1946,6 +1985,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
@ -2004,6 +2044,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@ -2026,6 +2067,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -2040,6 +2082,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
@ -2109,6 +2152,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@ -2136,6 +2180,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -2170,6 +2215,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
@ -2274,6 +2320,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
global_attention_mask=None, global_attention_mask=None,
@ -2290,6 +2337,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -2321,6 +2369,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
@ -2397,6 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
global_attention_mask=None, global_attention_mask=None,
@ -2419,6 +2469,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -2464,6 +2515,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
position_ids=flat_position_ids, position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
head_mask=head_mask,
global_attention_mask=flat_global_attention_mask, global_attention_mask=flat_global_attention_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -2547,6 +2599,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
global_attention_mask=None, global_attention_mask=None,
@ -2568,6 +2621,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -2582,6 +2636,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],

Просмотреть файл

@ -162,6 +162,8 @@ def prepare_led_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@ -173,11 +175,17 @@ def prepare_led_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFLEDModelTester(self) self.model_tester = TFLEDModelTester(self)

Просмотреть файл

@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFLongformerModelTester(self) self.model_tester = TFLongformerModelTester(self)
@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
layer_head_mask = None
output_hidden_states = layer( output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn] [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
)[0] )[0]
expected_slice = tf.convert_to_tensor( expected_slice = tf.convert_to_tensor(
@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states = layer( output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn] [
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states, local_attentions, global_attentions = layer( output_hidden_states, local_attentions, global_attentions = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn] [
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
) )
self.assertEqual(local_attentions.shape, (2, 4, 2, 8)) self.assertEqual(local_attentions.shape, (2, 4, 2, 8))