RunAsync in C# (#16890)
Implement c# binding for RunAsync. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
Родитель
249917a093
Коммит
063e9054b8
|
@ -6,7 +6,9 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Net.NetworkInformation;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
|
@ -604,7 +606,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count}).");
|
||||
}
|
||||
|
||||
var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
|
||||
var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
|
||||
var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray();
|
||||
|
||||
var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata);
|
||||
|
@ -636,7 +638,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IntPtr[] inputHandlesArray = new IntPtr[inputs.Count];
|
||||
|
||||
int count = 0;
|
||||
foreach(var input in inputs)
|
||||
foreach (var input in inputs)
|
||||
{
|
||||
inputNamesArray[count] = LookupInputMetadata(input.Key).ZeroTerminatedName;
|
||||
inputHandlesArray[count] = input.Value.Handle;
|
||||
|
@ -1044,6 +1046,130 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
}
|
||||
|
||||
private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status)
|
||||
{
|
||||
var hostHdl = GCHandle.FromIntPtr(userData);
|
||||
CallbackHost host = (CallbackHost)hostHdl.Target;
|
||||
try
|
||||
{
|
||||
host.callback(host.outputValues, status);
|
||||
}
|
||||
finally
|
||||
{
|
||||
hostHdl.Free();
|
||||
}
|
||||
}
|
||||
|
||||
private delegate void OrtCallbackDelegate(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status);
|
||||
|
||||
private static OrtCallbackDelegate ortCallback = new OrtCallbackDelegate(OrtCallback);
|
||||
|
||||
private delegate void UserCallbackDelegate(IReadOnlyCollection<OrtValue> outputs, IntPtr status);
|
||||
|
||||
private class CallbackHost
|
||||
{
|
||||
public IReadOnlyCollection<string> inputNames { get; }
|
||||
public IReadOnlyCollection<OrtValue> inputValues { get; }
|
||||
public IReadOnlyCollection<string> outputNames { get; }
|
||||
public IReadOnlyCollection<OrtValue> outputValues { get; }
|
||||
public UserCallbackDelegate callback { get; }
|
||||
|
||||
public IntPtr[] rawInputNames { get; }
|
||||
public IntPtr[] rawInputValues { get; }
|
||||
public IntPtr[] rawOutputNames { get; }
|
||||
public IntPtr[] rawOutputValues { get; }
|
||||
|
||||
public CallbackHost(InferenceSession session,
|
||||
IReadOnlyCollection<string> cbInputNames,
|
||||
IReadOnlyCollection<OrtValue> cbinputValues,
|
||||
IReadOnlyCollection<string> cbOutputNames,
|
||||
IReadOnlyCollection<OrtValue> cbOutputValues,
|
||||
UserCallbackDelegate userCallback)
|
||||
{
|
||||
|
||||
inputNames = cbInputNames;
|
||||
inputValues = cbinputValues;
|
||||
outputNames = cbOutputNames;
|
||||
outputValues = cbOutputValues;
|
||||
callback = userCallback;
|
||||
|
||||
rawInputNames = LookupUtf8Names(inputNames, n => n, session.LookupInputMetadata);
|
||||
rawInputValues = inputValues.Select(v => v.Handle).ToArray();
|
||||
|
||||
rawOutputNames = LookupUtf8Names(outputNames, n => n, session.LookupOutputMetadata);
|
||||
rawOutputValues = outputValues.Select(v => v.Handle).ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
private void RunAsyncInternal(RunOptions options,
|
||||
IReadOnlyCollection<string> inputNames,
|
||||
IReadOnlyCollection<OrtValue> inputValues,
|
||||
IReadOnlyCollection<string> outputNames,
|
||||
IReadOnlyCollection<OrtValue> outputValues,
|
||||
UserCallbackDelegate callback)
|
||||
{
|
||||
CallbackHost host = new CallbackHost(this, inputNames, inputValues, outputNames, outputValues, callback);
|
||||
var host_hdl = GCHandle.Alloc(host, GCHandleType.Normal);
|
||||
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunAsync(
|
||||
_nativeHandle,
|
||||
options == null ? (IntPtr)null : options.Handle,
|
||||
host.rawInputNames,
|
||||
host.rawInputValues,
|
||||
(UIntPtr)host.rawInputNames.Length,
|
||||
host.rawOutputNames,
|
||||
(UIntPtr)host.rawOutputNames.Length,
|
||||
host.rawOutputValues,
|
||||
Marshal.GetFunctionPointerForDelegate(ortCallback),
|
||||
GCHandle.ToIntPtr(host_hdl)
|
||||
));
|
||||
}
|
||||
catch (OnnxRuntimeException)
|
||||
{
|
||||
host_hdl.Free();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run inference asynchronous in a thread of intra-op thread pool
|
||||
/// </summary>
|
||||
/// <param name="options">run option, can be null</param>
|
||||
/// <param name="inputNames">name of inputs</param>
|
||||
/// <param name="inputValues">input ort values</param>
|
||||
/// <param name="outputNames">name of outputs</param>
|
||||
/// <param name="outputValues">output of ort values</param>
|
||||
/// <returns>task to be awaited</returns>
|
||||
/// <exception cref="OnnxRuntimeException"></exception>
|
||||
public async Task<IReadOnlyCollection<OrtValue>> RunAsync(RunOptions options,
|
||||
IReadOnlyCollection<string> inputNames,
|
||||
IReadOnlyCollection<OrtValue> inputValues,
|
||||
IReadOnlyCollection<string> outputNames,
|
||||
IReadOnlyCollection<OrtValue> outputValues)
|
||||
{
|
||||
var promise = new TaskCompletionSource<IReadOnlyCollection<OrtValue>>();
|
||||
RunAsyncInternal(options,
|
||||
inputNames,
|
||||
inputValues,
|
||||
outputNames,
|
||||
outputValues,
|
||||
(IReadOnlyCollection<OrtValue> outputs, IntPtr status) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(status);
|
||||
promise.SetResult(outputs);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
promise.SetException(ex);
|
||||
}
|
||||
});
|
||||
return await promise.Task;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region private methods
|
||||
|
|
|
@ -292,6 +292,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr UpdateROCMProviderOptions;
|
||||
public IntPtr GetROCMProviderOptionsAsString;
|
||||
public IntPtr ReleaseROCMProviderOptions;
|
||||
public IntPtr CreateAndRegisterAllocatorV2;
|
||||
public IntPtr RunAsync;
|
||||
}
|
||||
|
||||
internal static class NativeMethods
|
||||
|
@ -510,6 +512,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions));
|
||||
OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString));
|
||||
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
|
||||
OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2));
|
||||
OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync));
|
||||
}
|
||||
|
||||
internal class NativeLib
|
||||
|
@ -916,6 +920,32 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
out UIntPtr /*(ulong* out)*/ startTime);
|
||||
public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2(
|
||||
IntPtr /* (OrtEnv*) */ environment,
|
||||
IntPtr /*(char*)*/ provider_type,
|
||||
IntPtr /*(OrtMemoryInfo*)*/ mem_info,
|
||||
IntPtr /*(OrtArenaCfg*)*/ arena_cfg,
|
||||
IntPtr /*(char**)*/ provider_options_keys,
|
||||
IntPtr /*(char**)*/ provider_options_values,
|
||||
UIntPtr /*(size_t)*/num_keys);
|
||||
public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
|
||||
IntPtr[] /*(char**)*/ inputNames,
|
||||
IntPtr[] /*(OrtValue*[])*/ inputValues,
|
||||
UIntPtr /*(size_t)*/ inputCount,
|
||||
IntPtr[] /*(char**)*/ outputNames,
|
||||
UIntPtr /*(size_t)*/ outputCount,
|
||||
IntPtr[] /*(OrtValue*[])*/ outputValues,
|
||||
IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function
|
||||
IntPtr /*(void*)*/ user_data
|
||||
);
|
||||
public static DOrtRunAsync OrtRunAsync;
|
||||
|
||||
#endregion InferenceSession API
|
||||
|
||||
#region SessionOptions API
|
||||
|
|
|
@ -5,7 +5,10 @@ using Microsoft.ML.OnnxRuntime.Tensors;
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
@ -476,7 +479,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
TensorElementType.Float, expectedShape))
|
||||
{
|
||||
// Run inference
|
||||
var inputValues = new List<OrtValue>{ inputOrtValue }.AsReadOnly();
|
||||
var inputValues = new List<OrtValue> { inputOrtValue }.AsReadOnly();
|
||||
var outputValues = new List<OrtValue> { outputOrtValue }.AsReadOnly();
|
||||
session.Run(runOptions, inputNames, inputValues,
|
||||
expectedOutputNames, outputValues);
|
||||
|
@ -1342,7 +1345,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
[Fact(DisplayName = "TestModelInputBFLOAT16")]
|
||||
private void TestModelInputBFLOAT16()
|
||||
{
|
||||
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
|
||||
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
|
||||
new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) };
|
||||
int[] inputShape = { 1, 5 };
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
|
@ -2025,6 +2028,78 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestModelRunAsyncTask")]
|
||||
private async void TestModelRunAsyncTask()
|
||||
{
|
||||
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
|
||||
long[] shape = { 1, 5 };
|
||||
|
||||
var inputNames = new List<string> { "input" };
|
||||
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };
|
||||
|
||||
var outputNames = new List<string> { "output" };
|
||||
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
|
||||
TensorElementType.Float16, shape) };
|
||||
|
||||
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
|
||||
using (SessionOptions opt = new SessionOptions())
|
||||
{
|
||||
opt.IntraOpNumThreads = 2;
|
||||
using (var session = new InferenceSession(model, opt))
|
||||
{
|
||||
try
|
||||
{
|
||||
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
|
||||
var outputs = await task;
|
||||
var valueOut = outputs.ElementAt<OrtValue>(0);
|
||||
var float16s = valueOut.GetTensorDataAsSpan<Float16>().ToArray();
|
||||
Assert.Equal(new Float16(16896), float16s[2]);
|
||||
}
|
||||
catch
|
||||
{
|
||||
Assert.True(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestModelRunAsyncTaskFail")]
|
||||
private async void TestModelRunAsyncTaskFail()
|
||||
{
|
||||
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
|
||||
long[] shape = { 1, 5 };
|
||||
|
||||
var inputNames = new List<string> { "input" };
|
||||
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };
|
||||
|
||||
var outputNames = new List<string> { "output" };
|
||||
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
|
||||
TensorElementType.Float16, shape) };
|
||||
|
||||
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
|
||||
using (SessionOptions opt = new SessionOptions())
|
||||
{
|
||||
opt.IntraOpNumThreads = 1; // this will make RunAsync fail
|
||||
string err = "";
|
||||
using (var session = new InferenceSession(model, opt))
|
||||
{
|
||||
try
|
||||
{
|
||||
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
|
||||
var outputs = await task;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
err = ex.Message;
|
||||
}
|
||||
finally
|
||||
{
|
||||
Assert.Contains("intra op thread pool must have at least one thread for RunAsync", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче