* test async run

* format

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-08-23 10:00:32 -07:00 коммит произвёл GitHub
Родитель 613c5c0c9d
Коммит 4926156789
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 38 добавлений и 1 удалений

Просмотреть файл

@ -3,6 +3,7 @@
import os
import unittest
import numpy as np
import threading
from onnx import checker, helper, onnx_pb as onnx_proto
from onnxruntime_extensions import PyOrtFunction, util, get_library_path
@ -37,6 +38,42 @@ class TestAzureOps(unittest.TestCase):
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def test_add_f_async(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('ADDF', '')])
x = np.array([1,2,3,4]).astype(np.float32)
y = np.array([4,3,2,1]).astype(np.float32)
ort_inputs = {
"auth_token": auth_token,
"X": x,
"Y": y
}
class RunState:
def __init__(self):
self.__match = True
def set_match(self, match):
self.__match = match
def is_match(self):
return self.__match
event = threading.Event()
def callback(res: np.ndarray, state: RunState, err: str) -> None:
if len(err) != 0 or not np.allclose(res, [5,5,5,5]):
state.set_match(False)
event.set()
run_state = RunState()
sess.run_async(None, ort_inputs, callback, run_state)
event.wait(10) # timeout in 10 sec
self.assertTrue(event.is_set())
self.assertTrue(run_state.is_match())
def test_add_f8(self):
if self.__enabled:
opt = SessionOptions()
@ -133,4 +170,4 @@ class TestAzureOps(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()