updating tests and TF 2.0 model
This commit is contained in:
Родитель
0558c9cb9b
Коммит
8ae1044f80
|
@ -726,8 +726,11 @@ class T5Model(T5PreTrainedModel):
|
|||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
encoder_attention_mask = kwargs_encoder.get("attention_mask", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
# Apply masking
|
||||
|
@ -740,8 +743,12 @@ class T5Model(T5PreTrainedModel):
|
|||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert decoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids)
|
||||
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = encoder_attention_mask
|
||||
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||
|
@ -825,16 +832,24 @@ class T5WithLMHeadModel(T5PreTrainedModel):
|
|||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
|
||||
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert decoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids)
|
||||
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||
|
|
|
@ -613,6 +613,12 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
decoder_config.is_decoder = True
|
||||
self.decoder = TFT5MainLayer(decoder_config, name='decoder')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def call(self, decoder_input_ids, **kwargs):
|
||||
# We allow two types of multi-inputs:
|
||||
# - traditional keyword arguments in the call method
|
||||
|
@ -634,16 +640,24 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
|
||||
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert decoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids)
|
||||
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||
|
@ -692,6 +706,12 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel):
|
|||
decoder_config.is_decoder = True
|
||||
self.decoder = TFT5MainLayer(decoder_config, name='decoder')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def call(self, decoder_input_ids, **kwargs):
|
||||
# We allow two types of multi-inputs:
|
||||
# - traditional keyword arguments in the call method
|
||||
|
@ -713,16 +733,24 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel):
|
|||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
|
||||
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||
# Convert decoder inputs in embeddings if needed
|
||||
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
|
||||
if hidden_states is None:
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids)
|
||||
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||
|
|
|
@ -568,8 +568,14 @@ class CommonTestCases:
|
|||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
@ -577,9 +583,13 @@ class CommonTestCases:
|
|||
model.eval()
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs_dict["inputs_embeds"] = wte(input_ids)
|
||||
outputs = model(**inputs_dict)
|
||||
if not self.is_encoder_decoder:
|
||||
inputs_dict["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
class GPTModelTester(CommonModelTester):
|
||||
|
||||
|
|
|
@ -18,20 +18,19 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
|
||||
|
@ -174,7 +173,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
|
|
@ -130,12 +130,12 @@ class TFCommonTestCases:
|
|||
for name, key in inputs_dict.items())
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(inputs_dict)
|
||||
tfo = tfo[0].numpy()
|
||||
pto = pto[0].numpy()
|
||||
tfo[np.isnan(tfo)] = 0
|
||||
pto[np.isnan(pto)] = 0
|
||||
max_diff = np.amax(np.abs(tfo - pto))
|
||||
tfo = tf_model(inputs_dict, training=False)
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
tf_hidden_states[np.isnan(tf_hidden_states)] = 0
|
||||
pt_hidden_states[np.isnan(pt_hidden_states)] = 0
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
|
@ -296,33 +296,46 @@ class TFCommonTestCases:
|
|||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
||||
self.assertTrue(tf.math.equal(first, second).numpy().all())
|
||||
|
||||
def _get_embeds(self, wte, input_ids):
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try a few of them.
|
||||
# We used to fall back to just synthetically creating a dummy tensor of ones:
|
||||
try:
|
||||
x = wte(input_ids, mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids, None, None, None], mode="embedding")
|
||||
except:
|
||||
if hasattr(self.model_tester, "embedding_size"):
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
|
||||
else:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
return x
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
try:
|
||||
x = wte(input_ids, mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids, None, None, None], mode="embedding")
|
||||
except:
|
||||
if hasattr(self.model_tester, "embedding_size"):
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
|
||||
else:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try a few of them.
|
||||
# We used to fall back to just synthetically creating a dummy tensor of ones:
|
||||
#
|
||||
inputs_dict["inputs_embeds"] = x
|
||||
if not self.is_encoder_decoder:
|
||||
inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids)
|
||||
else:
|
||||
inputs_dict["encoder_inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
|
||||
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
|
||||
|
|
|
@ -18,21 +18,21 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
|
||||
from transformers import T5Config, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_t5 import (TFT5Model, TFT5WithLMHeadModel,TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
from transformers.modeling_tf_t5 import (TFT5Model, TFT5WithLMHeadModel,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
is_encoder_decoder = True
|
||||
|
@ -160,7 +160,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['t5-small']:
|
||||
|
|
|
@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from transformers.tokenization_t5 import (T5Tokenizer)
|
||||
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
||||
|
|
Загрузка…
Ссылка в новой задаче