diff --git a/net/dll/SEALdll.vcxproj b/net/dll/SEALdll.vcxproj index c78fe320..abb7b1d7 100644 --- a/net/dll/SEALdll.vcxproj +++ b/net/dll/SEALdll.vcxproj @@ -171,6 +171,8 @@ + + @@ -195,6 +197,8 @@ + + diff --git a/net/dll/SEALdll.vcxproj.filters b/net/dll/SEALdll.vcxproj.filters index 2bfebe61..1bc347ae 100644 --- a/net/dll/SEALdll.vcxproj.filters +++ b/net/dll/SEALdll.vcxproj.filters @@ -84,6 +84,12 @@ Header Files + + Header Files + + + Header Files + @@ -152,5 +158,11 @@ Source Files + + Source Files + + + Source Files + \ No newline at end of file diff --git a/net/dll/seal/memorymanager.cpp b/net/dll/seal/memorymanager.cpp new file mode 100644 index 00000000..db47ad51 --- /dev/null +++ b/net/dll/seal/memorymanager.cpp @@ -0,0 +1,45 @@ +// SEALDll +#include "stdafx.h" +#include "memorymanager.h" +#include "utilities.h" + +// SEAL +#include "seal/memorymanager.h" + +using namespace std; +using namespace seal; +using namespace seal::dll; + + +SEALDLL HRESULT SEALCALL MemoryManager_GetPool1(int prof_opt, bool clear_on_destruction, void** pool_handle) +{ + IfNullRet(pool_handle, E_POINTER); + + mm_prof_opt profile_opt = static_cast(prof_opt); + MemoryPoolHandle handle; + + // clear_on_destruction is only used when using FORCE_NEW + if (profile_opt == mm_prof_opt::FORCE_NEW) + { + handle = MemoryManager::GetPool(profile_opt, clear_on_destruction); + } + else + { + handle = MemoryManager::GetPool(profile_opt); + } + + MemoryPoolHandle* handle_ptr = new MemoryPoolHandle(move(handle)); + *pool_handle = handle_ptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryManager_GetPool2(void** pool_handle) +{ + IfNullRet(pool_handle, E_POINTER); + + MemoryPoolHandle handle = MemoryManager::GetPool(); + MemoryPoolHandle* handle_ptr = new MemoryPoolHandle(move(handle)); + *pool_handle = handle_ptr; + return S_OK; +} + diff --git a/net/dll/seal/memorymanager.h b/net/dll/seal/memorymanager.h new file mode 100644 index 00000000..667874a0 --- /dev/null +++ b/net/dll/seal/memorymanager.h @@ -0,0 +1,15 @@ +#pragma once + +/////////////////////////////////////////////////////////////////////////// +// +// This API is provided as a simple interface for the SEAL library +// that can be PInvoked by .Net code. +// +/////////////////////////////////////////////////////////////////////////// + +#include "defines.h" +#include + +SEALDLL HRESULT SEALCALL MemoryManager_GetPool1(int prof_opt, bool clear_on_destruction, void** pool_handle); + +SEALDLL HRESULT SEALCALL MemoryManager_GetPool2(void** pool_handle); diff --git a/net/dll/seal/memorypoolhandle.cpp b/net/dll/seal/memorypoolhandle.cpp new file mode 100644 index 00000000..cb34a4b5 --- /dev/null +++ b/net/dll/seal/memorypoolhandle.cpp @@ -0,0 +1,138 @@ +// SEALDll +#include "stdafx.h" +#include "memorypoolhandle.h" +#include "utilities.h" + +// SEAL +#include "seal/memorymanager.h" + +using namespace std; +using namespace seal; +using namespace seal::dll; + + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Create1(void** handle) +{ + IfNullRet(handle, E_POINTER); + + MemoryPoolHandle* handleptr = new MemoryPoolHandle(); + *handle = handleptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Create2(void* otherptr, void** handle) +{ + MemoryPoolHandle* other = FromVoid(otherptr); + IfNullRet(other, E_POINTER); + IfNullRet(handle, E_POINTER); + + MemoryPoolHandle* handleptr = new MemoryPoolHandle(*other); + *handle = handleptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Destroy(void* thisptr) +{ + MemoryPoolHandle* handle = FromVoid(thisptr); + IfNullRet(handle, E_POINTER); + + delete handle; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Set(void* thisptr, void* assignptr) +{ + MemoryPoolHandle* handle = FromVoid(thisptr); + IfNullRet(handle, E_POINTER); + MemoryPoolHandle* assign = FromVoid(assignptr); + IfNullRet(assign, E_POINTER); + + *handle = *assign; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Global(void** handle) +{ + IfNullRet(handle, E_POINTER); + + MemoryPoolHandle global = MemoryPoolHandle::Global(); + MemoryPoolHandle* handleptr = new MemoryPoolHandle(move(global)); + *handle = handleptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_ThreadLocal(void** handle) +{ + IfNullRet(handle, E_POINTER); + + MemoryPoolHandle threadlocal = MemoryPoolHandle::ThreadLocal(); + MemoryPoolHandle* handleptr = new MemoryPoolHandle(move(threadlocal)); + *handle = handleptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_New(bool clear_on_destruction, void** handle) +{ + IfNullRet(handle, E_POINTER); + + MemoryPoolHandle newhandle = MemoryPoolHandle::New(clear_on_destruction); + MemoryPoolHandle* handleptr = new MemoryPoolHandle(move(newhandle)); + *handle = handleptr; + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_PoolCount(void* thisptr, uint64_t* count) +{ + MemoryPoolHandle* pool = FromVoid(thisptr); + IfNullRet(pool, E_POINTER); + IfNullRet(count, E_POINTER); + + try + { + *count = pool->pool_count(); + return S_OK; + } + catch (const logic_error&) + { + return HRESULT_FROM_WIN32(ERROR_INVALID_OPERATION); + } +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_AllocByteCount(void* thisptr, uint64_t* count) +{ + MemoryPoolHandle* pool = FromVoid(thisptr); + IfNullRet(pool, E_POINTER); + IfNullRet(count, E_POINTER); + + try + { + *count = pool->alloc_byte_count(); + return S_OK; + } + catch (const logic_error&) + { + return HRESULT_FROM_WIN32(ERROR_INVALID_OPERATION); + } +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_IsInitialized(void* thisptr, bool* result) +{ + MemoryPoolHandle* pool = FromVoid(thisptr); + IfNullRet(pool, E_POINTER); + IfNullRet(result, E_POINTER); + + *result = (bool)(*pool); + return S_OK; +} + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Equals(void* thisptr, void* otherptr, bool* result) +{ + MemoryPoolHandle* pool = FromVoid(thisptr); + IfNullRet(pool, E_POINTER); + MemoryPoolHandle* other = FromVoid(otherptr); + IfNullRet(other, E_POINTER); + IfNullRet(result, E_POINTER); + + *result = (*pool == *other); + return S_OK; +} diff --git a/net/dll/seal/memorypoolhandle.h b/net/dll/seal/memorypoolhandle.h new file mode 100644 index 00000000..74118e3c --- /dev/null +++ b/net/dll/seal/memorypoolhandle.h @@ -0,0 +1,33 @@ +#pragma once + +/////////////////////////////////////////////////////////////////////////// +// +// This API is provided as a simple interface for the SEAL library +// that can be PInvoked by .Net code. +// +/////////////////////////////////////////////////////////////////////////// + +#include "defines.h" +#include + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Create1(void** handle); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Create2(void* otherptr, void** handle); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Destroy(void* thisptr); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Set(void* thisptr, void* assignptr); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Global(void** handle); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_ThreadLocal(void** handle); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_New(bool clear_on_destruction, void** handle); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_PoolCount(void* thisptr, uint64_t* count); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_AllocByteCount(void* thisptr, uint64_t* count); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_IsInitialized(void* thisptr, bool* result); + +SEALDLL HRESULT SEALCALL MemoryPoolHandle_Equals(void* thisptr, void* otherptr, bool* result); diff --git a/net/net/MemoryManager.cs b/net/net/MemoryManager.cs index e48df431..6beff71e 100644 --- a/net/net/MemoryManager.cs +++ b/net/net/MemoryManager.cs @@ -26,15 +26,15 @@ namespace Microsoft.Research.SEAL /// independent of the current profile: /// /// - /// MMProfOpt.ForceNew: return MemoryPoolHandle::New() - /// MMProfOpt.ForceGlobal: return MemoryPoolHandle::Global() - /// MMProfOpt.ForceThreadLocal: return MemoryPoolHandle::ThreadLocal() + /// MMProfOpt.ForceNew: return MemoryPoolHandle.New() + /// MMProfOpt.ForceGlobal: return MemoryPoolHandle.Global() + /// MMProfOpt.ForceThreadLocal: return MemoryPoolHandle.ThreadLocal() /// /// Other values for prof_opt are forwarded to the current profile and, depending /// on the profile, may or may not have an effect. The value mm_prof_opt::DEFAULT /// will always invoke a default behavior for the current profile. /// - /// A mm_prof_opt_t parameter used to provide additional + /// A MMProfOpt parameter used to provide additional /// instructions to the memory manager profile for internal logic. /// Indicates whether the memory pool data /// should be cleared when destroyed.This can be important when memory pools @@ -42,14 +42,19 @@ namespace Microsoft.Research.SEAL /// and ignored in all other cases. public static MemoryPoolHandle GetPool(MMProfOpt profOpt, bool clearOnDestruction = false) { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryManager_GetPool((int)profOpt, clearOnDestruction, out IntPtr handlePtr); + MemoryPoolHandle handle = new MemoryPoolHandle(handlePtr); + return handle; } + /// + /// Returns a MemoryPoolHandle according to the currently set memory manager profile. + /// public static MemoryPoolHandle GetPool() { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryManager_GetPool(out IntPtr handlePtr); + MemoryPoolHandle handle = new MemoryPoolHandle(handlePtr); + return handle; } } } diff --git a/net/net/MemoryPoolHandle.cs b/net/net/MemoryPoolHandle.cs index 3304cd90..86b39991 100644 --- a/net/net/MemoryPoolHandle.cs +++ b/net/net/MemoryPoolHandle.cs @@ -2,6 +2,7 @@ using Microsoft.Research.SEAL.Util; using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; namespace Microsoft.Research.SEAL @@ -63,19 +64,8 @@ namespace Microsoft.Research.SEAL /// public MemoryPoolHandle() { - // TODO: implement - throw new NotImplementedException(); - } - - /// - /// Creates a MemoryPoolHandle pointing to a given MemoryPool object. - /// - /// Pool to point to - /// if pool is null - public MemoryPoolHandle(MemoryPool pool) - { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryPoolHandle_Create(out IntPtr handlePtr); + NativePtr = handlePtr; } /// @@ -87,8 +77,21 @@ namespace Microsoft.Research.SEAL /// if copy is null. public MemoryPoolHandle(MemoryPoolHandle copy) { - // TODO: implement - throw new NotImplementedException(); + if (null == copy) + throw new ArgumentNullException(nameof(copy)); + + NativeMethods.MemoryPoolHandle_Create(copy.NativePtr, out IntPtr handlePtr); + NativePtr = handlePtr; + } + + /// + /// Create a MemoryPoolHandle through a native object pointer. + /// + /// Pointer to native MemoryPoolHandle + /// Whether this instance owns the native pointer + internal MemoryPoolHandle(IntPtr ptr, bool owned = true) + : base(ptr, owned) + { } /// @@ -101,8 +104,10 @@ namespace Microsoft.Research.SEAL /// if assign is null. public void Set(MemoryPoolHandle assign) { - // TODO: implement - throw new NotImplementedException(); + if (null == assign) + throw new ArgumentNullException(nameof(assign)); + + NativeMethods.MemoryPoolHandle_Set(NativePtr, assign.NativePtr); } /// @@ -110,8 +115,9 @@ namespace Microsoft.Research.SEAL /// public static MemoryPoolHandle Global() { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryPoolHandle_Global(out IntPtr handlePtr); + MemoryPoolHandle handle = new MemoryPoolHandle(handlePtr); + return handle; } /// @@ -121,8 +127,9 @@ namespace Microsoft.Research.SEAL /// public static MemoryPoolHandle ThreadLocal() { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryPoolHandle_ThreadLocal(out IntPtr handlePtr); + MemoryPoolHandle handle = new MemoryPoolHandle(handlePtr); + return handle; } /// @@ -133,8 +140,9 @@ namespace Microsoft.Research.SEAL /// are used to store private data. public static MemoryPoolHandle New(bool clearOnDestruction = false) { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryPoolHandle_New(clearOnDestruction, out IntPtr handlePtr); + MemoryPoolHandle handle = new MemoryPoolHandle(handlePtr); + return handle; } /// @@ -150,8 +158,17 @@ namespace Microsoft.Research.SEAL { get { - // TODO: implement - throw new NotImplementedException(); + try + { + NativeMethods.MemoryPoolHandle_PoolCount(NativePtr, out ulong count); + return count; + } + catch(COMException ex) + { + if ((uint)ex.HResult == NativeMethods.Errors.HRInvalidOperation) + throw new InvalidOperationException("MemoryPoolHandle is uninitialized", ex); + throw; + } } } @@ -165,8 +182,17 @@ namespace Microsoft.Research.SEAL { get { - // TODO: implement - throw new NotImplementedException(); + try + { + NativeMethods.MemoryPoolHandle_AllocByteCount(NativePtr, out ulong count); + return count; + } + catch (COMException ex) + { + if ((uint)ex.HResult == NativeMethods.Errors.HRInvalidOperation) + throw new InvalidOperationException("MemoryPoolHandle is uninitialized", ex); + throw; + } } } @@ -177,8 +203,8 @@ namespace Microsoft.Research.SEAL { get { - // TODO: implement - throw new NotImplementedException(); + NativeMethods.MemoryPoolHandle_IsInitialized(NativePtr, out bool result); + return result; } } @@ -189,8 +215,12 @@ namespace Microsoft.Research.SEAL /// Object to compare to. public override bool Equals(object obj) { - // TODO: implement - throw new NotImplementedException(); + MemoryPoolHandle other = obj as MemoryPoolHandle; + if (null == other) + return false; + + NativeMethods.MemoryPoolHandle_Equals(NativePtr, other.NativePtr, out bool result); + return result; } /// @@ -206,7 +236,7 @@ namespace Microsoft.Research.SEAL /// protected override void DestroyNativeObject() { - NativeMethods.MemPoolHandle_Destroy(NativePtr); + NativeMethods.MemoryPoolHandle_Destroy(NativePtr); } } } diff --git a/net/net/NativeMethods.cs b/net/net/NativeMethods.cs index 99a441b6..a1e97abb 100644 --- a/net/net/NativeMethods.cs +++ b/net/net/NativeMethods.cs @@ -870,12 +870,53 @@ namespace Microsoft.Research.SEAL [DllImport(SEALdll, PreserveSig = false)] internal static extern void SecretKey_ParmsId(IntPtr thisptr, ulong[] parmsId); + #endregion + + #region MemoryManager methods + + [DllImport(SEALdll, EntryPoint = "MemoryManager_GetPool1", PreserveSig = false)] + internal static extern void MemoryManager_GetPool(int profOpt, bool clearOnDestruction, out IntPtr handle); + + [DllImport(SEALdll, EntryPoint = "MemoryManager_GetPool2", PreserveSig = false)] + internal static extern void MemoryManager_GetPool(out IntPtr handle); + + #endregion #region MemoryPoolHandle methods + [DllImport(SEALdll, EntryPoint = "MemoryPoolHandle_Create1", PreserveSig = false)] + internal static extern void MemoryPoolHandle_Create(out IntPtr handlePtr); + + [DllImport(SEALdll, EntryPoint = "MemoryPoolHandle_Create2", PreserveSig = false)] + internal static extern void MemoryPoolHandle_Create(IntPtr other, out IntPtr handlePtr); + [DllImport(SEALdll, PreserveSig = false)] - internal static extern void MemPoolHandle_Destroy(IntPtr thisptr); + internal static extern void MemoryPoolHandle_Destroy(IntPtr thisptr); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_Set(IntPtr thisptr, IntPtr assignptr); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_Global(out IntPtr handlePtr); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_ThreadLocal(out IntPtr handlePtr); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_New(bool clearOnDestruction, out IntPtr handlePtr); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_PoolCount(IntPtr thisptr, out ulong count); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_AllocByteCount(IntPtr thisptr, out ulong count); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_IsInitialized(IntPtr thisptr, out bool initialized); + + [DllImport(SEALdll, PreserveSig = false)] + internal static extern void MemoryPoolHandle_Equals(IntPtr thisptr, IntPtr otherptr, out bool result); #endregion diff --git a/net/net/util/MemoryPool.cs b/net/net/util/MemoryPool.cs deleted file mode 100644 index faffbfd3..00000000 --- a/net/net/util/MemoryPool.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.Research.SEAL.Util -{ - /// - /// TODO: implement MemoryPool - /// - public class MemoryPool - { - } -} diff --git a/net/tests/MemoryPoolHandleTests.cs b/net/tests/MemoryPoolHandleTests.cs new file mode 100644 index 00000000..797de4a3 --- /dev/null +++ b/net/tests/MemoryPoolHandleTests.cs @@ -0,0 +1,69 @@ +using Microsoft.Research.SEAL; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; + +namespace SEALNetTest +{ + [TestClass] + public class MemoryPoolHandleTests + { + [TestMethod] + [ExpectedException(typeof(InvalidOperationException))] + public void PoolCountUninitializedTest() + { + MemoryPoolHandle handle = new MemoryPoolHandle(); + Assert.IsFalse(handle.IsInitialized); + ulong count = handle.PoolCount; + } + + [TestMethod] + [ExpectedException(typeof(InvalidOperationException))] + public void AllocByteCountUninitializedTest() + { + MemoryPoolHandle handle = new MemoryPoolHandle(); + Assert.IsFalse(handle.IsInitialized); + ulong count = handle.AllocByteCount; + } + + [TestMethod] + public void CreateTest() + { + MemoryPoolHandle handle = MemoryManager.GetPool(); + Assert.IsNotNull(handle); + Assert.IsTrue(handle.IsInitialized); + + MemoryPoolHandle handle2 = new MemoryPoolHandle(handle); + Assert.IsTrue(handle2.IsInitialized); + Assert.AreEqual(handle.PoolCount, handle2.PoolCount); + Assert.AreEqual(handle.AllocByteCount, handle2.AllocByteCount); + + MemoryPoolHandle handle3 = MemoryManager.GetPool(MMProfOpt.ForceNew, clearOnDestruction: true); + Assert.IsNotNull(handle3); + Assert.AreEqual(0ul, handle3.PoolCount); + Assert.AreEqual(0ul, handle3.AllocByteCount); + + MemoryPoolHandle handle4 = MemoryManager.GetPool(MMProfOpt.ForceThreadLocal); + Assert.IsNotNull(handle4); + Assert.AreEqual(0ul, handle4.PoolCount); + Assert.AreEqual(0ul, handle4.AllocByteCount); + } + + [TestMethod] + public void EqualsTest() + { + MemoryPoolHandle handle1 = MemoryManager.GetPool(MMProfOpt.ForceNew); + MemoryPoolHandle handle2 = MemoryManager.GetPool(MMProfOpt.Default); + MemoryPoolHandle handle3 = MemoryManager.GetPool(); + + Assert.IsNotNull(handle1); + Assert.IsNotNull(handle2); + Assert.IsNotNull(handle3); + + Assert.AreNotEqual(handle1, handle2); + Assert.AreNotEqual(handle1, handle3); + Assert.AreEqual(handle2, handle3); + } + } +}