Add UT testing async azure op (#536)
* test async run * format --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
Родитель
613c5c0c9d
Коммит
4926156789
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from onnx import checker, helper, onnx_pb as onnx_proto
|
||||
from onnxruntime_extensions import PyOrtFunction, util, get_library_path
|
||||
|
@ -37,6 +38,42 @@ class TestAzureOps(unittest.TestCase):
|
|||
out = sess.run(None, ort_inputs)[0]
|
||||
self.assertTrue(np.allclose(out, [5,5,5,5]))
|
||||
|
||||
def test_add_f_async(self):
|
||||
if self.__enabled:
|
||||
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"),
|
||||
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
|
||||
auth_token = np.array([os.getenv('ADDF', '')])
|
||||
x = np.array([1,2,3,4]).astype(np.float32)
|
||||
y = np.array([4,3,2,1]).astype(np.float32)
|
||||
ort_inputs = {
|
||||
"auth_token": auth_token,
|
||||
"X": x,
|
||||
"Y": y
|
||||
}
|
||||
|
||||
class RunState:
|
||||
def __init__(self):
|
||||
self.__match = True
|
||||
|
||||
def set_match(self, match):
|
||||
self.__match = match
|
||||
|
||||
def is_match(self):
|
||||
return self.__match
|
||||
|
||||
event = threading.Event()
|
||||
|
||||
def callback(res: np.ndarray, state: RunState, err: str) -> None:
|
||||
if len(err) != 0 or not np.allclose(res, [5,5,5,5]):
|
||||
state.set_match(False)
|
||||
event.set()
|
||||
|
||||
run_state = RunState()
|
||||
sess.run_async(None, ort_inputs, callback, run_state)
|
||||
event.wait(10) # timeout in 10 sec
|
||||
self.assertTrue(event.is_set())
|
||||
self.assertTrue(run_state.is_match())
|
||||
|
||||
def test_add_f8(self):
|
||||
if self.__enabled:
|
||||
opt = SessionOptions()
|
||||
|
@ -133,4 +170,4 @@ class TestAzureOps(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче