[fix] T5 ONNX test: model.to(torch_device) (#5769)
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Родитель
d0486c8bc2
Коммит
d533c7e9b9
|
@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs[0].return_tuple = True
|
||||
model = T5Model(config_and_inputs[0])
|
||||
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||
|
|
Загрузка…
Ссылка в новой задаче