onnxruntime-extensions/test/test_azure_ops.py

136 строки
5.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import unittest
import numpy as np
from onnx import checker, helper, onnx_pb as onnx_proto
from onnxruntime_extensions import PyOrtFunction, util, get_library_path
from onnxruntime import *
script_dir = os.path.dirname(os.path.realpath(__file__))
ort_ext_root = os.path.abspath(os.path.join(script_dir, ".."))
test_data_dir = os.path.join(ort_ext_root, "test", "data", "azure")
class TestAzureOps(unittest.TestCase):
def __init__(self, config):
super().__init__(config)
self.__enabled = os.getenv('OCOS_ENABLE_AZURE','') == '1'
if self.__enabled:
self.__opt = SessionOptions()
self.__opt.register_custom_ops_library(get_library_path())
def test_add_f(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('ADDF', '')])
x = np.array([1,2,3,4]).astype(np.float32)
y = np.array([4,3,2,1]).astype(np.float32)
ort_inputs = {
"auth_token": auth_token,
"X": x,
"Y": y
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def test_add_f8(self):
if self.__enabled:
opt = SessionOptions()
opt.register_custom_ops_library(get_library_path())
sess = InferenceSession(os.path.join(test_data_dir, "triton_addf8.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('ADDF8', '')])
x = np.array([1,2,3,4]).astype(np.double)
y = np.array([4,3,2,1]).astype(np.double)
ort_inputs = {
"auth_token": auth_token,
"X": x,
"Y": y
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def test_add_i4(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_addi4.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('ADDI4', '')])
x = np.array([1,2,3,4]).astype(np.int32)
y = np.array([4,3,2,1]).astype(np.int32)
ort_inputs = {
"auth_token": auth_token,
"X": x,
"Y": y
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [5,5,5,5]))
def test_and(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_and.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('AND', '')])
x = np.array([True, True])
y = np.array([True, False])
ort_inputs = {
"auth_token": auth_token,
"X": x,
"Y": y
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue(np.allclose(out, [True, False]))
def test_str(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "triton_str.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('STR', '')])
str_in = np.array(['this is the input'])
ort_inputs = {
"auth_token": auth_token,
"str_in": str_in
}
outs = sess.run(None, ort_inputs)
self.assertEqual(len(outs), 2)
self.assertEqual(outs[0], ['this is the input'])
self.assertEqual(outs[1], ['this is the input'])
def test_open_ai_audio(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "openai_audio.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('AUDIO', '')])
model = np.array(['whisper-1'])
response_format = np.array(['text'])
with open(os.path.join(test_data_dir, "test16.wav"), "rb") as _f:
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
ort_inputs = {
"auth_token": auth_token,
"model_name": model,
"response_format": response_format,
"file": audio_blob,
}
out = sess.run(None, ort_inputs)[0]
self.assertEqual(out, ['This is a test recording to test the Whisper model.\n'])
def test_azure_chat(self):
if self.__enabled:
sess = InferenceSession(os.path.join(test_data_dir, "azure_chat.onnx"),
self.__opt, providers=["CPUExecutionProvider", "AzureExecutionProvider"])
auth_token = np.array([os.getenv('CHAT', '')])
chat = np.array([r'{"messages":[{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Does Azure OpenAI support customer managed keys?"},{"role": "assistant", "content": "Yes, customer managed keys are supported by Azure OpenAI."},{"role": "user", "content": "Do other Azure AI services support this too?"}]}'])
ort_inputs = {
"auth_token": auth_token,
"chat": chat,
}
out = sess.run(None, ort_inputs)[0]
self.assertTrue('chat.completion' in out[0])
if __name__ == '__main__':
unittest.main()