Fix saved model creation (#5468)
* Fix TF Serving when output_hidden_states and output_attentions are True * Add tests for saved model creation + bug fix for multiple choices models * remove unused import * Fix the input for several layers * Fix test * Fix conflict printing * Apply style * Fix XLM and Flaubert for TensorFlow * Apply style * Fix TF check version * Apply style * Trigger CI
This commit is contained in:
Родитель
5a0dac53bf
Коммит
9996f697e3
|
@ -35,7 +35,6 @@ from .modeling_tf_utils import (
|
|||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -99,7 +98,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
|
@ -115,15 +122,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
|
@ -175,6 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
|
@ -192,9 +200,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
|
@ -233,9 +239,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
|
||||
outputs = (
|
||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
||||
)
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -248,9 +252,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
|
|||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, input_tensor = inputs
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
@ -262,6 +264,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||
super().__init__(config, **kwargs)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.output_attentions = config.output_attentions
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
@ -271,9 +274,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
input_tensor, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(input_tensor)[0]
|
||||
mixed_query_layer = self.query(input_tensor)
|
||||
mixed_key_layer = self.key(input_tensor)
|
||||
|
@ -312,9 +313,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
|
||||
self_outputs = (
|
||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
||||
)
|
||||
self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
hidden_states = self_outputs[0]
|
||||
|
||||
|
@ -349,11 +348,9 @@ class TFAlbertLayer(tf.keras.layers.Layer):
|
|||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
attention_outputs = self.attention(
|
||||
[hidden_states, attention_mask, head_mask, output_attentions], training=training
|
||||
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
ffn_output = self.ffn(attention_outputs[0])
|
||||
ffn_output = self.activation(ffn_output)
|
||||
|
@ -371,32 +368,32 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
|
|||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.albert_layers = [
|
||||
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
|
||||
]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||
layer_hidden_states = ()
|
||||
layer_attentions = ()
|
||||
|
||||
for layer_index, albert_layer in enumerate(self.albert_layers):
|
||||
layer_output = albert_layer(
|
||||
[hidden_states, attention_mask, head_mask[layer_index], output_attentions], training=training
|
||||
hidden_states, attention_mask, head_mask[layer_index], output_attentions, training=training
|
||||
)
|
||||
hidden_states = layer_output[0]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
layer_attentions = layer_attentions + (layer_output[1],)
|
||||
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (layer_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (layer_attentions,)
|
||||
# last-layer hidden state, (layer hidden states), (layer attentions)
|
||||
return outputs
|
||||
|
@ -417,13 +414,11 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
|
|||
for i in range(config.num_hidden_groups)
|
||||
]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||
all_attentions = ()
|
||||
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = (hidden_states,)
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
|
@ -434,27 +429,25 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
|
|||
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
||||
|
||||
layer_group_output = self.albert_layer_groups[group_idx](
|
||||
[
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
],
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + layer_group_output[-1]
|
||||
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
# last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
@ -619,9 +612,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
|||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -1274,7 +1271,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [
|
||||
outputs = self.albert(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
|
@ -1283,9 +1280,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
outputs = self.albert(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
|
|
|
@ -36,7 +36,6 @@ from .modeling_tf_utils import (
|
|||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -81,6 +80,7 @@ def gelu(x):
|
|||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
|
||||
|
||||
return x * cdf
|
||||
|
||||
|
||||
|
@ -94,6 +94,7 @@ def gelu_new(x):
|
|||
`x` with the GELU activation applied.
|
||||
"""
|
||||
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||
|
||||
return x * cdf
|
||||
|
||||
|
||||
|
@ -118,7 +119,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
|
@ -149,7 +149,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
|
@ -165,15 +173,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
|
@ -181,19 +189,22 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
|
@ -205,7 +216,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
|
@ -215,6 +225,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
|
@ -225,7 +236,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
|
@ -235,21 +245,18 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||
self.value = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
@ -277,15 +284,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = tf.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||
context_layer = tf.reshape(
|
||||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
|
||||
outputs = (
|
||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
||||
)
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -299,12 +302,11 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
|||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, input_tensor = inputs
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -317,14 +319,13 @@ class TFBertAttention(tf.keras.layers.Layer):
|
|||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
input_tensor, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||
self_outputs = self.self_attention(
|
||||
[input_tensor, attention_mask, head_mask, output_attentions], training=training
|
||||
input_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
|
||||
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -334,6 +335,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
|
@ -342,6 +344,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -354,12 +357,11 @@ class TFBertOutput(tf.keras.layers.Layer):
|
|||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, input_tensor = inputs
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -370,16 +372,15 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
||||
self.bert_output = TFBertOutput(config, name="output")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
attention_outputs = self.attention(
|
||||
[hidden_states, attention_mask, head_mask, output_attentions], training=training
|
||||
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.bert_output([intermediate_output, attention_output], training=training)
|
||||
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -388,32 +389,34 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
|||
super().__init__(**kwargs)
|
||||
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
|
||||
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
|
||||
|
||||
|
@ -432,6 +435,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
|||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
|
@ -441,16 +445,19 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
|||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -472,6 +479,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
|||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -482,6 +490,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
|
||||
return prediction_scores
|
||||
|
||||
|
||||
|
@ -494,6 +503,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
|
@ -507,7 +517,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFBertEncoder(config, name="encoder")
|
||||
self.pooler = TFBertPooler(config, name="pooler")
|
||||
|
@ -605,18 +614,22 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
@ -1211,8 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
outputs = self.bert(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
|
@ -1221,16 +1233,12 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
outputs = self.bert(flat_inputs, training=training)
|
||||
|
||||
training=training,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -27,7 +27,6 @@ from .modeling_tf_utils import (
|
|||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
|
@ -87,10 +86,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
|||
|
||||
|
||||
class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, d_model_size, num_heads, **kwargs):
|
||||
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_heads = num_heads
|
||||
self.d_model_size = d_model_size
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
self.depth = int(d_model_size / self.num_heads)
|
||||
|
||||
|
@ -104,8 +104,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||
batch_size = shape_list(q)[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
|
@ -121,10 +120,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
k = tf.concat((past_key, k), axis=-2)
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
|
||||
# to cope with keras serialization
|
||||
use_cache = cast_bool_to_primitive(use_cache, True)
|
||||
|
||||
if use_cache is True:
|
||||
if use_cache:
|
||||
present = tf.stack((k, v), axis=0)
|
||||
else:
|
||||
present = (None,)
|
||||
|
@ -134,10 +130,11 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
attn = output[1]
|
||||
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
|
||||
output = self.dense(original_size_attention)
|
||||
|
||||
outputs = (output, present)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -156,10 +153,16 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
|
|||
|
||||
|
||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
|
||||
def __init__(
|
||||
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
self.multi_head_attention = TFMultiHeadAttention(
|
||||
d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
|
||||
)
|
||||
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
|
||||
|
||||
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
||||
|
@ -168,11 +171,18 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||
self.dropout1 = tf.keras.layers.Dropout(rate)
|
||||
self.dropout2 = tf.keras.layers.Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
|
||||
normed,
|
||||
normed,
|
||||
normed,
|
||||
mask,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
|
@ -215,6 +225,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
config.dff,
|
||||
config.resid_pdrop,
|
||||
config.layer_norm_epsilon,
|
||||
self.output_attentions,
|
||||
name="h_._{}".format(i),
|
||||
)
|
||||
for i in range(config.n_layer)
|
||||
|
@ -367,31 +378,37 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
all_hidden_states = ()
|
||||
all_attentions = []
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h(
|
||||
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
|
||||
hidden_states,
|
||||
mask,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if use_cache is True:
|
||||
if use_cache:
|
||||
presents = presents + (present,)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if use_cache is True:
|
||||
if use_cache:
|
||||
outputs = outputs + (presents,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
|
|
|
@ -37,7 +37,6 @@ from .modeling_tf_utils import (
|
|||
TFSequenceClassificationLoss,
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -114,7 +113,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
||||
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, mode="embedding", training=False):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
|
||||
|
@ -130,13 +129,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
|
||||
return self._embedding(input_ids, position_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, inputs_embeds=None, training=False):
|
||||
def _embedding(self, input_ids, position_ids, inputs_embeds, training=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -148,11 +147,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||
embeddings: tf.Tensor(bs, max_seq_length, dim)
|
||||
The embedded tokens (plus position embeddings, no token_type embeddings)
|
||||
"""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs
|
||||
position_ids = None
|
||||
else:
|
||||
input_ids, position_ids = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
seq_length = shape_list(input_ids)[1]
|
||||
|
@ -194,6 +189,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
|
||||
|
||||
|
@ -215,7 +211,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -231,7 +227,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||
context: tf.Tensor(bs, seq_length, dim)
|
||||
Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
query, key, value, mask, head_mask, output_attentions = inputs
|
||||
bs, q_length, dim = shape_list(query)
|
||||
k_length = shape_list(key)[1]
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
|
@ -270,7 +265,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||
context = unshape(context) # (bs, q_length, dim)
|
||||
context = self.out_lin(context) # (bs, q_length, dim)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
return (context, weights)
|
||||
else:
|
||||
return (context,)
|
||||
|
@ -310,6 +305,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||
self.hidden_dim = config.hidden_dim
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.activation = config.activation
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
assert (
|
||||
config.dim % config.n_heads == 0
|
||||
|
@ -321,7 +317,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||
self.ffn = TFFFN(config, name="ffn")
|
||||
self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
|
||||
|
||||
def call(self, inputs, training=False): # removed: src_enc=None, src_len=None
|
||||
def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -335,11 +331,9 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||
ffn_output: tf.Tensor(bs, seq_length, dim)
|
||||
The output of the transformer block contextualization.
|
||||
"""
|
||||
x, attn_mask, head_mask, output_attentions = inputs
|
||||
|
||||
# Self-Attention
|
||||
sa_output = self.attention([x, x, x, attn_mask, head_mask, output_attentions], training=training)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
|
||||
if output_attentions:
|
||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||
# assert type(sa_output) == tuple
|
||||
|
@ -351,7 +345,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||
|
||||
output = (ffn_output,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
output = (sa_weights,) + output
|
||||
return output
|
||||
|
||||
|
@ -360,10 +354,12 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.n_layers = config.n_layers
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -383,34 +379,32 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if output_attentions=True
|
||||
"""
|
||||
x, attn_mask, head_mask, output_attentions, output_hidden_states = inputs
|
||||
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
|
||||
hidden_state = x
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i], output_attentions], training=training)
|
||||
layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
|
||||
hidden_state = layer_outputs[-1]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
assert len(layer_outputs) == 2, f"Incorrect number of outputs {len(layer_outputs)} instead of 2"
|
||||
if output_attentions:
|
||||
assert len(layer_outputs) == 2
|
||||
attentions = layer_outputs[0]
|
||||
all_attentions = all_attentions + (attentions,)
|
||||
else:
|
||||
assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
|
||||
|
||||
# Add last layer
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
outputs = (hidden_state,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
@ -481,6 +475,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
@ -491,11 +486,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer(
|
||||
[embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states], training=training
|
||||
embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states, training=training
|
||||
)
|
||||
|
||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||
|
@ -986,24 +982,21 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
distilbert_output = self.distilbert(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
head_mask,
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
distilbert_output = self.distilbert(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -2,7 +2,8 @@ import logging
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_electra import ElectraConfig
|
||||
from transformers import ElectraConfig
|
||||
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
||||
from .modeling_tf_utils import (
|
||||
|
@ -71,7 +72,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
|
@ -87,15 +96,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
|
@ -289,13 +298,17 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
|||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||
|
||||
hidden_states = self.encoder(
|
||||
[hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import tensorflow as tf
|
|||
|
||||
from .configuration_flaubert import FlaubertConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list
|
||||
from .modeling_tf_utils import keras_serializable, shape_list
|
||||
from .modeling_tf_xlm import (
|
||||
TFXLMForMultipleChoice,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
|
@ -274,10 +274,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||
# self attention
|
||||
if not self.pre_norm:
|
||||
attn_outputs = self.attentions[i](
|
||||
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
|
@ -285,10 +285,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||
else:
|
||||
tensor_normalized = self.layer_norm1[i](tensor)
|
||||
attn_outputs = self.attentions[i](
|
||||
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
|
||||
tensor_normalized, attn_mask, None, cache, head_mask[i], training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
|
@ -311,7 +311,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# Add last hidden state
|
||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
|
||||
# update cache length
|
||||
|
@ -322,9 +322,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||
# tensor = tensor.transpose(0, 1)
|
||||
|
||||
outputs = (tensor,)
|
||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (attentions,)
|
||||
return outputs # outputs, (hidden_states), (attentions)
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ from .modeling_tf_utils import (
|
|||
TFPreTrainedModel,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -75,6 +74,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
||||
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
||||
|
@ -95,8 +95,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
m = i >= j - ns + nd
|
||||
return tf.cast(m, dtype)
|
||||
|
||||
def _attn(self, inputs, training=False):
|
||||
q, k, v, attention_mask, head_mask, output_attentions = inputs
|
||||
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
if self.scale:
|
||||
|
@ -121,7 +120,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
w = w * head_mask
|
||||
|
||||
outputs = [tf.matmul(w, v)]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs.append(w)
|
||||
return outputs
|
||||
|
||||
|
@ -137,9 +136,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
x = tf.reshape(x, new_x_shape)
|
||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
|
||||
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = tf.split(x, 3, axis=2)
|
||||
query = self.split_heads(query)
|
||||
|
@ -151,12 +148,12 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
value = tf.concat([past_value, value], axis=-2)
|
||||
|
||||
# to cope with keras serialization
|
||||
if cast_bool_to_primitive(use_cache, True) is True:
|
||||
if use_cache:
|
||||
present = tf.stack([key, value], axis=0)
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
|
||||
a = attn_outputs[0]
|
||||
|
||||
a = self.merge_heads(a)
|
||||
|
@ -192,12 +189,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
|
||||
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||
a = self.ln_1(x)
|
||||
output_attn = self.attn(
|
||||
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training
|
||||
a, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=training
|
||||
)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
x = x + a
|
||||
|
@ -223,6 +218,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||
self.num_hidden_layers = config.n_layer
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_embd = config.n_embd
|
||||
self.output_hidden_states = self.output_hidden_states
|
||||
self.output_attentions = self.output_attentions
|
||||
|
||||
self.wte = TFSharedEmbeddings(
|
||||
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
||||
|
@ -362,34 +359,39 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||
all_attentions = []
|
||||
all_hidden_states = ()
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block(
|
||||
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
|
||||
hidden_states,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
# Add last hidden state
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if use_cache is True:
|
||||
if use_cache:
|
||||
outputs = outputs + (presents,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
|
@ -738,13 +740,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
past,
|
||||
flat_attention_mask,
|
||||
|
@ -755,18 +755,13 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
|
||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
||||
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
||||
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
||||
return outputs # lm logits, mc logits, presents, (all hidden_states), (attentions)
|
||||
|
|
|
@ -35,7 +35,6 @@ from .modeling_tf_utils import (
|
|||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -130,7 +129,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
|
@ -146,15 +153,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
|
@ -196,6 +203,7 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
|
@ -224,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.output_attentions = config.output_attentions
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
@ -244,14 +253,13 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(
|
||||
self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False
|
||||
):
|
||||
batch_size = shape_list(attention_mask)[0]
|
||||
mixed_query_layer = self.query(query_tensor)
|
||||
mixed_key_layer = self.key(key_tensor)
|
||||
mixed_value_layer = self.value(value_tensor)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
@ -285,9 +293,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
|
||||
outputs = (
|
||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
||||
)
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -305,8 +311,7 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer):
|
|||
if not self.use_bottleneck:
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, residual_tensor = inputs
|
||||
def call(self, hidden_states, residual_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
if not self.use_bottleneck:
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
@ -323,13 +328,22 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
|
|||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(
|
||||
self,
|
||||
query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
layer_input,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=False,
|
||||
):
|
||||
self_outputs = self.self(
|
||||
[query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions], training=training
|
||||
query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = self.mobilebert_output([self_outputs[0], layer_input], training=training)
|
||||
|
||||
attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
@ -349,8 +363,7 @@ class TFOutputBottleneck(tf.keras.layers.Layer):
|
|||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, residual_tensor = inputs
|
||||
def call(self, hidden_states, residual_tensor, training=False):
|
||||
layer_outputs = self.dense(hidden_states)
|
||||
layer_outputs = self.dropout(layer_outputs, training=training)
|
||||
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
|
||||
|
@ -372,16 +385,14 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
|
|||
else:
|
||||
self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, residual_tensor_1, residual_tensor_2 = inputs
|
||||
|
||||
def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
if not self.use_bottleneck:
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||
else:
|
||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||
hidden_states = self.bottleneck([hidden_states, residual_tensor_2])
|
||||
hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -466,7 +477,6 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||
super().__init__(**kwargs)
|
||||
self.use_bottleneck = config.use_bottleneck
|
||||
self.num_feedforward_networks = config.num_feedforward_networks
|
||||
|
||||
self.attention = TFMobileBertAttention(config, name="attention")
|
||||
self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
|
||||
self.mobilebert_output = TFMobileBertOutput(config, name="output")
|
||||
|
@ -478,16 +488,20 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||
TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1)
|
||||
]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
if self.use_bottleneck:
|
||||
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
|
||||
else:
|
||||
query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
|
||||
|
||||
attention_outputs = self.attention(
|
||||
[query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions],
|
||||
query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
layer_input,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -500,48 +514,57 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||
s += (attention_output,)
|
||||
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.mobilebert_output(
|
||||
[intermediate_output, attention_output, hidden_states], training=training
|
||||
)
|
||||
layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training)
|
||||
|
||||
outputs = (
|
||||
(layer_output,)
|
||||
+ attention_outputs[1:]
|
||||
+ (0, query_tensor, key_tensor, value_tensor, layer_input, attention_output, intermediate_output)
|
||||
+ (
|
||||
tf.constant(0),
|
||||
query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
layer_input,
|
||||
attention_output,
|
||||
intermediate_output,
|
||||
)
|
||||
+ s
|
||||
) # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFMobileBertEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
|
||||
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
|
||||
|
@ -732,11 +755,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -1360,8 +1386,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
outputs = self.mobilebert(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
|
@ -1370,16 +1395,12 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
outputs = self.mobilebert(flat_inputs, training=training)
|
||||
|
||||
training=training,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -29,7 +29,6 @@ from .modeling_tf_utils import (
|
|||
TFPreTrainedModel,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -84,6 +83,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
||||
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
||||
|
@ -104,8 +104,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
m = i >= j - ns + nd
|
||||
return tf.cast(m, dtype)
|
||||
|
||||
def _attn(self, inputs, training=False):
|
||||
q, k, v, attention_mask, head_mask, output_attentions = inputs
|
||||
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
if self.scale:
|
||||
|
@ -130,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
w = w * head_mask
|
||||
|
||||
outputs = [tf.matmul(w, v)]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs.append(w)
|
||||
return outputs
|
||||
|
||||
|
@ -146,16 +145,14 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
x = tf.reshape(x, new_x_shape)
|
||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = tf.split(x, 3, axis=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key)
|
||||
value = self.split_heads(value)
|
||||
|
||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
|
||||
a = attn_outputs[0]
|
||||
|
||||
a = self.merge_heads(a)
|
||||
|
@ -191,10 +188,8 @@ class TFBlock(tf.keras.layers.Layer):
|
|||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, attention_mask, head_mask, output_attentions = inputs
|
||||
|
||||
output_attn = self.attn([x, attention_mask, head_mask, output_attentions], training=training)
|
||||
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
||||
output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
|
||||
a = output_attn[0] # output_attn: a, (attentions)
|
||||
|
||||
n = self.ln_1(x + a)
|
||||
|
@ -341,23 +336,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
all_attentions = []
|
||||
all_hidden_states = ()
|
||||
for i, block in enumerate(self.h):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training)
|
||||
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training)
|
||||
hidden_states = outputs[0]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions.append(outputs[1])
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
# Add last hidden state
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
|
@ -671,13 +666,11 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
|
@ -686,18 +679,13 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
|
||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
||||
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
||||
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
||||
return outputs # lm logits, mc logits, (all hidden_states), (attentions)
|
||||
|
|
|
@ -86,9 +86,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
|
||||
return position_ids
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
|
@ -97,7 +97,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
return super()._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
|
@ -546,8 +546,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [
|
||||
outputs = self.roberta(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
|
@ -556,16 +555,12 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
outputs = self.roberta(flat_inputs, training=training)
|
||||
|
||||
training=training,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -115,6 +115,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||
self.is_decoder = config.is_decoder
|
||||
self.use_cache = config.use_cache
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.d_model = config.d_model
|
||||
|
@ -296,7 +297,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||
|
||||
outputs = (context,) + present_key_value_state
|
||||
|
||||
if cast_bool_to_primitive(output_attentions, True) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
if self.has_relative_attention_bias:
|
||||
outputs = outputs + (position_bias,)
|
||||
|
@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
hidden_states = self.dropout(inputs_embeds, training=training)
|
||||
|
||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
|
@ -727,23 +728,23 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
# append next layer key value states
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[2],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# Add last layer
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
# need to check if is decoder here as well for special cases when using keras compile
|
||||
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
|
||||
outputs = outputs + (present_key_value_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
|
|
@ -24,13 +24,7 @@ import tensorflow as tf
|
|||
from .configuration_transfo_xl import TransfoXLConfig
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||
from .modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||
from .tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
|
@ -119,6 +113,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
r_w_bias=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
init_std=0.02,
|
||||
output_attentions=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -127,6 +122,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
self.d_model = d_model
|
||||
self.d_head = d_head
|
||||
self.dropout = dropout
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
self.qkv_net = tf.keras.layers.Dense(
|
||||
3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net"
|
||||
|
@ -175,8 +171,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
|
||||
return x
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
w, r, attn_mask, mems, head_mask, output_attentions = inputs
|
||||
def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False):
|
||||
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
||||
|
||||
if mems is not None:
|
||||
|
@ -249,7 +244,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||
# residual connection + layer normalization
|
||||
outputs = [self.layer_norm(w + attn_out)]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs.append(attn_prob)
|
||||
|
||||
return outputs
|
||||
|
@ -272,6 +267,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||
r_r_bias=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
init_std=0.02,
|
||||
output_attentions=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -290,6 +286,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||
r_r_bias=r_r_bias,
|
||||
init_std=init_std,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
output_attentions=output_attentions,
|
||||
name="dec_attn",
|
||||
)
|
||||
self.pos_ff = TFPositionwiseFF(
|
||||
|
@ -302,11 +299,8 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||
name="pos_ff",
|
||||
)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions = inputs
|
||||
attn_outputs = self.dec_attn(
|
||||
[dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions], training=training
|
||||
)
|
||||
def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False):
|
||||
attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training)
|
||||
ff_output = self.pos_ff(attn_outputs[0], training=training)
|
||||
|
||||
outputs = [ff_output] + attn_outputs[1:]
|
||||
|
@ -443,6 +437,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||
r_r_bias=None if self.untie_r else self.r_r_bias,
|
||||
layer_norm_epsilon=config.layer_norm_epsilon,
|
||||
init_std=config.init_std,
|
||||
output_attentions=self.output_attentions,
|
||||
name="layers_._{}".format(i),
|
||||
)
|
||||
)
|
||||
|
@ -625,10 +620,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||
hids.append(core_out)
|
||||
mems_i = None if mems is None else mems[i]
|
||||
layer_outputs = layer(
|
||||
[core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions], training=training,
|
||||
core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions, training=training,
|
||||
)
|
||||
core_out = layer_outputs[0]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions.append(layer_outputs[1])
|
||||
else: # learnable embeddings and absolute embeddings
|
||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||
|
@ -639,12 +634,12 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
# We transpose back here to shape [bsz, len, hidden_dim]
|
||||
outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
|
||||
if cast_bool_to_primitive(output_hidden_states):
|
||||
if output_hidden_states:
|
||||
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
|
||||
hids.append(core_out)
|
||||
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
|
||||
outputs.append(hids)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
|
||||
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||
outputs.append(attentions)
|
||||
|
@ -860,14 +855,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
[input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states], training=training
|
||||
input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states, training=training
|
||||
)
|
||||
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
outputs = transformer_outputs[1:]
|
||||
|
||||
softmax_output = self.crit([pred_hid, labels], training=training)
|
||||
softmax_output = self.crit(pred_hid, labels, training=training)
|
||||
outputs = [softmax_output] + outputs
|
||||
|
||||
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
||||
|
|
|
@ -114,8 +114,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||
idx = tf.stack([r, target], 1)
|
||||
return tf.gather_nd(logprob, idx)
|
||||
|
||||
def call(self, inputs, return_mean=True, training=False):
|
||||
hidden, target = inputs
|
||||
def call(self, hidden, target, return_mean=True, training=False):
|
||||
head_logprob = 0
|
||||
if self.n_clusters == 0:
|
||||
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import functools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
|
@ -173,7 +174,11 @@ class TFTokenClassificationLoss:
|
|||
)
|
||||
# make sure only labels that are not equal to -100
|
||||
# are taken into account as loss
|
||||
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||
if tf.math.reduce_any(labels == -1).numpy() is True:
|
||||
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||
else:
|
||||
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||
|
||||
|
@ -233,7 +238,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
:obj:`Dict[str, tf.Tensor]`: Dummy inputs to build the network.
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
return {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||
|
||||
|
@ -774,14 +782,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||
return tf.gather(self.weight, input_ids)
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [..., hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [..., vocab_size].
|
||||
"""
|
||||
Computes logits by running inputs through a linear layer.
|
||||
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [..., hidden_size]
|
||||
|
||||
Returns:
|
||||
float32 tensor with shape [..., vocab_size].
|
||||
"""
|
||||
first_dims = shape_list(inputs)[:-1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.weight, transpose_b=True)
|
||||
|
||||
|
@ -789,7 +799,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||
|
||||
|
||||
class TFSequenceSummary(tf.keras.layers.Layer):
|
||||
r"""
|
||||
"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
|
@ -852,26 +862,9 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
|||
if self.has_last_dropout:
|
||||
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
||||
|
||||
def call(self, inputs, training=False) -> tf.Tensor:
|
||||
"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`Union[tf.Tensor, Tuple[tf.Tensor], List[tf.Tensor], Dict[str, tf.Tensor]]`):
|
||||
One or two tensors representing:
|
||||
|
||||
- **hidden_states** (:obj:`tf.Tensor` of shape :obj:`[batch_size, seq_len, hidden_size]`) -- The hidden
|
||||
states of the last layer.
|
||||
- **cls_index** :obj:`tf.Tensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are
|
||||
optional leading dimensions of :obj:`hidden_states`. Used if :obj:`summary_type == "cls_index"` and
|
||||
takes the last token of the sequence as classification token.
|
||||
|
||||
Returns:
|
||||
:obj:`tf.Tensor`: The summary of the sequence hidden states.
|
||||
"""
|
||||
def call(self, inputs, cls_index=None, training=False):
|
||||
if not isinstance(inputs, (dict, tuple, list)):
|
||||
hidden_states = inputs
|
||||
cls_index = None
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
hidden_states = inputs[0]
|
||||
cls_index = inputs[1] if len(inputs) > 1 else None
|
||||
|
|
|
@ -39,7 +39,6 @@ from .modeling_tf_utils import (
|
|||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -123,6 +122,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.output_attentions = config.output_attentions
|
||||
assert self.dim % self.n_heads == 0
|
||||
|
||||
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
|
||||
|
@ -135,11 +135,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
input, mask, kv, cache, head_mask, output_attentions = inputs
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
bs, qlen, dim = shape_list(input)
|
||||
|
@ -196,7 +195,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||
context = unshape(context) # (bs, qlen, dim)
|
||||
|
||||
outputs = (self.out_lin(context),)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
return outputs
|
||||
|
||||
|
@ -445,6 +444,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
|
||||
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
|
@ -457,15 +457,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
hidden_states = ()
|
||||
attentions = ()
|
||||
for i in range(self.n_layers):
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
|
||||
# self attention
|
||||
attn_outputs = self.attentions[i](
|
||||
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
|
@ -484,7 +484,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# Add last hidden state
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
|
||||
# update cache length
|
||||
|
@ -495,9 +495,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
# tensor = tensor.transpose(0, 1)
|
||||
|
||||
outputs = (tensor,)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (attentions,)
|
||||
return outputs # outputs, (hidden_states), (attentions)
|
||||
|
||||
|
@ -930,7 +930,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]))
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
@ -943,7 +943,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||
)
|
||||
lengths = None
|
||||
|
||||
flat_inputs = [
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_langs,
|
||||
|
@ -955,14 +955,12 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -38,7 +38,6 @@ from .modeling_tf_utils import (
|
|||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -92,6 +91,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
self.d_model = config.d_model
|
||||
self.scale = 1 / (config.d_head ** 0.5)
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
@ -142,11 +142,10 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
|
||||
return x
|
||||
|
||||
def rel_attn_core(self, inputs, training=False):
|
||||
def rel_attn_core(
|
||||
self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions, training=False
|
||||
):
|
||||
"""Core relative positional attention operations."""
|
||||
|
||||
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
|
||||
|
||||
# content based attention score
|
||||
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
|
||||
|
||||
|
@ -182,16 +181,14 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
# attention output
|
||||
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
return attn_vec, attn_prob
|
||||
|
||||
return attn_vec
|
||||
|
||||
def post_attention(self, inputs, residual=True, training=False):
|
||||
def post_attention(self, h, attn_vec, residual=True, training=False):
|
||||
"""Post-attention processing."""
|
||||
# post-attention projection (back to `d_model`)
|
||||
h, attn_vec = inputs
|
||||
|
||||
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
|
||||
|
||||
attn_out = self.dropout(attn_out, training=training)
|
||||
|
@ -202,9 +199,20 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
|
||||
return output
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
|
||||
|
||||
def call(
|
||||
self,
|
||||
h,
|
||||
g,
|
||||
attn_mask_h,
|
||||
attn_mask_g,
|
||||
r,
|
||||
seg_mat,
|
||||
mems,
|
||||
target_mapping,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=False,
|
||||
):
|
||||
if g is not None:
|
||||
# Two-stream attention with relative positional encoding.
|
||||
# content based attention score
|
||||
|
@ -228,15 +236,22 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
|
||||
# core attention ops
|
||||
attn_vec_h = self.rel_attn_core(
|
||||
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
|
||||
q_head_h,
|
||||
k_head_h,
|
||||
v_head_h,
|
||||
k_head_r,
|
||||
seg_mat,
|
||||
attn_mask_h,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attn_vec_h, attn_prob_h = attn_vec_h
|
||||
|
||||
# post processing
|
||||
output_h = self.post_attention([h, attn_vec_h], training=training)
|
||||
output_h = self.post_attention(h, attn_vec_h, training=training)
|
||||
|
||||
# g-stream
|
||||
# query-stream query head
|
||||
|
@ -246,27 +261,41 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
if target_mapping is not None:
|
||||
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
|
||||
attn_vec_g = self.rel_attn_core(
|
||||
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
|
||||
q_head_g,
|
||||
k_head_h,
|
||||
v_head_h,
|
||||
k_head_r,
|
||||
seg_mat,
|
||||
attn_mask_g,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attn_vec_g, attn_prob_g = attn_vec_g
|
||||
|
||||
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
|
||||
else:
|
||||
attn_vec_g = self.rel_attn_core(
|
||||
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
|
||||
q_head_g,
|
||||
k_head_h,
|
||||
v_head_h,
|
||||
k_head_r,
|
||||
seg_mat,
|
||||
attn_mask_g,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attn_vec_g, attn_prob_g = attn_vec_g
|
||||
|
||||
# post processing
|
||||
output_g = self.post_attention([g, attn_vec_g], training=training)
|
||||
output_g = self.post_attention(g, attn_vec_g, training=training)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attn_prob = attn_prob_h, attn_prob_g
|
||||
|
||||
else:
|
||||
|
@ -286,19 +315,26 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
|
||||
# core attention ops
|
||||
attn_vec = self.rel_attn_core(
|
||||
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
|
||||
q_head_h,
|
||||
k_head_h,
|
||||
v_head_h,
|
||||
k_head_r,
|
||||
seg_mat,
|
||||
attn_mask_h,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attn_vec, attn_prob = attn_vec
|
||||
|
||||
# post processing
|
||||
output_h = self.post_attention([h, attn_vec], training=training)
|
||||
output_h = self.post_attention(h, attn_vec, training=training)
|
||||
output_g = None
|
||||
|
||||
outputs = (output_h, output_g)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn_prob,)
|
||||
return outputs
|
||||
|
||||
|
@ -337,8 +373,33 @@ class TFXLNetLayer(tf.keras.layers.Layer):
|
|||
self.ff = TFXLNetFeedForward(config, name="ff")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
outputs = self.rel_attn(inputs, training=training)
|
||||
def call(
|
||||
self,
|
||||
output_h,
|
||||
output_g,
|
||||
non_tgt_mask,
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
mems,
|
||||
target_mapping,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=False,
|
||||
):
|
||||
outputs = self.rel_attn(
|
||||
output_h,
|
||||
output_g,
|
||||
non_tgt_mask,
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
mems,
|
||||
target_mapping,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
output_h, output_g = outputs[:2]
|
||||
|
||||
if output_g is not None:
|
||||
|
@ -686,32 +747,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache:
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
outputs = layer_module(
|
||||
[
|
||||
output_h,
|
||||
output_g,
|
||||
non_tgt_mask,
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
mems[i],
|
||||
target_mapping,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
],
|
||||
output_h,
|
||||
output_g,
|
||||
non_tgt_mask,
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
mems[i],
|
||||
target_mapping,
|
||||
head_mask[i],
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
output_h, output_g = outputs[:2]
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions.append(outputs[2])
|
||||
|
||||
# Add last hidden state
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
output = self.dropout(output_g if output_g is not None else output_h, training=training)
|
||||
|
@ -719,16 +778,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
|
||||
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache:
|
||||
outputs = outputs + (new_mems,)
|
||||
|
||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
if output_g is not None:
|
||||
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
|
||||
else:
|
||||
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
|
||||
outputs = outputs + (hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
if output_attentions:
|
||||
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||
outputs = outputs + (attentions,)
|
||||
|
||||
|
@ -1240,8 +1299,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
mems,
|
||||
|
@ -1254,14 +1312,12 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
training=training,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
|
|
|
@ -4,7 +4,6 @@ import datetime
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
|
@ -25,15 +24,6 @@ if is_wandb_available():
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if parse(tf.__version__).release < (2, 2, 0):
|
||||
logger.info(
|
||||
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
|
||||
tf.__version__
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class TFTrainer:
|
||||
"""
|
||||
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
||||
|
@ -77,6 +67,11 @@ class TFTrainer:
|
|||
None,
|
||||
),
|
||||
):
|
||||
assert parse(tf.__version__).release >= (2, 2, 0), (
|
||||
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
|
||||
% tf.__version__
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.train_dataset = train_dataset
|
||||
|
|
|
@ -23,7 +23,7 @@ import unittest
|
|||
from importlib import import_module
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -130,6 +130,61 @@ class TFModelTesterMixin:
|
|||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
model._set_save_spec(inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_attentions = True
|
||||
encoder_seq_length = (
|
||||
self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "encoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
encoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
model._set_save_spec(inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_keras_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
|
@ -342,11 +342,17 @@ class TFXLNetModelTester:
|
|||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
(logits, mems_1) = model(inputs)
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
Загрузка…
Ссылка в новой задаче