Making TF MPNet model compliant with XLA (#10260)
* Fix XLA * Rework cast * Apply style
This commit is contained in:
Родитель
fb56bf2584
Коммит
3d72d47f09
|
@ -348,15 +348,22 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
|||
self.n_heads = config.num_attention_heads
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
self.relative_attention_bias = tf.keras.layers.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
name="relative_attention_bias",
|
||||
)
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
|
||||
def build(self, input_shape):
|
||||
with tf.name_scope("relative_attention_bias"):
|
||||
self.relative_attention_bias = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.relative_attention_num_buckets, self.n_heads],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
return super().build(input_shape)
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -405,18 +412,16 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
|||
n = -relative_position
|
||||
|
||||
num_buckets //= 2
|
||||
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
|
||||
ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets
|
||||
n = tf.math.abs(n)
|
||||
|
||||
# now n is in the range [0, inf)
|
||||
max_exact = num_buckets // 2
|
||||
is_small = tf.math.less(n, max_exact)
|
||||
|
||||
val_if_large = max_exact + tf.dtypes.cast(
|
||||
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact),
|
||||
tf.int32,
|
||||
val_if_large = max_exact + tf.cast(
|
||||
tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact),
|
||||
dtype=relative_position.dtype,
|
||||
)
|
||||
|
||||
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
|
||||
|
@ -441,7 +446,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
|||
relative_position,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
|
@ -541,7 +546,9 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
|
|||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
|
|
@ -232,10 +232,6 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make MPNet XLA compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ["microsoft/mpnet-base"]:
|
||||
|
|
Загрузка…
Ссылка в новой задаче