Making TF MPNet model compliant with XLA (#10260)
* Fix XLA * Rework cast * Apply style
This commit is contained in:
Родитель
fb56bf2584
Коммит
3d72d47f09
|
@ -348,15 +348,22 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
||||||
self.n_heads = config.num_attention_heads
|
self.n_heads = config.num_attention_heads
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||||
|
self.initializer_range = config.initializer_range
|
||||||
|
|
||||||
self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||||
self.relative_attention_bias = tf.keras.layers.Embedding(
|
|
||||||
config.relative_attention_num_buckets,
|
|
||||||
self.n_heads,
|
|
||||||
name="relative_attention_bias",
|
|
||||||
)
|
|
||||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
with tf.name_scope("relative_attention_bias"):
|
||||||
|
self.relative_attention_bias = self.add_weight(
|
||||||
|
name="embeddings",
|
||||||
|
shape=[self.relative_attention_num_buckets, self.n_heads],
|
||||||
|
initializer=get_initializer(self.initializer_range),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().build(input_shape)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -405,18 +412,16 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
||||||
n = -relative_position
|
n = -relative_position
|
||||||
|
|
||||||
num_buckets //= 2
|
num_buckets //= 2
|
||||||
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
|
ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets
|
||||||
n = tf.math.abs(n)
|
n = tf.math.abs(n)
|
||||||
|
|
||||||
# now n is in the range [0, inf)
|
# now n is in the range [0, inf)
|
||||||
max_exact = num_buckets // 2
|
max_exact = num_buckets // 2
|
||||||
is_small = tf.math.less(n, max_exact)
|
is_small = tf.math.less(n, max_exact)
|
||||||
|
|
||||||
val_if_large = max_exact + tf.dtypes.cast(
|
val_if_large = max_exact + tf.cast(
|
||||||
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
|
tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact),
|
||||||
/ math.log(max_distance / max_exact)
|
dtype=relative_position.dtype,
|
||||||
* (num_buckets - max_exact),
|
|
||||||
tf.int32,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
|
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
|
||||||
|
@ -441,7 +446,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
|
||||||
relative_position,
|
relative_position,
|
||||||
num_buckets=self.relative_attention_num_buckets,
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
)
|
)
|
||||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads)
|
||||||
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
|
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -541,7 +546,9 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
||||||
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
|
|
@ -232,10 +232,6 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
def test_xla_mode(self):
|
|
||||||
# TODO JP: Make MPNet XLA compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in ["microsoft/mpnet-base"]:
|
for model_name in ["microsoft/mpnet-base"]:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче