load winml.dll from local folder

This commit is contained in:
Xiang Zhang 2019-11-26 17:57:45 -08:00
Родитель d0d7761411
Коммит 4f8b1ef33f
6 изменённых файлов: 116 добавлений и 21 удалений

Просмотреть файл

@ -0,0 +1,25 @@
#include "pch.h"
#include "Filehelper.h"
#include <libloaderapi.h>
#include <stdlib.h>
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
namespace FileHelper
{
std::string GetModulePath()
{
std::string val;
char modulePath[MAX_PATH] = {};
GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath));
char drive[_MAX_DRIVE];
char dir[_MAX_DIR];
char filename[_MAX_FNAME];
char ext[_MAX_EXT];
_splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
val = drive;
val += dir;
return val;
}
} // namespace FileHelper

Просмотреть файл

@ -0,0 +1,7 @@
#pragma once
#include <string>
#include <Windows.h>
namespace FileHelper
{
std::string GetModulePath();
}

Просмотреть файл

@ -9,7 +9,7 @@
<ProjectGuid>{2bf804d4-daa2-42be-9f21-0e94f021ef53}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>SqueezeNetObjectDetection</RootNamespace>
<WindowsTargetPlatformVersion Condition=" '$(WindowsTargetPlatformVersion)' == '' ">10.0.17763.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion Condition=" '$(WindowsTargetPlatformVersion)' == '' ">10.0.18362.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformMinVersion>10.0.17763.0</WindowsTargetPlatformMinVersion>
<ProjectName>SqueezeNetObjectDetectionCPP</ProjectName>
</PropertyGroup>
@ -123,9 +123,12 @@
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\..\..\SharedContent\models\SqueezeNet.h" />
<ClInclude Include="Filehelper.h" />
<ClInclude Include="pch.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="dllload.cpp" />
<ClCompile Include="Filehelper.cpp" />
<ClCompile Include="main.cpp" />
<ClCompile Include="pch.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>

Просмотреть файл

@ -3,10 +3,13 @@
<ItemGroup>
<ClCompile Include="main.cpp" />
<ClCompile Include="pch.cpp" />
<ClCompile Include="dllload.cpp" />
<ClCompile Include="Filehelper.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="pch.h" />
<ClInclude Include="..\..\..\..\SharedContent\models\SqueezeNet.h" />
<ClInclude Include="Filehelper.h" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />

Просмотреть файл

@ -0,0 +1,74 @@
#include "pch.h"
#include "FileHelper.h"
#include <winrt/Windows.Foundation.h>
#include <winstring.h>
extern "C"
{
HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept;
}
#ifdef _M_IX86
#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12")
#else
#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory")
#endif
bool starts_with(std::wstring_view value, std::wstring_view match) noexcept
{
return 0 == value.compare(0, match.size(), match);
}
int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept
{
*factory = nullptr;
std::wstring_view name{ WindowsGetStringRawBuffer(static_cast<HSTRING>(classId), nullptr),
WindowsGetStringLen(static_cast<HSTRING>(classId)) };
HMODULE library{ nullptr };
std::string modulePath = FileHelper::GetModulePath();
std::wstring winmlDllPath = std::wstring(modulePath.begin(), modulePath.end()) + L"Windows.AI.MachineLearning.dll";
if (starts_with(name, L"Windows.AI.MachineLearning."))
{
const wchar_t* libPath = winmlDllPath.c_str();
library = LoadLibraryW(libPath);
}
else
{
return OS_RoGetActivationFactory(static_cast<HSTRING>(classId), iid, factory);
}
// If the library is not found, get the default one
if (!library)
{
return OS_RoGetActivationFactory(static_cast<HSTRING>(classId), iid, factory);
}
using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory);
auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
if (!call)
{
HRESULT const hr = HRESULT_FROM_WIN32(GetLastError());
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
winrt::com_ptr<winrt::Windows::Foundation::IActivationFactory> activation_factory;
HRESULT const hr = call(static_cast<HSTRING>(classId), activation_factory.put_void());
if (FAILED(hr))
{
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
if (iid != winrt::guid_of<winrt::Windows::Foundation::IActivationFactory>())
{
return activation_factory->QueryInterface(iid, factory);
}
*factory = activation_factory.detach();
return S_OK;
}

Просмотреть файл

@ -2,6 +2,7 @@
//
#include "pch.h"
#include "FileHelper.h"
using namespace winrt;
using namespace Windows::Foundation;
@ -19,8 +20,6 @@ LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Default;
string deviceName = "default";
hstring imagePath;
// helper functions
string GetModulePath();
void LoadLabels();
VideoFrame LoadImageFile(hstring filePath, ColorManagementMode colorManagementMode);
void PrintResults(IVectorView<float> results);
@ -30,7 +29,7 @@ ColorManagementMode GetColorManagementMode(const LearningModel& model);
wstring GetModelPath()
{
wostringstream woss;
woss << GetModulePath().c_str();
woss << FileHelper::GetModulePath().c_str();
woss << "SqueezeNet.onnx";
return woss.str();
}
@ -118,26 +117,10 @@ bool ParseArgs(int argc, char* argv[])
return true;
}
string GetModulePath()
{
string val;
char modulePath[MAX_PATH] = {};
GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath));
char drive[_MAX_DRIVE];
char dir[_MAX_DIR];
char filename[_MAX_FNAME];
char ext[_MAX_EXT];
_splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
val = drive;
val += dir;
return val;
}
void LoadLabels()
{
// Parse labels from labels file. We know the file's entries are already sorted in order.
std::string labelsFilePath = GetModulePath() + labelsFileName;
std::string labelsFilePath = FileHelper::GetModulePath() + labelsFileName;
ifstream labelFile(labelsFilePath, ifstream::in);
if (labelFile.fail())
{