Implement c# binding for RunAsync.

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-08-07 22:19:38 -07:00 коммит произвёл GitHub
Родитель 249917a093
Коммит 063e9054b8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 235 добавлений и 4 удалений

Просмотреть файл

@ -6,7 +6,9 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.NetworkInformation;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
namespace Microsoft.ML.OnnxRuntime
{
@ -604,7 +606,7 @@ namespace Microsoft.ML.OnnxRuntime
throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count}).");
}
var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray();
var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata);
@ -636,7 +638,7 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr[] inputHandlesArray = new IntPtr[inputs.Count];
int count = 0;
foreach(var input in inputs)
foreach (var input in inputs)
{
inputNamesArray[count] = LookupInputMetadata(input.Key).ZeroTerminatedName;
inputHandlesArray[count] = input.Value.Handle;
@ -1044,6 +1046,130 @@ namespace Microsoft.ML.OnnxRuntime
}
}
private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status)
{
var hostHdl = GCHandle.FromIntPtr(userData);
CallbackHost host = (CallbackHost)hostHdl.Target;
try
{
host.callback(host.outputValues, status);
}
finally
{
hostHdl.Free();
}
}
private delegate void OrtCallbackDelegate(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status);
private static OrtCallbackDelegate ortCallback = new OrtCallbackDelegate(OrtCallback);
private delegate void UserCallbackDelegate(IReadOnlyCollection<OrtValue> outputs, IntPtr status);
private class CallbackHost
{
public IReadOnlyCollection<string> inputNames { get; }
public IReadOnlyCollection<OrtValue> inputValues { get; }
public IReadOnlyCollection<string> outputNames { get; }
public IReadOnlyCollection<OrtValue> outputValues { get; }
public UserCallbackDelegate callback { get; }
public IntPtr[] rawInputNames { get; }
public IntPtr[] rawInputValues { get; }
public IntPtr[] rawOutputNames { get; }
public IntPtr[] rawOutputValues { get; }
public CallbackHost(InferenceSession session,
IReadOnlyCollection<string> cbInputNames,
IReadOnlyCollection<OrtValue> cbinputValues,
IReadOnlyCollection<string> cbOutputNames,
IReadOnlyCollection<OrtValue> cbOutputValues,
UserCallbackDelegate userCallback)
{
inputNames = cbInputNames;
inputValues = cbinputValues;
outputNames = cbOutputNames;
outputValues = cbOutputValues;
callback = userCallback;
rawInputNames = LookupUtf8Names(inputNames, n => n, session.LookupInputMetadata);
rawInputValues = inputValues.Select(v => v.Handle).ToArray();
rawOutputNames = LookupUtf8Names(outputNames, n => n, session.LookupOutputMetadata);
rawOutputValues = outputValues.Select(v => v.Handle).ToArray();
}
}
private void RunAsyncInternal(RunOptions options,
IReadOnlyCollection<string> inputNames,
IReadOnlyCollection<OrtValue> inputValues,
IReadOnlyCollection<string> outputNames,
IReadOnlyCollection<OrtValue> outputValues,
UserCallbackDelegate callback)
{
CallbackHost host = new CallbackHost(this, inputNames, inputValues, outputNames, outputValues, callback);
var host_hdl = GCHandle.Alloc(host, GCHandleType.Normal);
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunAsync(
_nativeHandle,
options == null ? (IntPtr)null : options.Handle,
host.rawInputNames,
host.rawInputValues,
(UIntPtr)host.rawInputNames.Length,
host.rawOutputNames,
(UIntPtr)host.rawOutputNames.Length,
host.rawOutputValues,
Marshal.GetFunctionPointerForDelegate(ortCallback),
GCHandle.ToIntPtr(host_hdl)
));
}
catch (OnnxRuntimeException)
{
host_hdl.Free();
throw;
}
}
/// <summary>
/// Run inference asynchronous in a thread of intra-op thread pool
/// </summary>
/// <param name="options">run option, can be null</param>
/// <param name="inputNames">name of inputs</param>
/// <param name="inputValues">input ort values</param>
/// <param name="outputNames">name of outputs</param>
/// <param name="outputValues">output of ort values</param>
/// <returns>task to be awaited</returns>
/// <exception cref="OnnxRuntimeException"></exception>
public async Task<IReadOnlyCollection<OrtValue>> RunAsync(RunOptions options,
IReadOnlyCollection<string> inputNames,
IReadOnlyCollection<OrtValue> inputValues,
IReadOnlyCollection<string> outputNames,
IReadOnlyCollection<OrtValue> outputValues)
{
var promise = new TaskCompletionSource<IReadOnlyCollection<OrtValue>>();
RunAsyncInternal(options,
inputNames,
inputValues,
outputNames,
outputValues,
(IReadOnlyCollection<OrtValue> outputs, IntPtr status) =>
{
try
{
NativeApiStatus.VerifySuccess(status);
promise.SetResult(outputs);
}
catch (Exception ex)
{
promise.SetException(ex);
}
});
return await promise.Task;
}
#endregion
#region private methods

Просмотреть файл

@ -292,6 +292,8 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr UpdateROCMProviderOptions;
public IntPtr GetROCMProviderOptionsAsString;
public IntPtr ReleaseROCMProviderOptions;
public IntPtr CreateAndRegisterAllocatorV2;
public IntPtr RunAsync;
}
internal static class NativeMethods
@ -510,6 +512,8 @@ namespace Microsoft.ML.OnnxRuntime
OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions));
OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString));
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2));
OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync));
}
internal class NativeLib
@ -916,6 +920,32 @@ namespace Microsoft.ML.OnnxRuntime
out UIntPtr /*(ulong* out)*/ startTime);
public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2(
IntPtr /* (OrtEnv*) */ environment,
IntPtr /*(char*)*/ provider_type,
IntPtr /*(OrtMemoryInfo*)*/ mem_info,
IntPtr /*(OrtArenaCfg*)*/ arena_cfg,
IntPtr /*(char**)*/ provider_options_keys,
IntPtr /*(char**)*/ provider_options_values,
UIntPtr /*(size_t)*/num_keys);
public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync(
IntPtr /*(OrtSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
IntPtr[] /*(char**)*/ inputNames,
IntPtr[] /*(OrtValue*[])*/ inputValues,
UIntPtr /*(size_t)*/ inputCount,
IntPtr[] /*(char**)*/ outputNames,
UIntPtr /*(size_t)*/ outputCount,
IntPtr[] /*(OrtValue*[])*/ outputValues,
IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function
IntPtr /*(void*)*/ user_data
);
public static DOrtRunAsync OrtRunAsync;
#endregion InferenceSession API
#region SessionOptions API

Просмотреть файл

@ -5,7 +5,10 @@ using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
@ -476,7 +479,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
TensorElementType.Float, expectedShape))
{
// Run inference
var inputValues = new List<OrtValue>{ inputOrtValue }.AsReadOnly();
var inputValues = new List<OrtValue> { inputOrtValue }.AsReadOnly();
var outputValues = new List<OrtValue> { outputOrtValue }.AsReadOnly();
session.Run(runOptions, inputNames, inputValues,
expectedOutputNames, outputValues);
@ -1342,7 +1345,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[Fact(DisplayName = "TestModelInputBFLOAT16")]
private void TestModelInputBFLOAT16()
{
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) };
int[] inputShape = { 1, 5 };
// model takes 1x5 input of fixed type, echoes back
@ -2025,6 +2028,78 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
}
[Fact(DisplayName = "TestModelRunAsyncTask")]
private async void TestModelRunAsyncTask()
{
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
long[] shape = { 1, 5 };
var inputNames = new List<string> { "input" };
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };
var outputNames = new List<string> { "output" };
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
TensorElementType.Float16, shape) };
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
using (SessionOptions opt = new SessionOptions())
{
opt.IntraOpNumThreads = 2;
using (var session = new InferenceSession(model, opt))
{
try
{
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
var outputs = await task;
var valueOut = outputs.ElementAt<OrtValue>(0);
var float16s = valueOut.GetTensorDataAsSpan<Float16>().ToArray();
Assert.Equal(new Float16(16896), float16s[2]);
}
catch
{
Assert.True(false);
}
}
}
}
[Fact(DisplayName = "TestModelRunAsyncTaskFail")]
private async void TestModelRunAsyncTaskFail()
{
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
long[] shape = { 1, 5 };
var inputNames = new List<string> { "input" };
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };
var outputNames = new List<string> { "output" };
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
TensorElementType.Float16, shape) };
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
using (SessionOptions opt = new SessionOptions())
{
opt.IntraOpNumThreads = 1; // this will make RunAsync fail
string err = "";
using (var session = new InferenceSession(model, opt))
{
try
{
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
var outputs = await task;
}
catch (Exception ex)
{
err = ex.Message;
}
finally
{
Assert.Contains("intra op thread pool must have at least one thread for RunAsync", err);
}
}
}
}
}
}