Add classobject generation. Removed unsafe/safe examples, class files and target docs

This commit is contained in:
adrianwithah 2019-09-08 00:59:20 +01:00
Родитель 1ad152aefc
Коммит 201f370cb0
44 изменённых файлов: 261 добавлений и 2119 удалений

Просмотреть файл

@ -1,101 +0,0 @@
com_inproc_dll_module![
(CLSID_CAT_CLASS, BritishShortHairCatClass),
(CLSID_WINDOWS_FILE_MANAGER_CLASS, WindowsFileManagerClass),
(CLSID_LOCAL_FILE_MANAGER_CLASS, LocalFileManagerClass)
];
// ------------------------- DESIRED EXPANSION ------------------------------------------------
#[no_mangle]
extern "stdcall" fn DllGetClassObject(rclsid: REFCLSID, riid: REFIID, ppv: *mut LPVOID) -> HRESULT {
unsafe {
let rclsid_ref = &*rclsid;
if IsEqualGUID(rclsid_ref, &CLSID_CAT_CLASS) {
println!("Allocating new object CatClass...");
let mut cat = Box::new(<BritishShortHairCatClass>::new());
cat.add_ref();
let hr = cat.query_interface(riid, ppv);
cat.release();
Box::into_raw(cat);
hr
} else if IsEqualGUID(rclsid_ref, &CLSID_WINDOWS_FILE_MANAGER_CLASS) {
println!("Allocating new object WindowsFileManagerClass...");
let mut wfm = Box::new(<WindowsFileManagerClass>::new());
wfm.add_ref();
let hr = wfm.query_interface(riid, ppv);
wfm.release();
Box::into_raw(wfm);
hr
} else if IsEqualGUID(rclsid_ref, &CLSID_LOCAL_FILE_MANAGER_CLASS) {
println!("Allocating new object LocalFileManagerClass...");
let mut lfm = Box::new(<LocalFileManagerClass>::new());
lfm.add_ref();
let hr = lfm.query_interface(riid, ppv);
lfm.release();
Box::into_raw(lfm);
hr
} else {
CLASS_E_CLASSNOTAVAILABLE
}
}
}
// Function tries to add ALL-OR-NONE of the registry keys.
#[no_mangle]
extern "stdcall" fn DllRegisterServer() -> HRESULT {
let hr = register_keys(get_relevant_registry_keys());
if failed(hr) {
DllUnregisterServer();
}
hr
}
// Function tries to delete as many registry keys as possible.
#[no_mangle]
extern "stdcall" fn DllUnregisterServer() -> HRESULT {
let mut registry_keys_to_remove = get_relevant_registry_keys();
registry_keys_to_remove.reverse();
unregister_keys(registry_keys_to_remove)
}
fn get_relevant_registry_keys() -> Vec<RegistryKeyInfo> {
let file_path = get_dll_file_path();
// IMPORTANT: Assumption of order: Subkeys are located at a higher index than the parent key.
vec![
RegistryKeyInfo::new(
class_key_path(CLSID_CAT_CLASS).as_str(),
"",
"BritishShortHairCat",
),
RegistryKeyInfo::new(
class_inproc_key_path(CLSID_CAT_CLASS).as_str(),
"",
file_path.clone().as_str(),
),
RegistryKeyInfo::new(
class_key_path(CLSID_WINDOWS_FILE_MANAGER_CLASS).as_str(),
"",
"WindowsFileManagerClass",
),
RegistryKeyInfo::new(
class_inproc_key_path(CLSID_WINDOWS_FILE_MANAGER_CLASS).as_str(),
"",
file_path.clone().as_str(),
),
RegistryKeyInfo::new(
class_key_path(CLSID_LOCAL_FILE_MANAGER_CLASS).as_str(),
"",
"LocalFileManagerClass",
),
RegistryKeyInfo::new(
class_inproc_key_path(CLSID_LOCAL_FILE_MANAGER_CLASS).as_str(),
"",
file_path.clone().as_str(),
),
]
}

Просмотреть файл

@ -1,99 +0,0 @@
#[com_interface(00000001-0000-0000-c000-000000000046)]
pub trait IClassFactory: IUnknown {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT;
fn lock_server(&mut self, increment: BOOL) -> HRESULT;
}
impl ComPtr<IClassFactory> {
pub fn get_instance<T: ComInterface + ?Sized>(&mut self) -> Option<ComPtr<T>> {
let mut ppv = std::ptr::null_mut::<c_void>();
let aggr = std::ptr::null_mut();
let hr = self.create_instance(aggr, &T::IID as *const IID, &mut ppv);
if failed(hr) {
// TODO: decide what failures are possible
return None;
}
Some(ComPtr::new(std::ptr::NonNull::new(ppv as *mut c_void)?))
}
}
// ----------------------------- DESIRED EXPANSION ------------------------------------------------------
use super::*;
use winapi::ctypes::c_void;
use winapi::shared::guiddef::IID;
use winapi::shared::guiddef::REFIID;
use winapi::shared::minwindef::BOOL;
use winapi::shared::ntdef::HRESULT;
use std::marker::PhantomData;
#[allow(non_upper_case_globals)]
pub const IID_ICLASSFACTORY: IID = IID {
Data1: 0x01u32,
Data2: 0u16,
Data3: 0u16,
Data4: [0xC0, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0x46u8],
};
#[repr(C)]
pub struct IClassFactoryVTable {
pub base: <IUnknown as ComInterface>::VTable,
pub CreateInstance: unsafe extern "stdcall" fn(
*mut IClassFactoryVPtr,
*mut IUnknownVPtr,
REFIID,
*mut *mut c_void,
) -> HRESULT,
pub LockServer: unsafe extern "stdcall" fn(*mut IClassFactoryVPtr, BOOL) -> HRESULT,
}
pub type IClassFactoryVPtr = *const IClassFactoryVTable;
pub trait IClassFactory: IUnknown {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT;
fn lock_server(&mut self, increment: BOOL) -> HRESULT;
}
impl<T: IClassFactory + ComInterface + ?Sized> IClassFactory for ComPtr<T> {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
let itf_ptr = self.into_raw() as *mut IClassFactoryVPtr;
unsafe { ((**itf_ptr).CreateInstance)(itf_ptr, aggr, riid, ppv) }
}
fn lock_server(&mut self, increment: BOOL) -> HRESULT {
let itf_ptr = self.into_raw() as *mut IClassFactoryVPtr;
unsafe { ((**itf_ptr).LockServer)(itf_ptr, increment) }
}
}
unsafe impl ComInterface for IClassFactory {
type VTable = IClassFactoryVTable;
const IID: IID = IID_ICLASSFACTORY;
}
impl ComPtr<IClassFactory> {
pub fn get_instance<T: ComInterface + ?Sized>(&mut self) -> Option<ComPtr<T>> {
let mut ppv = std::ptr::null_mut::<c_void>();
let mut aggr = std::ptr::null_mut();
let hr = unsafe { self.create_instance(aggr, &T::IID as *const IID, &mut ppv) };
if failed(hr) {
// TODO: decide what failures are possible
return None;
}
Some(ComPtr::new(std::ptr::NonNull::new(ppv as *mut c_void)?))
}
}

Просмотреть файл

@ -1,110 +0,0 @@
use super::*;
use com_interface_attribute::com_interface;
use winapi::ctypes::c_void;
use winapi::shared::guiddef::GUID;
use winapi::shared::ntdef::HRESULT;
#[com_interface(00000000-0000-0000-C000-000000000046)]
pub trait IUnknown {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT;
fn add_ref(&mut self) -> u32;
fn release(&mut self) -> u32;
}
#[macro_export]
macro_rules! iunknown_gen_vtable {
($type:ty, $offset:literal) => {{
unsafe extern "stdcall" fn iunknown_query_interface(
this: *mut IUnknownVPtr,
riid: *const IID,
ppv: *mut *mut c_void,
) -> HRESULT {
let this = this.sub($offset) as *mut $type;
(*this).query_interface(riid, ppv)
}
unsafe extern "stdcall" fn iunknown_add_ref(this: *mut IUnknownVPtr) -> u32 {
let this = this.sub($offset) as *mut $type;
(*this).add_ref()
}
unsafe extern "stdcall" fn iunknown_release(this: *mut IUnknownVPtr) -> u32 {
let this = this.sub($offset) as *mut $type;
(*this).release()
}
IUnknownVTable {
QueryInterface: iunknown_query_interface,
Release: iunknown_release,
AddRef: iunknown_add_ref,
}
}};
}
// -------------------------------------- DESIRED EXPANSION -----------------------------------------
use super::*;
use winapi::shared::guiddef::GUID;
use winapi::shared::ntdef::HRESULT;
use winapi::ctypes::c_void;
#[allow(non_upper_case_globals)]
pub const IID_IUNKNOWN: GUID = GUID {
Data1: 0u32,
Data2: 0u16,
Data3: 0u16,
Data4: [192u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 70u8],
};
#[allow(non_snake_case)]
#[repr(C)]
pub struct IUnknownVTable {
pub QueryInterface:
unsafe extern "stdcall" fn(*mut IUnknownVPtr, *const IID, *mut *mut c_void) -> HRESULT,
pub AddRef: unsafe extern "stdcall" fn(*mut IUnknownVPtr) -> u32,
pub Release: unsafe extern "stdcall" fn(*mut IUnknownVPtr) -> u32,
}
pub type IUnknownVPtr = *const IUnknownVTable;
pub trait IUnknown {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT;
fn add_ref(&mut self) -> u32;
fn release(&mut self) -> u32;
}
impl <T: IUnknown + ComInterface + ?Sized> IUnknown for ComPtr<T> {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).QueryInterface)(itf_ptr, riid, ppv) }
}
fn add_ref(&mut self) -> u32 {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).AddRef)(itf_ptr) }
}
fn release(&mut self) -> u32 {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).Release)(itf_ptr) }
}
}
unsafe impl ComInterface for IUnknown {
type VTable = IUnknownVTable;
const IID: IID = IID_IUNKNOWN;
}
impl<T: IUnknown + ComInterface + ?Sized> ComPtr<T> {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).QueryInterface)(itf_ptr, riid, ppv) }
}
fn add_ref(&mut self) -> u32 {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).AddRef)(itf_ptr) }
}
fn release(&mut self) -> u32 {
let itf_ptr = self.into_raw() as *mut IUnknownVPtr;
unsafe { ((**itf_ptr).Release)(itf_ptr) }
}
}

Просмотреть файл

@ -1,14 +1,12 @@
use interface::{CLSID_LOCAL_FILE_MANAGER_CLASS, CLSID_WINDOWS_FILE_MANAGER_CLASS}; use interface::{CLSID_LOCAL_FILE_MANAGER_CLASS, CLSID_WINDOWS_FILE_MANAGER_CLASS};
mod local_file_manager; mod local_file_manager;
mod local_file_manager_class;
mod windows_file_manager; mod windows_file_manager;
mod windows_file_manager_class;
use local_file_manager_class::LocalFileManagerClass; use local_file_manager::LocalFileManager;
use windows_file_manager_class::WindowsFileManagerClass; use windows_file_manager::WindowsFileManager;
com::com_inproc_dll_module![ com::com_inproc_dll_module![
(CLSID_WINDOWS_FILE_MANAGER_CLASS, WindowsFileManagerClass), (CLSID_WINDOWS_FILE_MANAGER_CLASS, WindowsFileManager),
(CLSID_LOCAL_FILE_MANAGER_CLASS, LocalFileManagerClass) (CLSID_LOCAL_FILE_MANAGER_CLASS, LocalFileManager)
]; ];

Просмотреть файл

@ -32,148 +32,4 @@ impl LocalFileManager {
}; };
LocalFileManager::allocate(init) LocalFileManager::allocate(init)
} }
} }
// ----------------------------------------- MACRO GENERATED ------------------------------------------
// #[repr(C)]
// pub struct LocalFileManager {
// ilocalfilemanager: ILocalFileManagerVPtr,
// non_delegating_unk: IUnknownVPtr,
// iunk_to_use: *mut IUnknownVPtr,
// ref_count: u32,
// value: InitLocalFileManager,
// }
// impl Drop for LocalFileManager {
// fn drop(&mut self) {
// println!("Dropping LocalFileManager");
// let _ = unsafe { Box::from_raw(self.ilocalfilemanager as *mut ILocalFileManagerVTable) };
// }
// }
// // Default implementation should delegate to iunk_to_use.
// impl IUnknown for LocalFileManager {
// fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
// println!("Delegating QI");
// let mut iunk_to_use: ComPtr<dyn IUnknown> = unsafe { ComPtr::new(self.iunk_to_use as *mut c_void) };
// let hr = iunk_to_use.query_interface(riid, ppv);
// forget(iunk_to_use);
// hr
// }
// fn add_ref(&mut self) -> u32 {
// let mut iunk_to_use: ComPtr<dyn IUnknown> = unsafe { ComPtr::new(self.iunk_to_use as *mut c_void) };
// let res = iunk_to_use.add_ref();
// forget(iunk_to_use);
// res
// }
// fn release(&mut self) -> u32 {
// let mut iunk_to_use: ComPtr<dyn IUnknown> = unsafe { ComPtr::new(self.iunk_to_use as *mut c_void) };
// let res = iunk_to_use.release();
// forget(iunk_to_use);
// res
// }
// }
// impl LocalFileManager {
// fn allocate(value: InitLocalFileManager) -> Box<LocalFileManager> {
// println!("Allocating new Vtable for LocalFileManager...");
// // Initialising the non-delegating IUnknown
// let non_del_iunknown = IUnknownVTable {
// QueryInterface: non_delegating_ilocalfilemanager_query_interface,
// Release: non_delegating_ilocalfilemanager_release,
// AddRef: non_delegating_ilocalfilemanager_add_ref,
// };
// let non_del_unknown_vptr = Box::into_raw(Box::new(non_del_iunknown));
// // Initialising VTable for ILocalFileManager
// let ilocalfilemanager_vptr = Box::into_raw(Box::new(ilocalfilemanager));
// let out = LocalFileManager {
// ilocalfilemanager: ilocalfilemanager_vptr,
// non_delegating_unk: non_del_unknown_vptr,
// iunk_to_use: std::ptr::null_mut::<IUnknownVPtr>(),
// ref_count: 0,
// value
// };
// Box::new(out)
// }
// // Implementations only for Aggregable objects.
// pub(crate) fn set_iunknown(&mut self, aggr: *mut IUnknownVPtr) {
// if aggr.is_null() {
// self.iunk_to_use = &self.non_delegating_unk as *const _ as *mut IUnknownVPtr;
// } else {
// self.iunk_to_use = aggr;
// }
// }
// pub(crate) fn inner_query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
// println!("Non delegating QI");
// unsafe {
// let riid = &*riid;
// if IsEqualGUID(riid, &IID_IUNKNOWN) {
// // Returns the nondelegating IUnknown, as in COM specification.
// *ppv = &self.non_delegating_unk as *const _ as *mut c_void;
// } else if IsEqualGUID(riid, &IID_ILOCAL_FILE_MANAGER) {
// // Returns the original VTable.
// *ppv = &self.ilocalfilemanager as *const _ as *mut c_void;
// } else {
// *ppv = std::ptr::null_mut::<c_void>();
// println!("Returning NO INTERFACE.");
// return E_NOINTERFACE;
// }
// self.inner_add_ref();
// NOERROR
// }
// }
// pub(crate) fn inner_add_ref(&mut self) -> u32 {
// self.ref_count += 1;
// println!("Count now {}", self.ref_count);
// self.ref_count
// }
// pub(crate) fn inner_release(&mut self) -> u32 {
// self.ref_count -= 1;
// println!("Count now {}", self.ref_count);
// let count = self.ref_count;
// if count == 0 {
// println!("Count is 0 for LocalFileManager. Freeing memory...");
// drop(self);
// }
// count
// }
// }
// // Non-delegating methods.
// unsafe extern "stdcall" fn non_delegating_ilocalfilemanager_query_interface(
// this: *mut IUnknownVPtr,
// riid: *const IID,
// ppv: *mut *mut c_void,
// ) -> HRESULT {
// let this = this.sub(1) as *mut LocalFileManager;
// (*this).inner_query_interface(riid, ppv)
// }
// unsafe extern "stdcall" fn non_delegating_ilocalfilemanager_add_ref(
// this: *mut IUnknownVPtr,
// ) -> u32 {
// let this = this.sub(1) as *mut LocalFileManager;
// (*this).inner_add_ref()
// }
// unsafe extern "stdcall" fn non_delegating_ilocalfilemanager_release(
// this: *mut IUnknownVPtr,
// ) -> u32 {
// let this = this.sub(1) as *mut LocalFileManager;
// (*this).inner_release()
// }

Просмотреть файл

@ -1,115 +0,0 @@
use crate::local_file_manager::LocalFileManager;
use com::{
IClassFactory, IClassFactoryVPtr, IClassFactoryVTable, IUnknown, IUnknownVPtr,
IID_ICLASS_FACTORY, IID_IUNKNOWN,
};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID, REFIID},
minwindef::BOOL,
winerror::{E_INVALIDARG, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct LocalFileManagerClass {
inner: IClassFactoryVPtr,
ref_count: u32,
}
impl Drop for LocalFileManagerClass {
fn drop(&mut self) {
let _ = unsafe { Box::from_raw(self.inner as *mut IClassFactoryVTable) };
}
}
impl IClassFactory for LocalFileManagerClass {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
println!("Creating instance...");
unsafe {
let riid = &*riid;
if !aggr.is_null() && !IsEqualGUID(riid, &IID_IUNKNOWN) {
*ppv = std::ptr::null_mut::<c_void>();
return E_INVALIDARG;
}
let mut lfm = LocalFileManager::new();
// This check has to be here because it can only be done after object
// is allocated on the heap (address of nonDelegatingUnknown fixed)
lfm.set_iunknown(aggr);
// As an aggregable object, we have to add_ref through the
// non-delegating IUnknown on creation. Otherwise, we might
// add_ref the outer object if aggregated.
lfm.inner_add_ref();
let hr = lfm.inner_query_interface(riid, ppv);
lfm.inner_release();
Box::into_raw(lfm);
hr
}
}
fn lock_server(&mut self, _increment: BOOL) -> HRESULT {
println!("LockServer called");
S_OK
}
}
impl IUnknown for LocalFileManagerClass {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
println!("Querying interface on LocalFileManagerClass...");
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN) | IsEqualGUID(riid, &IID_ICLASS_FACTORY) {
*ppv = self as *const _ as *mut c_void;
self.add_ref();
NOERROR
} else {
E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
println!("Adding ref...");
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
println!("Releasing...");
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for LocalFileManagerClass. Freeing memory...");
drop(self);
}
count
}
}
impl LocalFileManagerClass {
pub(crate) fn new() -> LocalFileManagerClass {
let iclass_factory = com::vtable!(LocalFileManagerClass: IClassFactory);
let vptr = Box::into_raw(Box::new(iclass_factory));
LocalFileManagerClass {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -32,12 +32,6 @@ pub struct InitWindowsFileManager {
lfm_iunknown: *mut IUnknownVPtr, lfm_iunknown: *mut IUnknownVPtr,
} }
impl DerefMut for WindowsFileManager {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.__init_struct
}
}
impl Drop for InitWindowsFileManager { impl Drop for InitWindowsFileManager {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
@ -87,94 +81,4 @@ impl WindowsFileManager {
wfm wfm
} }
} }
// ------------------------- MACRO GENERATED --------------------------------
// #[repr(C)]
// pub struct WindowsFileManager {
// ifilemanager: IFileManagerVPtr,
// ref_count: u32,
// value: InitWindowsFileManager,
// }
// impl Drop for WindowsFileManager {
// fn drop(&mut self) {
// unsafe {
// Box::from_raw(self.ifilemanager as *mut IFileManagerVTable);
// };
// }
// }
// impl IUnknown for WindowsFileManager {
// fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
// /* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
// unsafe {
// let riid = &*riid;
// if IsEqualGUID(riid, &IID_IUNKNOWN) | IsEqualGUID(riid, &IID_IFILE_MANAGER) {
// *ppv = self as *const _ as *mut c_void;
// } else if IsEqualGUID(riid, &IID_ILOCAL_FILE_MANAGER) {
// let mut lfm_iunknown: ComPtr<dyn IUnknown> =
// ComPtr::new(self.lfm_iunknown as *mut c_void);
// let hr = lfm_iunknown.query_interface(riid, ppv);
// if failed(hr) {
// return E_NOINTERFACE;
// }
// // We release it as the previous call add_ref-ed the inner object.
// // The intention is to transfer reference counting logic to the
// // outer object.
// lfm_iunknown.release();
// forget(lfm_iunknown);
// } else {
// return E_NOINTERFACE;
// }
// self.add_ref();
// NOERROR
// }
// }
// fn add_ref(&mut self) -> u32 {
// self.ref_count += 1;
// println!("Count now {}", self.ref_count);
// self.ref_count
// }
// fn release(&mut self) -> u32 {
// self.ref_count -= 1;
// println!("Count now {}", self.ref_count);
// let count = self.ref_count;
// if count == 0 {
// println!("Count is 0 for WindowsFileManager. Freeing memory...");
// unsafe { Box::from_raw(self as *const _ as *mut WindowsFileManager); }
// }
// count
// }
// }
// impl WindowsFileManager {
// fn allocate(value: InitWindowsFileManager) -> Box<WindowsFileManager> {
// println!("Allocating new Vtable...");
// // Initialising VTable for IFileManager
// let ifilemanager = ifile_manager_gen_vtable!(WindowsFileManager, 0);
// let ifilemanager_vptr = Box::into_raw(Box::new(ifilemanager));
// let wfm = WindowsFileManager {
// ifilemanager: ifilemanager_vptr,
// ref_count: 0,
// value
// };
// Box::new(wfm)
// }
// }
//
// impl Deref for WindowsFileManager {
// type Target = InitWindowsFileManager;
// fn deref(&self) -> &Self::Target {
// &self.value
// }
// }

Просмотреть файл

@ -1,100 +0,0 @@
use crate::windows_file_manager::WindowsFileManager;
use com::{
IClassFactory, IClassFactoryVPtr, IClassFactoryVTable, IUnknown, IUnknownVPtr,
IID_ICLASS_FACTORY, IID_IUNKNOWN,
};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID, REFIID},
minwindef::{BOOL,},
winerror::{CLASS_E_NOAGGREGATION, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct WindowsFileManagerClass {
inner: IClassFactoryVPtr,
ref_count: u32,
}
impl Drop for WindowsFileManagerClass {
fn drop(&mut self) {
println!("Dropping WindowsFileManagerClass");
let _ = unsafe { Box::from_raw(self.inner as *mut IClassFactoryVTable) };
}
}
impl IClassFactory for WindowsFileManagerClass {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
println!("Creating instance...");
if aggr != std::ptr::null_mut() {
return CLASS_E_NOAGGREGATION;
}
let mut wfm = WindowsFileManager::new();
wfm.add_ref();
let hr = wfm.query_interface(riid, ppv);
wfm.release();
Box::into_raw(wfm);
hr
}
fn lock_server(&mut self, _increment: BOOL) -> HRESULT {
println!("LockServer called");
S_OK
}
}
impl IUnknown for WindowsFileManagerClass {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
unsafe {
println!("Querying interface on WindowsFileManagerClass...");
let riid_ref = &*riid;
if IsEqualGUID(riid_ref, &IID_IUNKNOWN) | IsEqualGUID(riid_ref, &IID_ICLASS_FACTORY) {
*ppv = self as *const _ as *mut c_void;
self.add_ref();
NOERROR
} else {
E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for WindowsFileManagerClass. Freeing memory...");
drop(self);
}
count
}
}
impl WindowsFileManagerClass {
pub(crate) fn new() -> WindowsFileManagerClass {
println!("Allocating new Vtable for WindowsFileManagerClass...");
let class_vtable = com::vtable!(WindowsFileManagerClass: IClassFactory);
let vptr = Box::into_raw(Box::new(class_vtable));
WindowsFileManagerClass {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,102 +0,0 @@
use crate::british_short_hair_cat::BritishShortHairCat;
use com::{
IClassFactory, IClassFactoryVPtr, IUnknown, IUnknownVPtr, IID_ICLASS_FACTORY, IID_IUNKNOWN,
};
use interface::icat_class::{ICatClassVTable, IID_ICAT_CLASS};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID, REFIID},
minwindef::BOOL,
winerror::{CLASS_E_NOAGGREGATION, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct BritishShortHairCatClass {
inner: IClassFactoryVPtr,
ref_count: u32,
}
impl IClassFactory for BritishShortHairCatClass {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
println!("Creating instance...");
if !aggr.is_null() {
return CLASS_E_NOAGGREGATION;
}
let mut cat = BritishShortHairCat::new();
cat.add_ref();
let hr = cat.query_interface(riid, ppv);
cat.release();
Box::into_raw(cat);
hr
}
fn lock_server(&mut self, _increment: BOOL) -> HRESULT {
println!("LockServer called");
S_OK
}
}
impl IUnknown for BritishShortHairCatClass {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN)
|| IsEqualGUID(riid, &IID_ICLASS_FACTORY)
|| IsEqualGUID(riid, &IID_ICAT_CLASS)
{
*ppv = self as *const _ as *mut c_void;
self.add_ref();
NOERROR
} else {
E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for BritishShortHairCatClass. Freeing memory...");
drop(self);
}
count
}
}
impl Drop for BritishShortHairCatClass {
fn drop(&mut self) {
let _ = unsafe { Box::from_raw(self.inner as *mut ICatClassVTable) };
}
}
impl BritishShortHairCatClass {
pub(crate) fn new() -> BritishShortHairCatClass {
println!("Allocating new vtable for CatClass...");
let icat_class_vtable = com::vtable!(BritishShortHairCatClass: IClassFactory);
let vptr = Box::into_raw(Box::new(icat_class_vtable));
BritishShortHairCatClass {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,7 +1,6 @@
mod british_short_hair_cat; mod british_short_hair_cat;
mod british_short_hair_cat_class;
use british_short_hair_cat_class::BritishShortHairCatClass; use british_short_hair_cat::BritishShortHairCat;
use interface::CLSID_CAT_CLASS; use interface::CLSID_CAT_CLASS;
com::com_inproc_dll_module![(CLSID_CAT_CLASS, BritishShortHairCatClass),]; com::com_inproc_dll_module![(CLSID_CAT_CLASS, BritishShortHairCat),];

Просмотреть файл

@ -1,12 +0,0 @@
[package]
name = "basic"
version = "0.0.1"
authors = ["Microsoft Corp"]
edition = "2018"
[workspace]
members = [
"client",
"interface",
"server",
]

Просмотреть файл

@ -1,35 +0,0 @@
# COM Example
A COM example in Rust
# Run
To install the server and run the client, simply run the following from the basic folder:
```bash
cargo run
```
Alternatively, you can choose to build/install/run the server and client seperately.
# Build & Install Server
You can build the server by running the following in the server folder:
```bash
cargo build
```
To "install" the server, you need to add the CLSIDs to your Windows registry. You can do that by running:
```bash
regsvr32 path/to/your/server/dll/file
```
# Run Client
To run the client which talks to the server, simply run the following from the client folder:
```bash
cargo run
```

2
examples/safe/client/.gitignore поставляемый
Просмотреть файл

@ -1,2 +0,0 @@
/target
**/*.rs.bk

Просмотреть файл

@ -1,12 +0,0 @@
[package]
name = "client"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
interface = { path = "../interface" }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "combaseapi", "objbase"] }

Просмотреть файл

@ -1,74 +0,0 @@
// import "unknwn.idl";
// [object, uuid(DF12E151-A29A-l1dO-8C2D-00BOC73925BA)]
// interface IAnimal : IUnknown {
// HRESULT Eat(void);
// }
// [object, uuid(DF12E152-A29A-l1dO-8C2D-0080C73925BA)]
// interface ICat : IAnimal {
// HRESULT IgnoreHumans(void);
// }
use com::{ComOutPtr, Runtime};
use winapi::shared::winerror::{E_FAIL, S_OK};
use interface::{ISuperman, CLSID_CLARK_KENT_CLASS};
fn main() {
let runtime = match Runtime::new() {
Ok(runtime) => {
println!("Got a runtime");
runtime
}
Err(hr) => {
println!("Failed to initialize COM Library: {}", hr);
return;
}
};
run_safe_test(runtime);
}
fn run_safe_test(runtime: Runtime) {
let mut clark_kent = match runtime.create_instance::<dyn ISuperman>(&CLSID_CLARK_KENT_CLASS) {
Ok(clark_kent) => clark_kent,
Err(e) => {
println!("Failed to get clark kent, {:x}", e as u32);
return;
}
};
println!("Got clark kent!");
// [in] tests
assert!(clark_kent.take_input(10) == E_FAIL);
assert!(clark_kent.take_input(4) == S_OK);
// [out] tests
let mut var_to_populate = ComOutPtr::<u32>::new();
clark_kent.populate_output(&mut var_to_populate);
assert!(*var_to_populate.get().unwrap() == 6);
// [in, out] tests
let mut ptr_to_mutate = Some(Box::new(6));
clark_kent.mutate_and_return(&mut ptr_to_mutate);
match ptr_to_mutate {
Some(n) => assert!(*n == 100),
None => assert!(false),
};
let mut ptr_to_mutate = None;
clark_kent.mutate_and_return(&mut ptr_to_mutate);
match ptr_to_mutate {
Some(_n) => assert!(false),
None => (),
};
// [in] ptr tests
let in_var = Some(50);
assert!(clark_kent.take_input_ptr(&in_var) == E_FAIL);
let in_var = Some(2);
assert!(clark_kent.take_input_ptr(&in_var) == S_OK);
let in_var = None;
assert!(clark_kent.take_input_ptr(&in_var) == E_FAIL);
println!("Tests passed!");
}

3
examples/safe/interface/.gitignore поставляемый
Просмотреть файл

@ -1,3 +0,0 @@
/target
**/*.rs.bk
Cargo.lock

Просмотреть файл

@ -1,14 +0,0 @@
[package]
name = "interface"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "winerror", "winnt"] }
[lib]
crate-type = ["rlib", "cdylib"]

Просмотреть файл

@ -1,107 +0,0 @@
use com::{ComInterface, ComOutPtr, ComPtr, IUnknown, IUnknownVTable};
use std::mem::MaybeUninit;
use winapi::{shared::guiddef::IID, um::winnt::HRESULT};
pub const IID_ISUPERMAN: IID = IID {
Data1: 0xa56b76e4,
Data2: 0x4bd7,
Data3: 0x48b9,
Data4: [0x8a, 0xf6, 0xb9, 0x3f, 0x43, 0xe8, 0x69, 0xc8],
};
pub trait ISuperman: IUnknown {
// [in]
fn take_input(&mut self, in_var: u32) -> HRESULT;
// [out]
fn populate_output(&mut self, out_var: &mut ComOutPtr<u32>) -> HRESULT;
// [in, out]
fn mutate_and_return(&mut self, in_out_var: &mut Option<Box<u32>>) -> HRESULT;
// [in] pointer
fn take_input_ptr(&mut self, in_ptr_var: &Option<u32>) -> HRESULT;
// // [in, out] Interface
// fn take_interface();
// // [out] Interface
// fn populate_interface(ComOutPtr<ComItf>);
}
unsafe impl ComInterface for dyn ISuperman {
type VTable = ISupermanVTable;
const IID: IID = IID_ISUPERMAN;
}
pub type ISupermanVPtr = *const ISupermanVTable;
impl<T: ISuperman + ComInterface + ?Sized> ISuperman for ComPtr<T> {
fn take_input(&mut self, in_var: u32) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
unsafe { ((**itf_ptr).TakeInput)(itf_ptr, in_var) }
}
fn populate_output(&mut self, out_var: &mut ComOutPtr<u32>) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
// Let called-procedure write to possibly uninit memory.
let mut proxy = MaybeUninit::<u32>::uninit();
unsafe {
let hr = ((**itf_ptr).PopulateOutput)(itf_ptr, proxy.as_mut_ptr());
println!("Returned from populate output!");
// Consumes the MaybeUninit. Not exactly sure what happens to the
// allocated memory here. Working verison for now.
let value = proxy.assume_init();
out_var.set(value);
// Attempt 2:
// out_var.wrap(proxy.as_mut_ptr());
// let mut value = proxy.assume_init();
// out_var.set(value);
//
// Remarks: This should be the ideal way to do it (with the old "set" that
// just writes to the underlying pointer), as we are
// pointing to the same location that the server wrote to.
// However, failed later during client code when doing
// &*com_out_ptr.as_mut_ptr(). Might be triggering UB somewhere.
hr
}
}
fn mutate_and_return(&mut self, in_out_var: &mut Option<Box<u32>>) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
let in_out_raw = match in_out_var {
Some(ref mut n) => n.as_mut() as *mut u32,
None => std::ptr::null_mut::<u32>(),
};
unsafe { ((**itf_ptr).MutateAndReturn)(itf_ptr, in_out_raw) }
}
fn take_input_ptr(&mut self, in_ptr_var: &Option<u32>) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
let in_out_raw = match in_ptr_var {
Some(n) => n as *const u32,
None => std::ptr::null_mut::<u32>(),
};
unsafe { ((**itf_ptr).TakeInputPtr)(itf_ptr, in_out_raw) }
}
}
#[allow(non_snake_case)]
#[repr(C)]
pub struct ISupermanVTable {
pub base: IUnknownVTable,
pub TakeInput: unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_var: u32) -> HRESULT,
pub PopulateOutput:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, out_var: *mut u32) -> HRESULT,
pub MutateAndReturn:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_out_var: *mut u32) -> HRESULT,
pub TakeInputPtr:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_ptr_var: *const u32) -> HRESULT,
}

Просмотреть файл

@ -1,12 +0,0 @@
pub mod isuperman;
pub use isuperman::ISuperman;
use winapi::shared::guiddef::IID;
pub const CLSID_CLARK_KENT_CLASS: IID = IID {
Data1: 0xf26c011d,
Data2: 0xa586,
Data3: 0x4819,
Data4: [0xa3, 0x34, 0xa7, 0x40, 0xb4, 0xe7, 0xfd, 0x3c],
};

3
examples/safe/server/.gitignore поставляемый
Просмотреть файл

@ -1,3 +0,0 @@
/target
**/*.rs.bk
Cargo.lock

Просмотреть файл

@ -1,15 +0,0 @@
[package]
name = "server"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
interface = { path = "../interface" }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "winerror", "combaseapi"] }
[lib]
crate-type = ["rlib", "cdylib"]

Просмотреть файл

@ -1,198 +0,0 @@
use com::{ComOutPtr, IUnknown, IUnknownVPtr, IUnknownVTable, IID_IUNKNOWN};
use interface::isuperman::{ISuperman, ISupermanVPtr, ISupermanVTable, IID_ISUPERMAN};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID},
winerror::{E_FAIL, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct ClarkKent {
inner: ISupermanVPtr,
ref_count: u32,
}
impl Drop for ClarkKent {
fn drop(&mut self) {
let _ = unsafe {
Box::from_raw(self.inner as *mut ISupermanVTable);
};
}
}
impl ISuperman for ClarkKent {
fn take_input(&mut self, in_var: u32) -> HRESULT {
println!("Received Input! Input is: {}", in_var);
if in_var > 5 {
return E_FAIL;
}
S_OK
}
fn populate_output(&mut self, out_var: &mut ComOutPtr<u32>) -> HRESULT {
out_var.set(6);
S_OK
}
fn mutate_and_return(&mut self, in_out_var: &mut Option<Box<u32>>) -> HRESULT {
match in_out_var {
Some(n) => **n = 100,
None => println!("Received null, unable to mutate!"),
};
S_OK
}
fn take_input_ptr(&mut self, in_ptr_var: &Option<u32>) -> HRESULT {
match in_ptr_var {
Some(n) => {
if *n > 5 {
return E_FAIL;
} else {
return S_OK;
}
}
None => {
return E_FAIL;
}
};
}
}
impl IUnknown for ClarkKent {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN) || IsEqualGUID(riid, &IID_ISUPERMAN) {
*ppv = &self.inner as *const _ as *mut c_void;
} else {
println!("Returning NO INTERFACE.");
return E_NOINTERFACE;
}
println!("Successful!.");
self.add_ref();
NOERROR
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for ClarkKent. Freeing memory...");
drop(self)
}
count
}
}
// Adjustor Thunks for ISuperman
unsafe extern "stdcall" fn query_interface(
this: *mut IUnknownVPtr,
riid: *const IID,
ppv: *mut *mut c_void,
) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).query_interface(riid, ppv)
}
unsafe extern "stdcall" fn add_ref(this: *mut IUnknownVPtr) -> u32 {
println!("Adding ref...");
let this = this as *mut ClarkKent;
(*this).add_ref()
}
unsafe extern "stdcall" fn release(this: *mut IUnknownVPtr) -> u32 {
println!("Releasing...");
let this = this as *mut ClarkKent;
(*this).release()
}
unsafe extern "stdcall" fn take_input(this: *mut ISupermanVPtr, in_var: u32) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).take_input(in_var)
}
unsafe extern "stdcall" fn populate_output(this: *mut ISupermanVPtr, out_var: *mut u32) -> HRESULT {
let this = this as *mut ClarkKent;
let mut ptr = ComOutPtr::from_ptr(out_var);
(*this).populate_output(&mut ptr)
}
unsafe extern "stdcall" fn mutate_and_return(
this: *mut ISupermanVPtr,
in_out_var: *mut u32,
) -> HRESULT {
let this = this as *mut ClarkKent;
let mut opt = if in_out_var.is_null() {
None
} else {
Some(Box::from_raw(in_out_var))
};
let hr = (*this).mutate_and_return(&mut opt);
// Server side should not be responsible for memory allocated by client.
match opt {
Some(n) => {
Box::into_raw(n);
}
_ => (),
};
hr
}
unsafe extern "stdcall" fn take_input_ptr(
this: *mut ISupermanVPtr,
in_ptr_var: *const u32,
) -> HRESULT {
let this = this as *mut ClarkKent;
let opt = if in_ptr_var.is_null() {
None
} else {
Some(*in_ptr_var)
};
(*this).take_input_ptr(&opt)
}
impl ClarkKent {
pub(crate) fn new() -> ClarkKent {
println!("Allocating new Vtable...");
/* Initialising VTable for ISuperman */
let iunknown = IUnknownVTable {
QueryInterface: query_interface,
Release: release,
AddRef: add_ref,
};
let isuperman = ISupermanVTable {
base: iunknown,
TakeInput: take_input,
PopulateOutput: populate_output,
MutateAndReturn: mutate_and_return,
TakeInputPtr: take_input_ptr,
};
let vptr = Box::into_raw(Box::new(isuperman));
ClarkKent {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,98 +0,0 @@
use crate::clark_kent::ClarkKent;
use com::{
IClassFactory, IClassFactoryVPtr, IClassFactoryVTable, IUnknown, IUnknownVPtr,
IID_ICLASS_FACTORY, IID_IUNKNOWN,
};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID, REFIID},
minwindef::BOOL,
winerror::{CLASS_E_NOAGGREGATION, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct ClarkKentClass {
inner: IClassFactoryVPtr,
ref_count: u32,
}
impl IClassFactory for ClarkKentClass {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
println!("Creating instance...");
if !aggr.is_null() {
return CLASS_E_NOAGGREGATION;
}
let mut ck = Box::new(ClarkKent::new());
ck.add_ref();
let hr = ck.query_interface(riid, ppv);
ck.release();
let _ptr = Box::into_raw(ck);
hr
}
fn lock_server(&mut self, _increment: BOOL) -> HRESULT {
println!("LockServer called");
S_OK
}
}
impl IUnknown for ClarkKentClass {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN) || IsEqualGUID(riid, &IID_ICLASS_FACTORY) {
*ppv = self as *const _ as *mut c_void;
self.add_ref();
NOERROR
} else {
E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for ClarkKentClass. Freeing memory...");
drop(self);
}
count
}
}
impl Drop for ClarkKentClass {
fn drop(&mut self) {
let _ = unsafe { Box::from_raw(self.inner as *mut IClassFactoryVTable) };
}
}
impl ClarkKentClass {
pub(crate) fn new() -> ClarkKentClass {
println!("Allocating new Vtable for ClarkKentClass...");
let iclassfactory = com::vtable!(ClarkKentClass: IClassFactory);
let vptr = Box::into_raw(Box::new(iclassfactory));
ClarkKentClass {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,7 +0,0 @@
mod clark_kent;
mod clark_kent_class;
use clark_kent_class::ClarkKentClass;
use interface::CLSID_CLARK_KENT_CLASS;
com::com_inproc_dll_module![(CLSID_CLARK_KENT_CLASS, ClarkKentClass),];

Просмотреть файл

@ -1,31 +0,0 @@
use std::process::Command;
fn main() {
let mut child_proc = Command::new("cmd")
.args(&["/C", "cls && cargo build --all --release"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Build failed.");
return;
}
let mut child_proc = Command::new("cmd")
.args(&["/C", "regsvr32 /s target/release/server.dll"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Failed to register server.dll! Make sure you have administrator rights!");
return;
}
let mut child_proc = Command::new("cmd")
.args(&["/C", "cargo run --release --package client"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Execution of client failed.");
return;
}
}

Просмотреть файл

@ -1,12 +0,0 @@
[package]
name = "basic"
version = "0.0.1"
authors = ["Microsoft Corp"]
edition = "2018"
[workspace]
members = [
"client",
"interface",
"server",
]

Просмотреть файл

@ -1,35 +0,0 @@
# COM Example
A COM example in Rust
# Run
To install the server and run the client, simply run the following from the basic folder:
```bash
cargo run
```
Alternatively, you can choose to build/install/run the server and client seperately.
# Build & Install Server
You can build the server by running the following in the server folder:
```bash
cargo build
```
To "install" the server, you need to add the CLSIDs to your Windows registry. You can do that by running:
```bash
regsvr32 path/to/your/server/dll/file
```
# Run Client
To run the client which talks to the server, simply run the following from the client folder:
```bash
cargo run
```

2
examples/unsafe/client/.gitignore поставляемый
Просмотреть файл

@ -1,2 +0,0 @@
/target
**/*.rs.bk

Просмотреть файл

@ -1,12 +0,0 @@
[package]
name = "client"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
interface = { path = "../interface" }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "combaseapi", "objbase"] }

Просмотреть файл

@ -1,60 +0,0 @@
// import "unknwn.idl";
// [object, uuid(DF12E151-A29A-l1dO-8C2D-00BOC73925BA)]
// interface IAnimal : IUnknown {
// HRESULT Eat(void);
// }
// [object, uuid(DF12E152-A29A-l1dO-8C2D-0080C73925BA)]
// interface ICat : IAnimal {
// HRESULT IgnoreHumans(void);
// }
use com::Runtime;
use winapi::shared::winerror::{E_FAIL, S_OK};
use interface::{ISuperman, CLSID_CLARK_KENT_CLASS};
fn main() {
let runtime = match Runtime::new() {
Ok(runtime) => runtime,
Err(hr) => {
println!("Failed to initialize COM Library: {}", hr);
return;
}
};
run_safe_test(runtime);
}
fn run_safe_test(runtime: Runtime) {
let mut clark_kent = match runtime.create_instance::<dyn ISuperman>(&CLSID_CLARK_KENT_CLASS) {
Ok(clark_kent) => clark_kent,
Err(e) => {
println!("Failed to get clark kent, {:x}", e as u32);
return;
}
};
println!("Got clark kent!");
// [in] tests
assert!(clark_kent.take_input(10) == E_FAIL);
assert!(clark_kent.take_input(4) == S_OK);
// [out] tests
let mut var_to_populate = 0u32;
// let ptr = std::ptr::null_mut::<u32>();
clark_kent.populate_output(&mut var_to_populate as *mut u32);
assert!(var_to_populate == 6);
// [in, out] tests
let ptr_to_mutate = Box::into_raw(Box::new(6));
clark_kent.mutate_and_return(ptr_to_mutate);
assert!(unsafe { *ptr_to_mutate == 100 });
// [in] ptr tests
let in_var = Box::into_raw(Box::new(50));
assert!(clark_kent.take_input_ptr(in_var) == E_FAIL);
let in_var = Box::into_raw(Box::new(2));
assert!(clark_kent.take_input_ptr(in_var) == S_OK);
println!("Tests passed!");
}

3
examples/unsafe/interface/.gitignore поставляемый
Просмотреть файл

@ -1,3 +0,0 @@
/target
**/*.rs.bk
Cargo.lock

Просмотреть файл

@ -1,14 +0,0 @@
[package]
name = "interface"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "winerror", "winnt"] }
[lib]
crate-type = ["rlib", "cdylib"]

Просмотреть файл

@ -1,72 +0,0 @@
use com::{ComInterface, ComPtr, IUnknown, IUnknownVTable};
use winapi::shared::guiddef::IID;
use winapi::um::winnt::HRESULT;
pub const IID_ISUPERMAN: IID = IID {
Data1: 0xa56b76e4,
Data2: 0x4bd7,
Data3: 0x48b9,
Data4: [0x8a, 0xf6, 0xb9, 0x3f, 0x43, 0xe8, 0x69, 0xc8],
};
pub trait ISuperman: IUnknown {
// [in]
fn take_input(&mut self, in_var: u32) -> HRESULT;
// [out]
fn populate_output(&mut self, out_var: *mut u32) -> HRESULT;
// [in, out]
fn mutate_and_return(&mut self, in_out_var: *mut u32) -> HRESULT;
// [in] pointer
fn take_input_ptr(&mut self, in_ptr_var: *const u32) -> HRESULT;
// // [in, out] Interface
// fn take_interface();
// // [out] Interface
// fn populate_interface(ComOutPtr<ComItf>);
}
unsafe impl ComInterface for dyn ISuperman {
type VTable = ISupermanVTable;
const IID: IID = IID_ISUPERMAN;
}
pub type ISupermanVPtr = *const ISupermanVTable;
impl<T: ISuperman + ComInterface + ?Sized> ISuperman for ComPtr<T> {
fn take_input(&mut self, in_var: u32) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
unsafe { ((**itf_ptr).TakeInput)(itf_ptr, in_var) }
}
fn populate_output(&mut self, out_var: *mut u32) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
unsafe { ((**itf_ptr).PopulateOutput)(itf_ptr, out_var) }
}
fn mutate_and_return(&mut self, in_out_var: *mut u32) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
unsafe { ((**itf_ptr).MutateAndReturn)(itf_ptr, in_out_var) }
}
fn take_input_ptr(&mut self, in_ptr_var: *const u32) -> HRESULT {
let itf_ptr = self.into_raw() as *mut ISupermanVPtr;
unsafe { ((**itf_ptr).TakeInputPtr)(itf_ptr, in_ptr_var) }
}
}
#[allow(non_snake_case)]
#[repr(C)]
pub struct ISupermanVTable {
pub base: IUnknownVTable,
pub TakeInput: unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_var: u32) -> HRESULT,
pub PopulateOutput:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, out_var: *mut u32) -> HRESULT,
pub MutateAndReturn:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_out_var: *mut u32) -> HRESULT,
pub TakeInputPtr:
unsafe extern "stdcall" fn(*mut ISupermanVPtr, in_ptr_var: *const u32) -> HRESULT,
}

Просмотреть файл

@ -1,12 +0,0 @@
pub mod isuperman;
pub use isuperman::ISuperman;
use winapi::shared::guiddef::IID;
pub const CLSID_CLARK_KENT_CLASS: IID = IID {
Data1: 0xf26c011d,
Data2: 0xa586,
Data3: 0x4819,
Data4: [0xa3, 0x34, 0xa7, 0x40, 0xb4, 0xe7, 0xfd, 0x3c],
};

3
examples/unsafe/server/.gitignore поставляемый
Просмотреть файл

@ -1,3 +0,0 @@
/target
**/*.rs.bk
Cargo.lock

Просмотреть файл

@ -1,15 +0,0 @@
[package]
name = "server"
version = "0.1.0"
authors = ["Microsoft Corp"]
edition = "2018"
[dependencies]
com = { path = "../../.." }
interface = { path = "../interface" }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winreg", "winerror", "combaseapi"] }
[lib]
crate-type = ["rlib", "cdylib"]

Просмотреть файл

@ -1,175 +0,0 @@
use com::{IUnknown, IUnknownVPtr, IUnknownVTable, IID_IUNKNOWN};
use interface::isuperman::{ISuperman, ISupermanVPtr, ISupermanVTable, IID_ISUPERMAN};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID},
winerror::{E_FAIL, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct ClarkKent {
// inner must always be first because Cat is actually an ISuperman with one extra field at the end
inner: ISupermanVPtr,
ref_count: u32,
}
impl Drop for ClarkKent {
fn drop(&mut self) {
let _ = unsafe {
Box::from_raw(self.inner as *mut ISupermanVTable);
};
}
}
impl ISuperman for ClarkKent {
fn take_input(&mut self, in_var: u32) -> HRESULT {
println!("Received Input! Input is: {}", in_var);
if in_var > 5 {
return E_FAIL;
}
S_OK
}
fn populate_output(&mut self, out_var: *mut u32) -> HRESULT {
// let allocated_value = Box::into_raw(Box::new(6));
unsafe {
*out_var = 6;
}
S_OK
}
fn mutate_and_return(&mut self, in_out_var: *mut u32) -> HRESULT {
unsafe {
*in_out_var = 100;
}
S_OK
}
fn take_input_ptr(&mut self, in_ptr_var: *const u32) -> HRESULT {
unsafe {
let in_ptr_var = *in_ptr_var;
if in_ptr_var > 5 {
return E_FAIL;
}
}
S_OK
}
}
impl IUnknown for ClarkKent {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN) || IsEqualGUID(riid, &IID_ISUPERMAN) {
*ppv = &self.inner as *const _ as *mut c_void;
} else {
println!("Returning NO INTERFACE.");
return E_NOINTERFACE;
}
println!("Successful!.");
self.add_ref();
NOERROR
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for ClarkKent. Freeing memory...");
drop(self)
}
count
}
}
// Adjustor Thunks for ISuperman
unsafe extern "stdcall" fn query_interface(
this: *mut IUnknownVPtr,
riid: *const IID,
ppv: *mut *mut c_void,
) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).query_interface(riid, ppv)
}
unsafe extern "stdcall" fn add_ref(this: *mut IUnknownVPtr) -> u32 {
println!("Adding ref...");
let this = this as *mut ClarkKent;
(*this).add_ref()
}
// TODO: This could potentially be null or pointing to some invalid memory
unsafe extern "stdcall" fn release(this: *mut IUnknownVPtr) -> u32 {
println!("Releasing...");
let this = this as *mut ClarkKent;
(*this).release()
}
unsafe extern "stdcall" fn take_input(this: *mut ISupermanVPtr, in_var: u32) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).take_input(in_var)
}
unsafe extern "stdcall" fn populate_output(this: *mut ISupermanVPtr, out_var: *mut u32) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).populate_output(out_var)
}
unsafe extern "stdcall" fn mutate_and_return(
this: *mut ISupermanVPtr,
in_out_var: *mut u32,
) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).mutate_and_return(in_out_var)
}
unsafe extern "stdcall" fn take_input_ptr(
this: *mut ISupermanVPtr,
in_ptr_var: *const u32,
) -> HRESULT {
let this = this as *mut ClarkKent;
(*this).take_input_ptr(in_ptr_var)
}
impl ClarkKent {
pub(crate) fn new() -> ClarkKent {
println!("Allocating new Vtable...");
/* Initialising VTable for ISuperman */
let iunknown = IUnknownVTable {
QueryInterface: query_interface,
Release: release,
AddRef: add_ref,
};
let isuperman = ISupermanVTable {
base: iunknown,
TakeInput: take_input,
PopulateOutput: populate_output,
MutateAndReturn: mutate_and_return,
TakeInputPtr: take_input_ptr,
};
let vptr = Box::into_raw(Box::new(isuperman));
ClarkKent {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,98 +0,0 @@
use crate::clark_kent::ClarkKent;
use com::{
IClassFactory, IClassFactoryVPtr, IClassFactoryVTable, IUnknown, IUnknownVPtr,
IID_ICLASS_FACTORY, IID_IUNKNOWN,
};
use winapi::{
ctypes::c_void,
shared::{
guiddef::{IsEqualGUID, IID, REFIID},
minwindef::BOOL,
winerror::{CLASS_E_NOAGGREGATION, E_NOINTERFACE, HRESULT, NOERROR, S_OK},
},
};
#[repr(C)]
pub struct ClarkKentClass {
inner: IClassFactoryVPtr,
ref_count: u32,
}
impl IClassFactory for ClarkKentClass {
fn create_instance(
&mut self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT {
println!("Creating instance...");
if !aggr.is_null() {
return CLASS_E_NOAGGREGATION;
}
let mut ck = Box::new(ClarkKent::new());
ck.add_ref();
let hr = ck.query_interface(riid, ppv);
ck.release();
let _ptr = Box::into_raw(ck);
hr
}
fn lock_server(&mut self, _increment: BOOL) -> HRESULT {
println!("LockServer called");
S_OK
}
}
impl IUnknown for ClarkKentClass {
fn query_interface(&mut self, riid: *const IID, ppv: *mut *mut c_void) -> HRESULT {
/* TODO: This should be the safe wrapper. You shouldn't need to write unsafe code here. */
unsafe {
let riid = &*riid;
if IsEqualGUID(riid, &IID_IUNKNOWN) || IsEqualGUID(riid, &IID_ICLASS_FACTORY) {
*ppv = self as *const _ as *mut c_void;
self.add_ref();
NOERROR
} else {
E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for ClarkKentClass. Freeing memory...");
drop(self);
}
count
}
}
impl Drop for ClarkKentClass {
fn drop(&mut self) {
let _ = unsafe { Box::from_raw(self.inner as *mut IClassFactoryVTable) };
}
}
impl ClarkKentClass {
pub(crate) fn new() -> ClarkKentClass {
println!("Allocating new Vtable for ClarkKentClass...");
let iclassfactory = com::vtable!(ClarkKentClass: IClassFactory);
let vptr = Box::into_raw(Box::new(iclassfactory));
ClarkKentClass {
inner: vptr,
ref_count: 0,
}
}
}

Просмотреть файл

@ -1,71 +0,0 @@
use com::{
class_inproc_key_path, class_key_path, failed, get_dll_file_path, register_keys,
unregister_keys, IUnknown, RegistryKeyInfo,
};
use winapi::shared::{
guiddef::{IsEqualGUID, REFCLSID, REFIID},
minwindef::LPVOID,
winerror::{CLASS_E_CLASSNOTAVAILABLE, HRESULT},
};
pub use interface::CLSID_CLARK_KENT_CLASS;
mod clark_kent;
mod clark_kent_class;
use clark_kent_class::ClarkKentClass;
#[no_mangle]
extern "stdcall" fn DllGetClassObject(rclsid: REFCLSID, riid: REFIID, ppv: *mut LPVOID) -> HRESULT {
unsafe {
let rclsid = &*rclsid;
if IsEqualGUID(rclsid, &CLSID_CLARK_KENT_CLASS) {
println!("Allocating new object ClarkKentClass...");
let mut ckc = Box::new(ClarkKentClass::new());
ckc.add_ref();
let hr = ckc.query_interface(riid, ppv);
ckc.release();
Box::into_raw(ckc);
hr
} else {
CLASS_E_CLASSNOTAVAILABLE
}
}
}
// Function tries to add ALL-OR-NONE of the registry keys.
#[no_mangle]
extern "stdcall" fn DllRegisterServer() -> HRESULT {
let hr = register_keys(get_relevant_registry_keys());
if failed(hr) {
DllUnregisterServer();
}
hr
}
// Function tries to delete as many registry keys as possible.
#[no_mangle]
extern "stdcall" fn DllUnregisterServer() -> HRESULT {
let mut registry_keys_to_remove = get_relevant_registry_keys();
registry_keys_to_remove.reverse();
unregister_keys(registry_keys_to_remove)
}
fn get_relevant_registry_keys() -> Vec<RegistryKeyInfo> {
let file_path = get_dll_file_path();
// IMPORTANT: Assumption of order: Subkeys are located at a higher index than the parent key.
vec![
RegistryKeyInfo::new(
class_key_path(CLSID_CLARK_KENT_CLASS).as_str(),
"",
"Clark Kent Component",
),
RegistryKeyInfo::new(
class_inproc_key_path(CLSID_CLARK_KENT_CLASS).as_str(),
"",
file_path.clone().as_str(),
),
]
}

Просмотреть файл

@ -1,31 +0,0 @@
use std::process::Command;
fn main() {
let mut child_proc = Command::new("cmd")
.args(&["/C", "cls && cargo build --all --release"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Build failed.");
return;
}
let mut child_proc = Command::new("cmd")
.args(&["/C", "regsvr32 /s target/release/server.dll"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Failed to register server.dll! Make sure you have administrator rights!");
return;
}
let mut child_proc = Command::new("cmd")
.args(&["/C", "cargo run --release --package client"])
.spawn()
.expect("Something went wrong!");
if !child_proc.wait().unwrap().success() {
println!("Execution of client failed.");
return;
}
}

Просмотреть файл

@ -87,10 +87,121 @@ pub fn expand_derive_aggr_com_class(item: TokenStream) -> TokenStream {
out.push(gen_iunknown_impl(&input).into()); out.push(gen_iunknown_impl(&input).into());
out.push(gen_drop_impl(&base_itf_idents, &input).into()); out.push(gen_drop_impl(&base_itf_idents, &input).into());
out.push(gen_deref_impl(&input).into()); out.push(gen_deref_impl(&input).into());
out.push(gen_class_factory(&input).into());
TokenStream::from_iter(out) TokenStream::from_iter(out)
} }
// We manually generate a ClassFactory without macros, otherwise
// it leads to an infinite loop.
fn gen_class_factory(struct_item: &ItemStruct) -> HelperTokenStream {
let real_ident = macro_utils::get_real_ident(&struct_item.ident);
let class_factory_ident = macro_utils::get_class_factory_ident(&real_ident);
quote!(
#[repr(C)]
pub struct #class_factory_ident {
inner: <com::IClassFactory as com::ComInterface>::VPtr,
ref_count: u32,
}
impl com::IClassFactory for #class_factory_ident {
fn create_instance(
&mut self,
aggr: *mut <com::IUnknown as com::ComInterface>::VPtr,
riid: winapi::shared::guiddef::REFIID,
ppv: *mut *mut winapi::ctypes::c_void,
) -> winapi::shared::winerror::HRESULT {
use com::IUnknown;
let riid = unsafe { &*riid };
println!("Creating instance for {}", stringify!(#real_ident));
if !aggr.is_null() && !winapi::shared::guiddef::IsEqualGUID(riid, &<com::IUnknown as com::ComInterface>::IID) {
unsafe {
*ppv = std::ptr::null_mut::<winapi::ctypes::c_void>();
}
return winapi::shared::winerror::E_INVALIDARG;
}
let mut instance = #real_ident::new();
// This check has to be here because it can only be done after object
// is allocated on the heap (address of nonDelegatingUnknown fixed)
instance.set_iunknown(aggr);
// As an aggregable object, we have to add_ref through the
// non-delegating IUnknown on creation. Otherwise, we might
// add_ref the outer object if aggregated.
instance.inner_add_ref();
let hr = instance.inner_query_interface(riid, ppv);
instance.inner_release();
Box::into_raw(instance);
hr
}
fn lock_server(&mut self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT {
println!("LockServer called");
winapi::shared::winerror::S_OK
}
}
impl com::IUnknown for #class_factory_ident {
fn query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT {
// Bringing trait into scope to access add_ref method.
use com::IUnknown;
unsafe {
println!("Querying interface on {}...", stringify!(#class_factory_ident));
let riid = &*riid;
if winapi::shared::guiddef::IsEqualGUID(riid, &<com::IUnknown as com::ComInterface>::IID) | winapi::shared::guiddef::IsEqualGUID(riid, &<com::IClassFactory as com::ComInterface>::IID) {
*ppv = &self.inner as *const _ as *mut winapi::ctypes::c_void;
self.add_ref();
winapi::shared::winerror::NOERROR
} else {
*ppv = std::ptr::null_mut::<winapi::ctypes::c_void>();
winapi::shared::winerror::E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for {}. Freeing memory...", stringify!(#class_factory_ident));
unsafe { Box::from_raw(self as *const _ as *mut #class_factory_ident); }
}
count
}
}
impl #class_factory_ident {
pub(crate) fn new() -> Box<#class_factory_ident> {
use com::IClassFactory;
println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident));
let class_vtable = com::vtable!(#class_factory_ident: IClassFactory);
let vptr = Box::into_raw(Box::new(class_vtable));
let class_factory = #class_factory_ident {
inner: vptr,
ref_count: 0,
};
Box::new(class_factory)
}
}
)
}
fn gen_impl( fn gen_impl(
base_itf_idents: &[Ident], base_itf_idents: &[Ident],
aggr_itf_idents: &HashMap<Ident, Vec<Ident>>, aggr_itf_idents: &HashMap<Ident, Vec<Ident>>,
@ -100,12 +211,25 @@ fn gen_impl(
let allocate_fn = gen_allocate_fn(base_itf_idents, struct_item); let allocate_fn = gen_allocate_fn(base_itf_idents, struct_item);
let set_iunknown_fn = gen_set_iunknown_fn(); let set_iunknown_fn = gen_set_iunknown_fn();
let inner_iunknown_fns = gen_inner_iunknown_fns(base_itf_idents, aggr_itf_idents, struct_item); let inner_iunknown_fns = gen_inner_iunknown_fns(base_itf_idents, aggr_itf_idents, struct_item);
let get_class_object_fn = gen_get_class_object_fn(struct_item);
quote!( quote!(
impl #real_ident { impl #real_ident {
#allocate_fn #allocate_fn
#set_iunknown_fn #set_iunknown_fn
#inner_iunknown_fns #inner_iunknown_fns
#get_class_object_fn
}
)
}
fn gen_get_class_object_fn(struct_item: &ItemStruct) -> HelperTokenStream {
let real_ident = macro_utils::get_real_ident(&struct_item.ident);
let class_factory_ident = macro_utils::get_class_factory_ident(&real_ident);
quote!(
pub fn get_class_object() -> Box<#class_factory_ident> {
<#class_factory_ident>::new()
} }
) )
} }
@ -263,6 +387,12 @@ fn gen_deref_impl(struct_item: &ItemStruct) -> HelperTokenStream {
&self.#inner_init_field_ident &self.#inner_init_field_ident
} }
} }
impl std::ops::DerefMut for #real_ident {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.#inner_init_field_ident
}
}
) )
} }

Просмотреть файл

@ -23,8 +23,110 @@ pub fn expand_derive_com_class(item: TokenStream) -> TokenStream {
out.push(gen_iunknown_impl(&base_itf_idents, &aggr_itf_idents, &input).into()); out.push(gen_iunknown_impl(&base_itf_idents, &aggr_itf_idents, &input).into());
out.push(gen_drop_impl(&base_itf_idents, &input).into()); out.push(gen_drop_impl(&base_itf_idents, &input).into());
out.push(gen_deref_impl(&input).into()); out.push(gen_deref_impl(&input).into());
out.push(gen_class_factory(&input).into());
TokenStream::from_iter(out) // TokenStream::from_iter(out)
let out = TokenStream::from_iter(out);
println!("Result:\n{}", out.to_string());
out
}
// We manually generate a ClassFactory without macros, otherwise
// it leads to an infinite loop.
fn gen_class_factory(struct_item: &ItemStruct) -> HelperTokenStream {
let real_ident = macro_utils::get_real_ident(&struct_item.ident);
let class_factory_ident = macro_utils::get_class_factory_ident(&real_ident);
quote!(
#[repr(C)]
pub struct #class_factory_ident {
inner: <com::IClassFactory as com::ComInterface>::VPtr,
ref_count: u32,
}
impl com::IClassFactory for #class_factory_ident {
fn create_instance(
&mut self,
aggr: *mut <com::IUnknown as com::ComInterface>::VPtr,
riid: winapi::shared::guiddef::REFIID,
ppv: *mut *mut winapi::ctypes::c_void,
) -> winapi::shared::winerror::HRESULT {
use com::IUnknown;
println!("Creating instance for {}", stringify!(#real_ident));
if aggr != std::ptr::null_mut() {
return winapi::shared::winerror::CLASS_E_NOAGGREGATION;
}
let mut instance = #real_ident::new();
instance.add_ref();
let hr = instance.query_interface(riid, ppv);
instance.release();
Box::into_raw(instance);
hr
}
fn lock_server(&mut self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT {
println!("LockServer called");
winapi::shared::winerror::S_OK
}
}
impl com::IUnknown for #class_factory_ident {
fn query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT {
// Bringing trait into scope to access add_ref method.
use com::IUnknown;
unsafe {
println!("Querying interface on {}...", stringify!(#class_factory_ident));
let riid = &*riid;
if winapi::shared::guiddef::IsEqualGUID(riid, &<com::IUnknown as com::ComInterface>::IID) | winapi::shared::guiddef::IsEqualGUID(riid, &<com::IClassFactory as com::ComInterface>::IID) {
*ppv = &self.inner as *const _ as *mut winapi::ctypes::c_void;
self.add_ref();
winapi::shared::winerror::NOERROR
} else {
*ppv = std::ptr::null_mut::<winapi::ctypes::c_void>();
winapi::shared::winerror::E_NOINTERFACE
}
}
}
fn add_ref(&mut self) -> u32 {
self.ref_count += 1;
println!("Count now {}", self.ref_count);
self.ref_count
}
fn release(&mut self) -> u32 {
self.ref_count -= 1;
println!("Count now {}", self.ref_count);
let count = self.ref_count;
if count == 0 {
println!("Count is 0 for {}. Freeing memory...", stringify!(#class_factory_ident));
unsafe { Box::from_raw(self as *const _ as *mut #class_factory_ident); }
}
count
}
}
impl #class_factory_ident {
pub(crate) fn new() -> Box<#class_factory_ident> {
use com::IClassFactory;
println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident));
let class_vtable = com::vtable!(#class_factory_ident: IClassFactory);
let vptr = Box::into_raw(Box::new(class_vtable));
let class_factory = #class_factory_ident {
inner: vptr,
ref_count: 0,
};
Box::new(class_factory)
}
}
)
} }
fn gen_drop_impl(base_itf_idents: &[Ident], struct_item: &ItemStruct) -> HelperTokenStream { fn gen_drop_impl(base_itf_idents: &[Ident], struct_item: &ItemStruct) -> HelperTokenStream {
@ -59,6 +161,12 @@ fn gen_deref_impl(struct_item: &ItemStruct) -> HelperTokenStream {
&self.#inner_init_field_ident &self.#inner_init_field_ident
} }
} }
impl std::ops::DerefMut for #real_ident {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.#inner_init_field_ident
}
}
) )
} }
@ -165,6 +273,7 @@ fn gen_allocate_impl(base_itf_idents: &[Ident], struct_item: &ItemStruct) -> Hel
let init_ident = &struct_item.ident; let init_ident = &struct_item.ident;
let real_ident = get_real_ident(&struct_item.ident); let real_ident = get_real_ident(&struct_item.ident);
// Allocate stuff
let mut offset_count: usize = 0; let mut offset_count: usize = 0;
let base_inits = base_itf_idents.iter().map(|base| { let base_inits = base_itf_idents.iter().map(|base| {
let vtable_var_ident = format_ident!("{}_vtable", base.to_string().to_lowercase()); let vtable_var_ident = format_ident!("{}_vtable", base.to_string().to_lowercase());
@ -185,6 +294,9 @@ fn gen_allocate_impl(base_itf_idents: &[Ident], struct_item: &ItemStruct) -> Hel
let ref_count_ident = get_ref_count_ident(); let ref_count_ident = get_ref_count_ident();
let inner_init_field_ident = get_inner_init_field_ident(); let inner_init_field_ident = get_inner_init_field_ident();
// GetClassObject stuff
let class_factory_ident = macro_utils::get_class_factory_ident(&real_ident);
quote!( quote!(
impl #real_ident { impl #real_ident {
fn allocate(init_struct: #init_ident) -> Box<#real_ident> { fn allocate(init_struct: #init_ident) -> Box<#real_ident> {
@ -197,6 +309,10 @@ fn gen_allocate_impl(base_itf_idents: &[Ident], struct_item: &ItemStruct) -> Hel
}; };
Box::new(out) Box::new(out)
} }
pub fn get_class_object() -> Box<#class_factory_ident> {
<#class_factory_ident>::new()
}
} }
) )
} }

Просмотреть файл

@ -5,6 +5,10 @@ use syn::{
use std::collections::HashMap; use std::collections::HashMap;
pub fn get_class_factory_ident(class_ident: &Ident) -> Ident {
format_ident!("{}ClassFactory", class_ident)
}
pub fn get_vtable_ident(trait_ident: &Ident) -> Ident { pub fn get_vtable_ident(trait_ident: &Ident) -> Ident {
format_ident!("{}VTable", trait_ident) format_ident!("{}VTable", trait_ident)
} }

Просмотреть файл

@ -179,7 +179,7 @@ macro_rules! com_inproc_dll_module {
use com::IUnknown; use com::IUnknown;
let rclsid = unsafe{ &*rclsid }; let rclsid = unsafe{ &*rclsid };
if $crate::_winapi::shared::guiddef::IsEqualGUID(rclsid, &$clsid_one) { if $crate::_winapi::shared::guiddef::IsEqualGUID(rclsid, &$clsid_one) {
let mut instance = Box::new(<$classtype_one>::new()); let mut instance = <$classtype_one>::get_class_object();
instance.add_ref(); instance.add_ref();
let hr = instance.query_interface(riid, ppv); let hr = instance.query_interface(riid, ppv);
instance.release(); instance.release();
@ -187,7 +187,7 @@ macro_rules! com_inproc_dll_module {
hr hr
} $(else if $crate::_winapi::shared::guiddef::IsEqualGUID(rclsid, &$clsid) { } $(else if $crate::_winapi::shared::guiddef::IsEqualGUID(rclsid, &$clsid) {
let mut instance = Box::new(<$classtype>::new()); let mut instance = <$classtype>::get_class_object();
instance.add_ref(); instance.add_ref();
let hr = instance.query_interface(riid, ppv); let hr = instance.query_interface(riid, ppv);
instance.release(); instance.release();