Making TF MPNet model compliant with XLA (#10260)

* Fix XLA

* Rework cast

* Apply style
This commit is contained in:
Julien Plu 2021-02-19 12:56:41 +01:00 коммит произвёл GitHub
Родитель fb56bf2584
Коммит 3d72d47f09
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 20 добавлений и 17 удалений

Просмотреть файл

@ -348,15 +348,22 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
self.n_heads = config.num_attention_heads self.n_heads = config.num_attention_heads
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.initializer_range = config.initializer_range
self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
self.relative_attention_bias = tf.keras.layers.Embedding(
config.relative_attention_num_buckets,
self.n_heads,
name="relative_attention_bias",
)
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
def build(self, input_shape):
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=get_initializer(self.initializer_range),
)
return super().build(input_shape)
def call( def call(
self, self,
hidden_states, hidden_states,
@ -405,18 +412,16 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
n = -relative_position n = -relative_position
num_buckets //= 2 num_buckets //= 2
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets
n = tf.math.abs(n) n = tf.math.abs(n)
# now n is in the range [0, inf) # now n is in the range [0, inf)
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact) is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast( val_if_large = max_exact + tf.cast(
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact),
/ math.log(max_distance / max_exact) dtype=relative_position.dtype,
* (num_buckets - max_exact),
tf.int32,
) )
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
@ -441,7 +446,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
relative_position, relative_position,
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
return values return values
@ -541,7 +546,9 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head

Просмотреть файл

@ -232,10 +232,6 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
def test_xla_mode(self):
# TODO JP: Make MPNet XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in ["microsoft/mpnet-base"]: for model_name in ["microsoft/mpnet-base"]: