refactor adapter selection into utility class

This commit is contained in:
Ori Levari 2019-01-07 16:16:37 -08:00
Родитель 876688d8b9
Коммит ef354f9e16
5 изменённых файлов: 41 добавлений и 57 удалений

Просмотреть файл

@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.28302.56
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "AdapterSelection", "AdapterSelection\cpp\AdapterSelection.vcxproj", "{2E115EFB-F7EC-444E-A555-507D55A89BC9}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "AdapterSelectionTest", "AdapterSelectionTest\AdapterSelectionTest.vcxproj", "{5D062D17-5950-40B8-AD5B-970C95AE1E3C}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
@ -21,6 +23,14 @@ Global
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x64.Build.0 = Release|x64
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x86.ActiveCfg = Release|Win32
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x86.Build.0 = Release|Win32
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x64.ActiveCfg = Debug|x64
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x64.Build.0 = Debug|x64
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x86.ActiveCfg = Debug|Win32
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x86.Build.0 = Debug|Win32
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x64.ActiveCfg = Release|x64
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x64.Build.0 = Release|x64
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x86.ActiveCfg = Release|Win32
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x86.Build.0 = Release|Win32
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

Просмотреть файл

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props')" />
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
@ -165,10 +166,11 @@
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="AdapterSelection.h" />
<ClInclude Include="pch.h" />
<ClInclude Include="resource.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="AdapterSelection.cpp" />
<ClCompile Include="main.cpp" />
<ClCompile Include="pch.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
@ -180,7 +182,18 @@
<ItemGroup>
<CopyFileToFolders Include="Labels.txt" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets'))" />
</Target>
</Project>

Просмотреть файл

@ -18,7 +18,7 @@
<ClInclude Include="pch.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="resource.h">
<ClInclude Include="AdapterSelection.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
@ -29,10 +29,14 @@
<ClCompile Include="main.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="AdapterSelection.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<Text Include="Labels.txt">
<Filter>Resource Files</Filter>
</Text>
<CopyFileToFolders Include="Labels.txt" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
</Project>

Просмотреть файл

@ -20,7 +20,6 @@ void LoadLabels();
VideoFrame LoadImageFile(hstring filePath);
void PrintResults(IVectorView<float> results);
bool ParseArgs(int argc, char* argv[]);
LearningModelDevice getLearningModelDeviceFromAdapter(com_ptr<IDXGIAdapter1> spAdapter);
int main(int argc, char* argv[])
{
@ -32,29 +31,14 @@ int main(int argc, char* argv[])
return -1;
}
// display all adapters
com_ptr<IDXGIFactory1> spFactory;
CreateDXGIFactory(__uuidof(IDXGIFactory), (void**)(spFactory.put()));
std::vector <com_ptr<IDXGIAdapter1>> validAdapters;
for (UINT i = 0; ; ++i) {
com_ptr<IDXGIAdapter1> spAdapter;
if (spFactory->EnumAdapters1(i, spAdapter.put()) != S_OK) {
break;
}
std::vector <com_ptr<IDXGIAdapter1>> validAdapters = AdapterSelection::EnumerateAdapters(true);
for (int i = 0; i < validAdapters.size(); i++) {
DXGI_ADAPTER_DESC1 pDesc;
spAdapter->GetDesc1(&pDesc);
// is a software adapter
if (pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE || (pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c)) {
continue;
}
// valid GPU adapter
else {
printf("Index: %" PRIu64 ", Description: %ls\n", validAdapters.size(), pDesc.Description);
wcout << pDesc.Description << endl;
validAdapters.push_back(spAdapter);
}
com_ptr<IDXGIAdapter1> currAdapter = validAdapters.at(i);
currAdapter->GetDesc1(&pDesc);
printf("Index: %d, Description: %ls\n", i, pDesc.Description);
}
LearningModelDevice device = nullptr;
if (validAdapters.size() == 0) {
printf("There are no available adapters, running on CPU...\n");
@ -70,8 +54,7 @@ int main(int argc, char* argv[])
printf("Invalid index, please try again.\n");
}
printf("Selected adapter at index %d\n", selectedIndex);
device = getLearningModelDeviceFromAdapter(validAdapters.at(selectedIndex));
device = AdapterSelection::GetLearningModelDeviceFromAdapter(validAdapters.at(selectedIndex));
}
// load the model
@ -110,29 +93,6 @@ int main(int argc, char* argv[])
PrintResults(resultVector);
}
LearningModelDevice getLearningModelDeviceFromAdapter(com_ptr<IDXGIAdapter1> spAdapter) {
// create D3D12Device
com_ptr<IUnknown> spIUnknownAdapter;
spAdapter->QueryInterface(IID_IUnknown, spIUnknownAdapter.put_void());
com_ptr<ID3D12Device> spD3D12Device;
D3D12CreateDevice(spIUnknownAdapter.get(), D3D_FEATURE_LEVEL_11_0, _uuidof(ID3D12Device), spD3D12Device.put_void());
// create D3D12 command queue from device
D3D12_COMMAND_QUEUE_DESC queueDesc = {};
queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
com_ptr<ID3D12CommandQueue> spCommandQueue;
spD3D12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(spCommandQueue.put()));
// create LearningModelDevice from command queue
com_ptr<ILearningModelDeviceFactoryNative> dFactory =
get_activation_factory<LearningModelDevice, ILearningModelDeviceFactoryNative>();
com_ptr<::IUnknown> spLearningDevice;
dFactory->CreateFromD3D12CommandQueue(spCommandQueue.get(), spLearningDevice.put());
return spLearningDevice.as<LearningModelDevice>();
}
bool ParseArgs(int argc, char* argv[])
{
if (argc < 3)

Просмотреть файл

@ -8,20 +8,18 @@
#define NOMINMAX
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING 1 // The C++ Standard doesn't provide equivalent non-deprecated functionality yet.
#include <iostream>
#include "AdapterSelection.h"
#include <iostream>
#include <vcruntime.h>
#include <windows.h>
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.AI.MachineLearning.h>
#include <winrt/Windows.Media.h>
#include <winrt/Windows.Storage.h>
#include <winrt/Windows.Graphics.h>
#include <winrt/Windows.Graphics.Imaging.h>
#include <windows.ai.machinelearning.native.h>
#include <dxgi.h>
#include <string>
#include <codecvt>
@ -29,6 +27,5 @@
#include <inttypes.h>
#include <algorithm>
using convert_type = std::codecvt_utf8<wchar_t>;
using wstring_to_utf8 = std::wstring_convert<convert_type, wchar_t>;