Add batching size 1 and session options (#264)

* check session options version

* Create Session Considering Support for Session Options

* Only set batchsize to 1

* Made batch size override always 1

* Merged

* Update minimum SDK
This commit is contained in:
Ryan Lai 2019-08-01 17:03:34 -07:00 коммит произвёл GitHub
Родитель d813cef500
Коммит 78c4b65525
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 64 добавлений и 37 удалений

Просмотреть файл

@ -23,7 +23,7 @@
<ProjectGuid>{E9D4AC92-8295-4FB4-BF7D-3FAF74B564E8}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>WinMLRunnerTest</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17763.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion>10.0.18362.0</WindowsTargetPlatformVersion>
<ProjectSubType>NativeUnitTestProject</ProjectSubType>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />

Просмотреть файл

@ -16,7 +16,7 @@ You can either download the x64 executable or build it yourself.
#### Prerequisites
- [Visual Studio 2017 Version 15.7.4 or Newer](https://developer.microsoft.com/en-us/windows/downloads)
- [Windows 10 - Build 17763 or higher](https://www.microsoft.com/en-us/software-download/windowsinsiderpreviewiso)
- [Windows SDK - Build 17763 or higher](https://www.microsoft.com/en-us/software-download/windowsinsiderpreviewSDK)
- [Windows SDK - Build 18362 or higher](https://www.microsoft.com/en-us/software-download/windowsinsiderpreviewSDK)
The easiest way to use these samples without using Git is to download the zip file containing the current version (using the following link or by clicking the "Download ZIP" button on the repo page). You can then unzip the entire archive and use the samples in Visual Studio 2017. Notes: Before you unzip the archive, right-click it, select Properties, and then select Unblock.
Be sure to unzip the entire archive, and not just individual samples. The samples all depend on the SharedContent folder in the archive. In Visual Studio 2017, the platform target defaults to ARM, so be sure to change that to x64 or x86 if you want to test on a non-ARM device. Reminder: If you unzip individual samples, they will not build due to references to other portions of the ZIP file that were not unzipped.

Просмотреть файл

@ -48,7 +48,7 @@
<ProjectGuid>{31653A2F-02CC-4A95-9880-BF86965FB262}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>WinMLRunner</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17763.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion>10.0.18362.0</WindowsTargetPlatformVersion>
<WindowsSDKDesktopARM64Support>true</WindowsSDKDesktopARM64Support>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />

Просмотреть файл

@ -31,7 +31,7 @@
<ProjectGuid>{C174D45D-C189-475B-B1A7-494939EE7491}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>WinMLRunnerScenarios</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17763.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion>10.0.18362.0</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">

Просмотреть файл

@ -47,7 +47,7 @@
<ProjectGuid>{C3BCBEA1-90E6-426F-88AC-64C274BCEF45}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>WinMLRunnerStaticLib</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17763.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion>10.0.18362.0</WindowsTargetPlatformVersion>
<WindowsSDKDesktopARM64Support>true</WindowsSDKDesktopARM64Support>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />

Просмотреть файл

@ -6,8 +6,9 @@
#include <Windows.Graphics.DirectX.Direct3D11.interop.h>
#include "Run.h"
#include "Scenarios.h"
#include <winrt/Windows.Foundation.Metadata.h>
using namespace winrt::Windows::Graphics::DirectX::Direct3D11;
using namespace winrt::Windows::Foundation::Metadata;
std::vector<ILearningModelFeatureValue> GenerateInputFeatures(const LearningModel& model, const CommandLineArgs& args,
InputBindingType inputBindingType,
InputDataType inputDataType,
@ -137,6 +138,57 @@ HRESULT CreateDXGIFactory2SEH(void** dxgiFactory)
}
#endif
void PopulateSessionOptions(LearningModelSessionOptions& sessionOptions)
{
// Batch Size Override as 1
try
{
sessionOptions.BatchSizeOverride(1);
}
catch (...)
{
printf("Batch size override couldn't be set.\n");
throw;
}
}
void CreateSessionConsideringSupportForSessionOptions(LearningModelSession& session,
LearningModel& model,
Profiler<WINML_MODEL_TEST_PERF>& profiler,
CommandLineArgs& args,
LearningModelDevice& learningModelDevice)
{
auto statics = get_activation_factory<ApiInformation, IApiInformationStatics>();
bool isSessionOptionsTypePresent = isSessionOptionsTypePresent =
statics.IsTypePresent(L"Windows.AI.MachineLearning.LearningModelSessionOptions");
if (isSessionOptionsTypePresent)
{
LearningModelSessionOptions sessionOptions;
PopulateSessionOptions(sessionOptions);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
session = LearningModelSession(model, learningModelDevice, sessionOptions);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_STOP(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
}
else
{
if (args.IsPerformanceCapture())
{
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
session = LearningModelSession(model, learningModelDevice);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_STOP(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
}
}
HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevice, LearningModel& model,
CommandLineArgs& args, OutputHelper& output, DeviceType deviceType,
DeviceCreationLocation deviceCreationLocation, Profiler<WINML_MODEL_TEST_PERF>& profiler)
@ -150,6 +202,7 @@ HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevic
#endif
try
{
LearningModelDevice learningModelDevice = NULL;
if (deviceCreationLocation == DeviceCreationLocation::UserD3DDevice && deviceType != DeviceType::CPU)
{
// Enumerate Adapters to pick the requested one.
@ -194,17 +247,7 @@ HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevic
THROW_IF_FAILED(hr);
winrtDevice = inspectableDevice.as<IDirect3DDevice>();
LearningModelDevice learningModelDevice = LearningModelDevice::CreateFromDirect3D11Device(winrtDevice);
output.PrintLearningModelDevice(deviceType, learningModelDevice);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
session = LearningModelSession(model, learningModelDevice);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_STOP(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
learningModelDevice = LearningModelDevice::CreateFromDirect3D11Device(winrtDevice);
}
#ifdef DXCORE_SUPPORTED_BUILD
else if ((TypeHelper::GetWinmlDeviceKind(deviceType) != LearningModelDeviceKind::Cpu) && !adapterName.empty())
@ -319,31 +362,15 @@ HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevic
com_ptr<::IUnknown> spUnkLearningModelDevice;
THROW_IF_FAILED(
factory->CreateFromD3D12CommandQueue(d3d12CommandQueue.get(), spUnkLearningModelDevice.put()));
if (args.IsPerformanceCapture())
{
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
session = LearningModelSession(model, spUnkLearningModelDevice.as<LearningModelDevice>());
if (args.IsPerformanceCapture())
{
WINML_PROFILING_STOP(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
learningModelDevice = spUnkLearningModelDevice.as<LearningModelDevice>();
}
#endif
else
{
LearningModelDevice learningModelDevice(TypeHelper::GetWinmlDeviceKind(deviceType));
output.PrintLearningModelDevice(deviceType, learningModelDevice);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
session = LearningModelSession(model, learningModelDevice);
if (args.IsPerformanceCapture())
{
WINML_PROFILING_STOP(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
}
learningModelDevice = LearningModelDevice(TypeHelper::GetWinmlDeviceKind(deviceType));
}
output.PrintLearningModelDevice(deviceType, learningModelDevice);
CreateSessionConsideringSupportForSessionOptions(session, model, profiler, args, learningModelDevice);
}
catch (hresult_error hr)
{