This commit is contained in:
Julien Plu 2021-02-24 14:38:29 +01:00 коммит произвёл GitHub
Родитель 2d458b2c7d
Коммит cdcdd5f03a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 24 добавлений и 30 удалений

Просмотреть файл

@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if attn_mask.dtype == tf.float16:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def create_mask(self, qlen, mlen, dtype=tf.float32):
def create_mask(self, qlen, mlen):
"""
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = tf.ones([qlen, qlen], dtype=dtype)
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb
def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = tf.range(0, self.d_model, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=dtype)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == "bi":
@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
# Attention mask
# causal attention mask
if self.attn_type == "uni":
@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
one_cst = tf.constant(1.0)
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None:
# all mems can be attended to
if mlen > 0:
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float)
attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
non_tgt_mask = -tf.eye(qlen)
if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
else:
non_tgt_mask = None
@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else:
cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
seg_mat = tf.cast(
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])),
dtype=inputs["token_type_ids"].dtype,
)
seg_mat = tf.one_hot(seg_mat, 2)
else:
seg_mat = None
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed
@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset = 2
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
# Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
# We'll only predict the last token
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1))
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1))
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {