Rework casts (#10274)
This commit is contained in:
Родитель
2d458b2c7d
Коммит
cdcdd5f03a
|
@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||
attn_score = (ac + bd + ef) * self.scale
|
||||
if attn_mask is not None:
|
||||
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
|
||||
if attn_mask.dtype == tf.float16:
|
||||
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
|
||||
attn_score = attn_score - 65500 * attn_mask
|
||||
else:
|
||||
attn_score = attn_score - 1e30 * attn_mask
|
||||
|
@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
||||
def create_mask(self, qlen, mlen, dtype=tf.float32):
|
||||
def create_mask(self, qlen, mlen):
|
||||
"""
|
||||
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
|
||||
|
||||
|
@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
|
||||
|
||||
"""
|
||||
attn_mask = tf.ones([qlen, qlen], dtype=dtype)
|
||||
attn_mask = tf.ones([qlen, qlen])
|
||||
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
|
||||
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
|
||||
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
|
||||
attn_mask_pad = tf.zeros([qlen, mlen])
|
||||
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
|
||||
if self.same_length:
|
||||
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
|
||||
|
@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
|
||||
return pos_emb
|
||||
|
||||
def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
|
||||
def relative_positional_encoding(self, qlen, klen, bsz=None):
|
||||
"""create relative positional encoding."""
|
||||
freq_seq = tf.range(0, self.d_model, 2.0)
|
||||
if dtype is not None and dtype != tf.float32:
|
||||
freq_seq = tf.cast(freq_seq, dtype=dtype)
|
||||
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
|
||||
|
||||
if self.attn_type == "bi":
|
||||
|
@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
fwd_pos_seq = tf.range(beg, end, -1.0)
|
||||
bwd_pos_seq = tf.range(-beg, -end, 1.0)
|
||||
|
||||
if dtype is not None and dtype != tf.float32:
|
||||
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
|
||||
bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
|
||||
|
||||
if self.clamp_len > 0:
|
||||
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
|
||||
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
|
||||
|
@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
|
||||
else:
|
||||
fwd_pos_seq = tf.range(beg, end, -1.0)
|
||||
if dtype is not None and dtype != tf.float32:
|
||||
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
|
||||
if self.clamp_len > 0:
|
||||
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
|
||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||
|
@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
|
||||
|
||||
# Attention mask
|
||||
# causal attention mask
|
||||
if self.attn_type == "uni":
|
||||
|
@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
|
||||
)
|
||||
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
|
||||
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
|
||||
one_cst = tf.constant(1.0)
|
||||
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
|
||||
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
|
||||
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
|
||||
|
@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
if data_mask is not None:
|
||||
# all mems can be attended to
|
||||
if mlen > 0:
|
||||
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
|
||||
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
|
||||
data_mask = tf.concat([mems_mask, data_mask], axis=1)
|
||||
if attn_mask is None:
|
||||
attn_mask = data_mask[:, :, :, None]
|
||||
|
@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
attn_mask += data_mask[:, :, :, None]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float)
|
||||
attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)
|
||||
|
||||
if attn_mask is not None:
|
||||
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
|
||||
non_tgt_mask = -tf.eye(qlen)
|
||||
if mlen > 0:
|
||||
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
|
||||
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
|
||||
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1)
|
||||
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
|
||||
else:
|
||||
non_tgt_mask = None
|
||||
|
||||
|
@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
if inputs["token_type_ids"] is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
if mlen > 0:
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype)
|
||||
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
|
||||
else:
|
||||
cat_ids = inputs["token_type_ids"]
|
||||
|
||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
|
||||
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
|
||||
seg_mat = tf.cast(
|
||||
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])),
|
||||
dtype=inputs["token_type_ids"].dtype,
|
||||
)
|
||||
seg_mat = tf.one_hot(seg_mat, 2)
|
||||
else:
|
||||
seg_mat = None
|
||||
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||
pos_emb = self.dropout(pos_emb, training=inputs["training"])
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
offset = 2
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
|
||||
|
||||
if past:
|
||||
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
|
||||
|
@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
|
||||
# Build permutation mask so that previous tokens don't see last token
|
||||
sequence_length = inputs.shape[1]
|
||||
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
|
||||
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
|
||||
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
|
||||
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
|
||||
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
|
||||
|
||||
# We'll only predict the last token
|
||||
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
|
||||
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
|
||||
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1))
|
||||
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1))
|
||||
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
|
||||
|
||||
inputs = {
|
||||
|
|
Загрузка…
Ссылка в новой задаче