load winml.dll from local folder
This commit is contained in:
Родитель
d0d7761411
Коммит
4f8b1ef33f
|
@ -0,0 +1,25 @@
|
|||
#include "pch.h"
|
||||
#include "Filehelper.h"
|
||||
#include <libloaderapi.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
|
||||
|
||||
namespace FileHelper
|
||||
{
|
||||
std::string GetModulePath()
|
||||
{
|
||||
std::string val;
|
||||
char modulePath[MAX_PATH] = {};
|
||||
GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath));
|
||||
char drive[_MAX_DRIVE];
|
||||
char dir[_MAX_DIR];
|
||||
char filename[_MAX_FNAME];
|
||||
char ext[_MAX_EXT];
|
||||
_splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
|
||||
|
||||
val = drive;
|
||||
val += dir;
|
||||
return val;
|
||||
}
|
||||
} // namespace FileHelper
|
|
@ -0,0 +1,7 @@
|
|||
#pragma once
|
||||
#include <string>
|
||||
#include <Windows.h>
|
||||
namespace FileHelper
|
||||
{
|
||||
std::string GetModulePath();
|
||||
}
|
|
@ -9,7 +9,7 @@
|
|||
<ProjectGuid>{2bf804d4-daa2-42be-9f21-0e94f021ef53}</ProjectGuid>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>SqueezeNetObjectDetection</RootNamespace>
|
||||
<WindowsTargetPlatformVersion Condition=" '$(WindowsTargetPlatformVersion)' == '' ">10.0.17763.0</WindowsTargetPlatformVersion>
|
||||
<WindowsTargetPlatformVersion Condition=" '$(WindowsTargetPlatformVersion)' == '' ">10.0.18362.0</WindowsTargetPlatformVersion>
|
||||
<WindowsTargetPlatformMinVersion>10.0.17763.0</WindowsTargetPlatformMinVersion>
|
||||
<ProjectName>SqueezeNetObjectDetectionCPP</ProjectName>
|
||||
</PropertyGroup>
|
||||
|
@ -123,9 +123,12 @@
|
|||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\..\..\SharedContent\models\SqueezeNet.h" />
|
||||
<ClInclude Include="Filehelper.h" />
|
||||
<ClInclude Include="pch.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="dllload.cpp" />
|
||||
<ClCompile Include="Filehelper.cpp" />
|
||||
<ClCompile Include="main.cpp" />
|
||||
<ClCompile Include="pch.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
<ItemGroup>
|
||||
<ClCompile Include="main.cpp" />
|
||||
<ClCompile Include="pch.cpp" />
|
||||
<ClCompile Include="dllload.cpp" />
|
||||
<ClCompile Include="Filehelper.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="pch.h" />
|
||||
<ClInclude Include="..\..\..\..\SharedContent\models\SqueezeNet.h" />
|
||||
<ClInclude Include="Filehelper.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
#include "pch.h"
|
||||
#include "FileHelper.h"
|
||||
#include <winrt/Windows.Foundation.h>
|
||||
#include <winstring.h>
|
||||
|
||||
extern "C"
|
||||
{
|
||||
HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept;
|
||||
}
|
||||
|
||||
#ifdef _M_IX86
|
||||
#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12")
|
||||
#else
|
||||
#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory")
|
||||
#endif
|
||||
|
||||
bool starts_with(std::wstring_view value, std::wstring_view match) noexcept
|
||||
{
|
||||
return 0 == value.compare(0, match.size(), match);
|
||||
}
|
||||
|
||||
int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept
|
||||
{
|
||||
*factory = nullptr;
|
||||
std::wstring_view name{ WindowsGetStringRawBuffer(static_cast<HSTRING>(classId), nullptr),
|
||||
WindowsGetStringLen(static_cast<HSTRING>(classId)) };
|
||||
HMODULE library{ nullptr };
|
||||
|
||||
std::string modulePath = FileHelper::GetModulePath();
|
||||
std::wstring winmlDllPath = std::wstring(modulePath.begin(), modulePath.end()) + L"Windows.AI.MachineLearning.dll";
|
||||
|
||||
if (starts_with(name, L"Windows.AI.MachineLearning."))
|
||||
{
|
||||
const wchar_t* libPath = winmlDllPath.c_str();
|
||||
library = LoadLibraryW(libPath);
|
||||
}
|
||||
else
|
||||
{
|
||||
return OS_RoGetActivationFactory(static_cast<HSTRING>(classId), iid, factory);
|
||||
}
|
||||
|
||||
// If the library is not found, get the default one
|
||||
if (!library)
|
||||
{
|
||||
return OS_RoGetActivationFactory(static_cast<HSTRING>(classId), iid, factory);
|
||||
}
|
||||
|
||||
using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory);
|
||||
auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
|
||||
|
||||
if (!call)
|
||||
{
|
||||
HRESULT const hr = HRESULT_FROM_WIN32(GetLastError());
|
||||
WINRT_VERIFY(FreeLibrary(library));
|
||||
return hr;
|
||||
}
|
||||
|
||||
winrt::com_ptr<winrt::Windows::Foundation::IActivationFactory> activation_factory;
|
||||
HRESULT const hr = call(static_cast<HSTRING>(classId), activation_factory.put_void());
|
||||
|
||||
if (FAILED(hr))
|
||||
{
|
||||
WINRT_VERIFY(FreeLibrary(library));
|
||||
return hr;
|
||||
}
|
||||
|
||||
if (iid != winrt::guid_of<winrt::Windows::Foundation::IActivationFactory>())
|
||||
{
|
||||
return activation_factory->QueryInterface(iid, factory);
|
||||
}
|
||||
|
||||
*factory = activation_factory.detach();
|
||||
return S_OK;
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
//
|
||||
|
||||
#include "pch.h"
|
||||
#include "FileHelper.h"
|
||||
|
||||
using namespace winrt;
|
||||
using namespace Windows::Foundation;
|
||||
|
@ -19,8 +20,6 @@ LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Default;
|
|||
string deviceName = "default";
|
||||
hstring imagePath;
|
||||
|
||||
// helper functions
|
||||
string GetModulePath();
|
||||
void LoadLabels();
|
||||
VideoFrame LoadImageFile(hstring filePath, ColorManagementMode colorManagementMode);
|
||||
void PrintResults(IVectorView<float> results);
|
||||
|
@ -30,7 +29,7 @@ ColorManagementMode GetColorManagementMode(const LearningModel& model);
|
|||
wstring GetModelPath()
|
||||
{
|
||||
wostringstream woss;
|
||||
woss << GetModulePath().c_str();
|
||||
woss << FileHelper::GetModulePath().c_str();
|
||||
woss << "SqueezeNet.onnx";
|
||||
return woss.str();
|
||||
}
|
||||
|
@ -118,26 +117,10 @@ bool ParseArgs(int argc, char* argv[])
|
|||
return true;
|
||||
}
|
||||
|
||||
string GetModulePath()
|
||||
{
|
||||
string val;
|
||||
char modulePath[MAX_PATH] = {};
|
||||
GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath));
|
||||
char drive[_MAX_DRIVE];
|
||||
char dir[_MAX_DIR];
|
||||
char filename[_MAX_FNAME];
|
||||
char ext[_MAX_EXT];
|
||||
_splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
|
||||
|
||||
val = drive;
|
||||
val += dir;
|
||||
return val;
|
||||
}
|
||||
|
||||
void LoadLabels()
|
||||
{
|
||||
// Parse labels from labels file. We know the file's entries are already sorted in order.
|
||||
std::string labelsFilePath = GetModulePath() + labelsFileName;
|
||||
std::string labelsFilePath = FileHelper::GetModulePath() + labelsFileName;
|
||||
ifstream labelFile(labelsFilePath, ifstream::in);
|
||||
if (labelFile.fail())
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче