Fixing flaky conversational test + flag it as a pipeline test. (#9837)
This commit is contained in:
Родитель
58fbef9ebc
Коммит
b936582f71
|
@ -22,7 +22,7 @@ from transformers import (
|
|||
is_torch_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
@ -35,6 +35,7 @@ if is_torch_available():
|
|||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
def get_pipeline(self):
|
||||
# When
|
||||
|
@ -52,9 +53,11 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||
# Force model output to be L
|
||||
V, D = model.lm_head.weight.shape
|
||||
bias = torch.zeros(V, requires_grad=True)
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
bias[76] = 1
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
model.lm_head.weight = torch.nn.Parameter(weight)
|
||||
|
||||
# # Created with:
|
||||
# import tempfile
|
||||
|
|
Загрузка…
Ссылка в новой задаче