From 4f8b1ef33f3b19beeb818abd306a0ce3bee9199a Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Tue, 26 Nov 2019 17:57:45 -0800 Subject: [PATCH] load winml.dll from local folder --- .../Desktop/cpp/Filehelper.cpp | 25 +++++++ .../Desktop/cpp/Filehelper.h | 7 ++ .../cpp/SqueezeNetObjectDetectionCPP.vcxproj | 5 +- ...ueezeNetObjectDetectionCPP.vcxproj.filters | 3 + .../Desktop/cpp/dllload.cpp | 74 +++++++++++++++++++ .../Desktop/cpp/main.cpp | 23 +----- 6 files changed, 116 insertions(+), 21 deletions(-) create mode 100644 Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.cpp create mode 100644 Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.h create mode 100644 Samples/SqueezeNetObjectDetection/Desktop/cpp/dllload.cpp diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.cpp b/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.cpp new file mode 100644 index 00000000..17804360 --- /dev/null +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.cpp @@ -0,0 +1,25 @@ +#include "pch.h" +#include "Filehelper.h" +#include +#include + +EXTERN_C IMAGE_DOS_HEADER __ImageBase; + +namespace FileHelper +{ + std::string GetModulePath() + { + std::string val; + char modulePath[MAX_PATH] = {}; + GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath)); + char drive[_MAX_DRIVE]; + char dir[_MAX_DIR]; + char filename[_MAX_FNAME]; + char ext[_MAX_EXT]; + _splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); + + val = drive; + val += dir; + return val; + } +} // namespace FileHelper diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.h b/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.h new file mode 100644 index 00000000..b993a282 --- /dev/null +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/Filehelper.h @@ -0,0 +1,7 @@ +#pragma once +#include +#include +namespace FileHelper +{ + std::string GetModulePath(); +} \ No newline at end of file diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj b/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj index 33944cc4..5866edf1 100644 --- a/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj @@ -9,7 +9,7 @@ {2bf804d4-daa2-42be-9f21-0e94f021ef53} Win32Proj SqueezeNetObjectDetection - 10.0.17763.0 + 10.0.18362.0 10.0.17763.0 SqueezeNetObjectDetectionCPP @@ -123,9 +123,12 @@ + + + Create diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj.filters b/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj.filters index 2f137129..1346510b 100644 --- a/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj.filters +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/SqueezeNetObjectDetectionCPP.vcxproj.filters @@ -3,10 +3,13 @@ + + + diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/dllload.cpp b/Samples/SqueezeNetObjectDetection/Desktop/cpp/dllload.cpp new file mode 100644 index 00000000..54acdaa2 --- /dev/null +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/dllload.cpp @@ -0,0 +1,74 @@ +#include "pch.h" +#include "FileHelper.h" +#include +#include + +extern "C" +{ + HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; +} + +#ifdef _M_IX86 +#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") +#else +#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") +#endif + +bool starts_with(std::wstring_view value, std::wstring_view match) noexcept +{ + return 0 == value.compare(0, match.size(), match); +} + +int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept +{ + *factory = nullptr; + std::wstring_view name{ WindowsGetStringRawBuffer(static_cast(classId), nullptr), + WindowsGetStringLen(static_cast(classId)) }; + HMODULE library{ nullptr }; + + std::string modulePath = FileHelper::GetModulePath(); + std::wstring winmlDllPath = std::wstring(modulePath.begin(), modulePath.end()) + L"Windows.AI.MachineLearning.dll"; + + if (starts_with(name, L"Windows.AI.MachineLearning.")) + { + const wchar_t* libPath = winmlDllPath.c_str(); + library = LoadLibraryW(libPath); + } + else + { + return OS_RoGetActivationFactory(static_cast(classId), iid, factory); + } + + // If the library is not found, get the default one + if (!library) + { + return OS_RoGetActivationFactory(static_cast(classId), iid, factory); + } + + using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); + auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); + + if (!call) + { + HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + winrt::com_ptr activation_factory; + HRESULT const hr = call(static_cast(classId), activation_factory.put_void()); + + if (FAILED(hr)) + { + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + if (iid != winrt::guid_of()) + { + return activation_factory->QueryInterface(iid, factory); + } + + *factory = activation_factory.detach(); + return S_OK; +} diff --git a/Samples/SqueezeNetObjectDetection/Desktop/cpp/main.cpp b/Samples/SqueezeNetObjectDetection/Desktop/cpp/main.cpp index 76d8c0ff..17b17167 100644 --- a/Samples/SqueezeNetObjectDetection/Desktop/cpp/main.cpp +++ b/Samples/SqueezeNetObjectDetection/Desktop/cpp/main.cpp @@ -2,6 +2,7 @@ // #include "pch.h" +#include "FileHelper.h" using namespace winrt; using namespace Windows::Foundation; @@ -19,8 +20,6 @@ LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Default; string deviceName = "default"; hstring imagePath; -// helper functions -string GetModulePath(); void LoadLabels(); VideoFrame LoadImageFile(hstring filePath, ColorManagementMode colorManagementMode); void PrintResults(IVectorView results); @@ -30,7 +29,7 @@ ColorManagementMode GetColorManagementMode(const LearningModel& model); wstring GetModelPath() { wostringstream woss; - woss << GetModulePath().c_str(); + woss << FileHelper::GetModulePath().c_str(); woss << "SqueezeNet.onnx"; return woss.str(); } @@ -118,26 +117,10 @@ bool ParseArgs(int argc, char* argv[]) return true; } -string GetModulePath() -{ - string val; - char modulePath[MAX_PATH] = {}; - GetModuleFileNameA(NULL, modulePath, ARRAYSIZE(modulePath)); - char drive[_MAX_DRIVE]; - char dir[_MAX_DIR]; - char filename[_MAX_FNAME]; - char ext[_MAX_EXT]; - _splitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); - - val = drive; - val += dir; - return val; -} - void LoadLabels() { // Parse labels from labels file. We know the file's entries are already sorted in order. - std::string labelsFilePath = GetModulePath() + labelsFileName; + std::string labelsFilePath = FileHelper::GetModulePath() + labelsFileName; ifstream labelFile(labelsFilePath, ifstream::in); if (labelFile.fail()) {