Fix the CI pipeline for the latest PyTorch release. (#759)
This commit is contained in:
Родитель
f1abea14e8
Коммит
b436d09459
|
@ -92,17 +92,17 @@ class TestPreprocessing(unittest.TestCase):
|
|||
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
|
||||
)
|
||||
inputs = tok.forward(test_sentence)
|
||||
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
|
||||
pnp.export(tok, test_sentence, opset_version=14, output_path="temp_tok2.onnx")
|
||||
|
||||
with open("temp_gpt2lmh.onnx", "wb") as f:
|
||||
torch.onnx.export(
|
||||
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
|
||||
gpt2_m, inputs, f, opset_version=14, do_constant_folding=False
|
||||
)
|
||||
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
|
||||
pnp.export(gpt2_m, *inputs, opset_version=14, do_constant_folding=False)
|
||||
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
|
||||
expected = full_model.forward(test_sentence)
|
||||
model = pnp.export(
|
||||
full_model, test_sentence, opset_version=12, do_constant_folding=False
|
||||
full_model, test_sentence, opset_version=14, do_constant_folding=False
|
||||
)
|
||||
mfunc = OrtPyFunction.from_model(model)
|
||||
actuals = mfunc(test_sentence)
|
||||
|
|
Загрузка…
Ссылка в новой задаче