refactor adapter selection into utility class
This commit is contained in:
Родитель
876688d8b9
Коммит
ef354f9e16
|
@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.28302.56
|
|||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "AdapterSelection", "AdapterSelection\cpp\AdapterSelection.vcxproj", "{2E115EFB-F7EC-444E-A555-507D55A89BC9}"
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "AdapterSelectionTest", "AdapterSelectionTest\AdapterSelectionTest.vcxproj", "{5D062D17-5950-40B8-AD5B-970C95AE1E3C}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|x64 = Debug|x64
|
||||
|
@ -21,6 +23,14 @@ Global
|
|||
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x64.Build.0 = Release|x64
|
||||
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x86.ActiveCfg = Release|Win32
|
||||
{2E115EFB-F7EC-444E-A555-507D55A89BC9}.Release|x86.Build.0 = Release|Win32
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x64.ActiveCfg = Debug|x64
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x64.Build.0 = Debug|x64
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x86.ActiveCfg = Debug|Win32
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Debug|x86.Build.0 = Debug|Win32
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x64.ActiveCfg = Release|x64
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x64.Build.0 = Release|x64
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x86.ActiveCfg = Release|Win32
|
||||
{5D062D17-5950-40B8-AD5B-970C95AE1E3C}.Release|x86.Build.0 = Release|Win32
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props')" />
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|Win32">
|
||||
<Configuration>Debug</Configuration>
|
||||
|
@ -165,10 +166,11 @@
|
|||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="AdapterSelection.h" />
|
||||
<ClInclude Include="pch.h" />
|
||||
<ClInclude Include="resource.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="AdapterSelection.cpp" />
|
||||
<ClCompile Include="main.cpp" />
|
||||
<ClCompile Include="pch.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
|
||||
|
@ -180,7 +182,18 @@
|
|||
<ItemGroup>
|
||||
<CopyFileToFolders Include="Labels.txt" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.1.0.181214.3\build\native\Microsoft.Windows.CppWinRT.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
|
@ -18,7 +18,7 @@
|
|||
<ClInclude Include="pch.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="resource.h">
|
||||
<ClInclude Include="AdapterSelection.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
|
@ -29,10 +29,14 @@
|
|||
<ClCompile Include="main.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="AdapterSelection.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Text Include="Labels.txt">
|
||||
<Filter>Resource Files</Filter>
|
||||
</Text>
|
||||
<CopyFileToFolders Include="Labels.txt" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -20,7 +20,6 @@ void LoadLabels();
|
|||
VideoFrame LoadImageFile(hstring filePath);
|
||||
void PrintResults(IVectorView<float> results);
|
||||
bool ParseArgs(int argc, char* argv[]);
|
||||
LearningModelDevice getLearningModelDeviceFromAdapter(com_ptr<IDXGIAdapter1> spAdapter);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
@ -32,29 +31,14 @@ int main(int argc, char* argv[])
|
|||
return -1;
|
||||
}
|
||||
|
||||
// display all adapters
|
||||
com_ptr<IDXGIFactory1> spFactory;
|
||||
CreateDXGIFactory(__uuidof(IDXGIFactory), (void**)(spFactory.put()));
|
||||
std::vector <com_ptr<IDXGIAdapter1>> validAdapters;
|
||||
for (UINT i = 0; ; ++i) {
|
||||
com_ptr<IDXGIAdapter1> spAdapter;
|
||||
if (spFactory->EnumAdapters1(i, spAdapter.put()) != S_OK) {
|
||||
break;
|
||||
}
|
||||
std::vector <com_ptr<IDXGIAdapter1>> validAdapters = AdapterSelection::EnumerateAdapters(true);
|
||||
for (int i = 0; i < validAdapters.size(); i++) {
|
||||
DXGI_ADAPTER_DESC1 pDesc;
|
||||
spAdapter->GetDesc1(&pDesc);
|
||||
|
||||
// is a software adapter
|
||||
if (pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE || (pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c)) {
|
||||
continue;
|
||||
}
|
||||
// valid GPU adapter
|
||||
else {
|
||||
printf("Index: %" PRIu64 ", Description: %ls\n", validAdapters.size(), pDesc.Description);
|
||||
wcout << pDesc.Description << endl;
|
||||
validAdapters.push_back(spAdapter);
|
||||
}
|
||||
com_ptr<IDXGIAdapter1> currAdapter = validAdapters.at(i);
|
||||
currAdapter->GetDesc1(&pDesc);
|
||||
printf("Index: %d, Description: %ls\n", i, pDesc.Description);
|
||||
}
|
||||
|
||||
LearningModelDevice device = nullptr;
|
||||
if (validAdapters.size() == 0) {
|
||||
printf("There are no available adapters, running on CPU...\n");
|
||||
|
@ -70,8 +54,7 @@ int main(int argc, char* argv[])
|
|||
printf("Invalid index, please try again.\n");
|
||||
}
|
||||
printf("Selected adapter at index %d\n", selectedIndex);
|
||||
|
||||
device = getLearningModelDeviceFromAdapter(validAdapters.at(selectedIndex));
|
||||
device = AdapterSelection::GetLearningModelDeviceFromAdapter(validAdapters.at(selectedIndex));
|
||||
}
|
||||
|
||||
// load the model
|
||||
|
@ -110,29 +93,6 @@ int main(int argc, char* argv[])
|
|||
PrintResults(resultVector);
|
||||
}
|
||||
|
||||
LearningModelDevice getLearningModelDeviceFromAdapter(com_ptr<IDXGIAdapter1> spAdapter) {
|
||||
|
||||
// create D3D12Device
|
||||
com_ptr<IUnknown> spIUnknownAdapter;
|
||||
spAdapter->QueryInterface(IID_IUnknown, spIUnknownAdapter.put_void());
|
||||
com_ptr<ID3D12Device> spD3D12Device;
|
||||
D3D12CreateDevice(spIUnknownAdapter.get(), D3D_FEATURE_LEVEL_11_0, _uuidof(ID3D12Device), spD3D12Device.put_void());
|
||||
|
||||
// create D3D12 command queue from device
|
||||
D3D12_COMMAND_QUEUE_DESC queueDesc = {};
|
||||
queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
|
||||
queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
||||
com_ptr<ID3D12CommandQueue> spCommandQueue;
|
||||
spD3D12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(spCommandQueue.put()));
|
||||
|
||||
// create LearningModelDevice from command queue
|
||||
com_ptr<ILearningModelDeviceFactoryNative> dFactory =
|
||||
get_activation_factory<LearningModelDevice, ILearningModelDeviceFactoryNative>();
|
||||
com_ptr<::IUnknown> spLearningDevice;
|
||||
dFactory->CreateFromD3D12CommandQueue(spCommandQueue.get(), spLearningDevice.put());
|
||||
return spLearningDevice.as<LearningModelDevice>();
|
||||
}
|
||||
|
||||
bool ParseArgs(int argc, char* argv[])
|
||||
{
|
||||
if (argc < 3)
|
||||
|
|
|
@ -8,20 +8,18 @@
|
|||
#define NOMINMAX
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING 1 // The C++ Standard doesn't provide equivalent non-deprecated functionality yet.
|
||||
|
||||
#include <iostream>
|
||||
#include "AdapterSelection.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <vcruntime.h>
|
||||
#include <windows.h>
|
||||
|
||||
#include <winrt/Windows.Foundation.h>
|
||||
#include <winrt/Windows.AI.MachineLearning.h>
|
||||
#include <winrt/Windows.Media.h>
|
||||
#include <winrt/Windows.Storage.h>
|
||||
#include <winrt/Windows.Graphics.h>
|
||||
#include <winrt/Windows.Graphics.Imaging.h>
|
||||
|
||||
#include <windows.ai.machinelearning.native.h>
|
||||
#include <dxgi.h>
|
||||
|
||||
#include <string>
|
||||
#include <codecvt>
|
||||
|
@ -29,6 +27,5 @@
|
|||
#include <inttypes.h>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
using convert_type = std::codecvt_utf8<wchar_t>;
|
||||
using wstring_to_utf8 = std::wstring_convert<convert_type, wchar_t>;
|
Загрузка…
Ссылка в новой задаче