[Tests, GPU, SLOW] fix a bunch of GPU hardcoded tests in Pytorch (#4468)
* fix gpu slow tests in pytorch * change model to device syntax
This commit is contained in:
Родитель
5856999a9f
Коммит
aa925a52fa
|
@ -80,7 +80,7 @@ def main():
|
|||
|
||||
# Load a pre-trained model
|
||||
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
|
||||
model = model.to(device)
|
||||
model.to(device)
|
||||
|
||||
logger.info(
|
||||
"Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
|
||||
|
|
|
@ -770,7 +770,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
import torch_xla.core.xla_model as xm
|
||||
|
||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
||||
model = model.to(xm.xla_device())
|
||||
model.to(xm.xla_device())
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -219,6 +219,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor(
|
||||
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||
) # Legal the president is
|
||||
|
|
|
@ -329,5 +329,5 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_real_bert_model_from_pretrained(self):
|
||||
model = EncoderDecoderModel.from_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -343,6 +343,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
|
@ -372,6 +373,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
|
|
|
@ -214,32 +214,39 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = LongformerModel.from_pretrained("longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world! ' repeated 1000 times
|
||||
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
|
||||
input_ids = torch.tensor(
|
||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||
) # long input
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
|
||||
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
expected_output_sum = torch.tensor(74585.8594)
|
||||
expected_output_mean = torch.tensor(0.0243)
|
||||
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
|
||||
expected_output_mean = torch.tensor(0.0243, device=torch_device)
|
||||
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world! ' repeated 1000 times
|
||||
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
|
||||
input_ids = torch.tensor(
|
||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||
) # long input
|
||||
|
||||
loss, prediction_scores = model(input_ids, masked_lm_labels=input_ids)
|
||||
|
||||
expected_loss = torch.tensor(0.0620)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1599e08)
|
||||
expected_prediction_scores_mean = torch.tensor(-3.0622)
|
||||
expected_loss = torch.tensor(0.0620, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
|
||||
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))
|
||||
|
|
|
@ -227,6 +227,7 @@ class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_openai_gpt(self):
|
||||
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is
|
||||
expected_output_ids = [
|
||||
481,
|
||||
|
|
|
@ -444,6 +444,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
@ -471,6 +472,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||
expected_translation = "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre un « portrait familial » de générations innombrables de étoiles : les plus anciennes sont observées sous forme de pointes bleues, alors que les « nouveau-nés » de couleur rose dans la salle des accouchements doivent être plus difficiles "
|
||||
|
||||
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
@ -498,6 +500,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
|
||||
|
||||
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
|
|
@ -223,6 +223,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_transfo_xl_wt103(self):
|
||||
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
|
|
|
@ -434,6 +434,7 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_xlm_mlm_en_2048(self):
|
||||
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device) # the president
|
||||
expected_output_ids = [
|
||||
14,
|
||||
|
@ -459,4 +460,4 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
|
|||
] # the president the president the president the president the president the president the president the president the president the president
|
||||
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
self.assertListEqual(output_ids[0].cpu().numpy().tolist(), expected_output_ids)
|
||||
|
|
|
@ -517,6 +517,7 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lm_generate_xlnet_base_cased(self):
|
||||
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
|
|
Загрузка…
Ссылка в новой задаче