Implementing the test integration of BertGeneration (#9990)
* claiming this issue * Integration test for BertGeneration(Encoder and Decoder) * fix code quality
This commit is contained in:
Родитель
cdd8659231
Коммит
3b7e612a5e
|
@ -297,3 +297,33 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||
def test_model_from_pretrained(self):
|
||||
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationEncoderIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size([1, 8, 1024])
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.1775, 0.0083, -0.0321], [1.6002, 0.1287, 0.3912], [2.1473, 0.5791, 0.6066]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationDecoderIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size([1, 8, 50358])
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.5788, -2.5994, -3.7054], [0.0438, 4.7997, 1.8795], [1.5862, 6.6409, 4.4638]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
|
Загрузка…
Ссылка в новой задаче