Fix the CI pipeline for the latest PyTorch release. (#759)

This commit is contained in:
Wenbing Li 2024-07-08 16:21:48 -07:00 коммит произвёл GitHub
Родитель f1abea14e8
Коммит b436d09459
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -92,17 +92,17 @@ class TestPreprocessing(unittest.TestCase):
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
)
inputs = tok.forward(test_sentence)
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
pnp.export(tok, test_sentence, opset_version=14, output_path="temp_tok2.onnx")
with open("temp_gpt2lmh.onnx", "wb") as f:
torch.onnx.export(
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
gpt2_m, inputs, f, opset_version=14, do_constant_folding=False
)
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
pnp.export(gpt2_m, *inputs, opset_version=14, do_constant_folding=False)
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
expected = full_model.forward(test_sentence)
model = pnp.export(
full_model, test_sentence, opset_version=12, do_constant_folding=False
full_model, test_sentence, opset_version=14, do_constant_folding=False
)
mfunc = OrtPyFunction.from_model(model)
actuals = mfunc(test_sentence)