Fixing flaky conversational test + flag it as a pipeline test. (#9837)

This commit is contained in:
Nicolas Patry 2021-01-28 10:19:55 +01:00 коммит произвёл GitHub
Родитель 58fbef9ebc
Коммит b936582f71
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 4 добавлений и 1 удалений

Просмотреть файл

@ -22,7 +22,7 @@ from transformers import (
is_torch_available,
pipeline,
)
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
from .test_pipelines_common import MonoInputPipelineCommonMixin
@ -35,6 +35,7 @@ if is_torch_available():
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test
class SimpleConversationPipelineTests(unittest.TestCase):
def get_pipeline(self):
# When
@ -52,9 +53,11 @@ class SimpleConversationPipelineTests(unittest.TestCase):
# Force model output to be L
V, D = model.lm_head.weight.shape
bias = torch.zeros(V, requires_grad=True)
weight = torch.zeros((V, D), requires_grad=True)
bias[76] = 1
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.weight = torch.nn.Parameter(weight)
# # Created with:
# import tempfile