Replace strided slice with tf.expand_dims (#10078)
* Replace tf.newaxis -> tf.expand_dims * Fix tests * Fix tests * Use reshape when a tensors needs a double expand * Fix GPT2 * Fix GPT2
This commit is contained in:
Родитель
e7381c4596
Коммит
b82fe7d258
|
@ -1631,7 +1631,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
|||
) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
|
||||
cls_shape = shape_list(cls_index)
|
||||
if len(cls_shape) <= len(hidden_shape) - 2:
|
||||
cls_index = cls_index[..., tf.newaxis]
|
||||
cls_index = tf.expand_dims(cls_index, axis=-1)
|
||||
# else:
|
||||
# cls_index = cls_index[..., tf.newaxis]
|
||||
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
||||
|
|
|
@ -138,7 +138,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -543,7 +543,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -192,7 +192,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -655,7 +655,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -128,7 +128,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
|
|||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -541,7 +541,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -312,9 +312,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
else:
|
||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
inputs["position_ids"] = tf.expand_dims(
|
||||
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
|
||||
)
|
||||
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
|
||||
|
||||
# Attention mask.
|
||||
|
@ -324,7 +324,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -113,7 +113,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
|
|
@ -415,7 +415,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -510,7 +510,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -181,12 +181,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
|||
else:
|
||||
# assert lengths.max().item() <= slen
|
||||
alen = tf.range(slen)
|
||||
mask = tf.math.less(alen, lengths[:, tf.newaxis])
|
||||
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
|
||||
|
||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||
if causal:
|
||||
attn_mask = tf.less_equal(
|
||||
tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)), alen[tf.newaxis, :, tf.newaxis]
|
||||
tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1))
|
||||
)
|
||||
else:
|
||||
attn_mask = mask
|
||||
|
@ -612,7 +612,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# hidden_states and attentions cannot be None in graph mode.
|
||||
hidden_states = () if inputs["output_hidden_states"] else None
|
||||
|
@ -682,7 +682,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||
tensor_normalized = self.layer_norm2[i](tensor)
|
||||
tensor = tensor + self.ffns[i](tensor_normalized)
|
||||
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# Add last hidden state
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -302,9 +302,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
inputs["position_ids"] = tf.expand_dims(
|
||||
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
|
@ -312,7 +312,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||
inputs["attention_mask"] = tf.reshape(
|
||||
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -547,9 +547,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
|||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
||||
else:
|
||||
position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
position_ids = tf.expand_dims(
|
||||
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
||||
)
|
||||
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
|
@ -1661,7 +1661,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, :, tf.newaxis, tf.newaxis]
|
||||
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||
extended_attention_mask = tf.reshape(
|
||||
inputs["attention_mask"], (attention_mask_shape[0], attention_mask_shape[1], 1, 1)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
|
||||
# masked and global attn positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -233,7 +233,7 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
|
|||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
||||
|
@ -726,7 +726,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
@ -738,7 +738,12 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
|||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
if inputs["visual_attention_mask"] is not None:
|
||||
extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_visual_attention_mask = tf.reshape(
|
||||
inputs["visual_attention_mask"], (input_shape[0], 1, 1, input_shape[1])
|
||||
)
|
||||
extended_visual_attention_mask = tf.expand_dims(
|
||||
tf.expand_dims(inputs["visual_attention_mask"], axis=1), axis=1
|
||||
)
|
||||
|
||||
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
|
||||
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
||||
|
|
|
@ -192,7 +192,7 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||
inputs_embeds = self.embedding_transformation(inputs_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -731,7 +731,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -151,9 +151,9 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
|||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
||||
else:
|
||||
position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
position_ids = tf.expand_dims(
|
||||
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
||||
)
|
||||
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
|
@ -533,7 +533,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -268,7 +268,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1], dtype=tf.int32), axis=0)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
|
@ -276,7 +276,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -147,9 +147,9 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
|
|||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
||||
else:
|
||||
position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
position_ids = tf.expand_dims(
|
||||
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
|
||||
)
|
||||
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
|
@ -533,7 +533,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
|
@ -92,12 +92,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
|||
else:
|
||||
# assert lengths.max().item() <= slen
|
||||
alen = tf.range(slen)
|
||||
mask = tf.math.less(alen, lengths[:, tf.newaxis])
|
||||
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
|
||||
|
||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||
if causal:
|
||||
attn_mask = tf.less_equal(
|
||||
tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)), alen[tf.newaxis, :, tf.newaxis]
|
||||
tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1))
|
||||
)
|
||||
else:
|
||||
attn_mask = mask
|
||||
|
@ -463,7 +463,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# transformer layers
|
||||
hidden_states = () if inputs["output_hidden_states"] else None
|
||||
|
@ -502,7 +502,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
# FFN
|
||||
tensor = tensor + self.ffns[i](tensor)
|
||||
tensor = self.layer_norm2[i](tensor)
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# Add last hidden state
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -134,7 +134,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
|
@ -570,7 +570,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
|
|
Загрузка…
Ссылка в новой задаче