diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index 56fc4ecd2..89e83995d 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): attn_score = (ac + bd + ef) * self.scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask - if attn_mask.dtype == tf.float16: + if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16: attn_score = attn_score - 65500 * attn_mask else: attn_score = attn_score - 1e30 * attn_mask @@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): def _prune_heads(self, heads_to_prune): raise NotImplementedError - def create_mask(self, qlen, mlen, dtype=tf.float32): + def create_mask(self, qlen, mlen): """ Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked. @@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0] """ - attn_mask = tf.ones([qlen, qlen], dtype=dtype) + attn_mask = tf.ones([qlen, qlen]) mask_u = tf.matrix_band_part(attn_mask, 0, -1) mask_dia = tf.matrix_band_part(attn_mask, 0, 0) - attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) + attn_mask_pad = tf.zeros([qlen, mlen]) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) if self.same_length: mask_l = tf.matrix_band_part(attn_mask, -1, 0) @@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): return pos_emb - def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None): + def relative_positional_encoding(self, qlen, klen, bsz=None): """create relative positional encoding.""" freq_seq = tf.range(0, self.d_model, 2.0) - if dtype is not None and dtype != tf.float32: - freq_seq = tf.cast(freq_seq, dtype=dtype) inv_freq = 1 / (10000 ** (freq_seq / self.d_model)) if self.attn_type == "bi": @@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): fwd_pos_seq = tf.range(beg, end, -1.0) bwd_pos_seq = tf.range(-beg, -end, 1.0) - if dtype is not None and dtype != tf.float32: - fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) - bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) - if self.clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) @@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) else: fwd_pos_seq = tf.range(beg, end, -1.0) - if dtype is not None and dtype != tf.float32: - fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) if self.clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) @@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0 klen = mlen + qlen - dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 - # Attention mask # causal attention mask if self.attn_type == "uni": @@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one." ) if inputs["input_mask"] is None and inputs["attention_mask"] is not None: - inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float) + one_cst = tf.constant(1.0) + inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) if inputs["input_mask"] is not None and inputs["perm_mask"] is not None: data_mask = inputs["input_mask"][None] + inputs["perm_mask"] elif inputs["input_mask"] is not None and inputs["perm_mask"] is None: @@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): if data_mask is not None: # all mems can be attended to if mlen > 0: - mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float) + mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz]) data_mask = tf.concat([mems_mask, data_mask], axis=1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] @@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): attn_mask += data_mask[:, :, :, None] if attn_mask is not None: - attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float) + attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype) if attn_mask is not None: - non_tgt_mask = -tf.eye(qlen, dtype=dtype_float) + non_tgt_mask = -tf.eye(qlen) if mlen > 0: - non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) - non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float) + non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1) + non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype) else: non_tgt_mask = None @@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): if inputs["token_type_ids"] is not None: # Convert `token_type_ids` to one-hot `seg_mat` if mlen > 0: - mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) + mem_pad = tf.zeros([mlen, bsz], dtype=inputs["token_type_ids"].dtype) cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0) else: cat_ids = inputs["token_type_ids"] # `1` indicates not in the same segment [qlen x klen x bsz] - seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32) - seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float) + seg_mat = tf.cast( + tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), + dtype=inputs["token_type_ids"].dtype, + ) + seg_mat = tf.one_hot(seg_mat, 2) else: seg_mat = None # Positional encoding - pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) + pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.dropout(pos_emb, training=inputs["training"]) # Prepare head mask if needed @@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): offset = 2 effective_batch_size = inputs.shape[0] - dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32) + dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) if past: inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) @@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): # Build permutation mask so that previous tokens don't see last token sequence_length = inputs.shape[1] - perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32) - perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32) + perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1)) + perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1)) perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) # We'll only predict the last token - target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32) - target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32) + target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1)) + target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1)) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) inputs = {