Merge branch 'master' into interface-change

This commit is contained in:
Ryan Levick 2019-09-18 14:59:44 +02:00 коммит произвёл GitHub
Родитель 8c588b4916 9b9012e232
Коммит 915995616a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
27 изменённых файлов: 509 добавлений и 407 удалений

Просмотреть файл

@ -21,7 +21,7 @@ fn main() {
fn run_aggr_test(runtime: Runtime) {
let result = runtime.create_instance::<dyn IFileManager>(&CLSID_WINDOWS_FILE_MANAGER_CLASS);
let mut filemanager = match result {
let filemanager = match result {
Ok(filemanager) => filemanager,
Err(e) => {
println!("Failed to get filemanager, {:x}", e as u32);
@ -32,7 +32,7 @@ fn run_aggr_test(runtime: Runtime) {
filemanager.delete_all();
let result = filemanager.get_interface::<dyn ILocalFileManager>();
let mut lfm = match result {
let lfm = match result {
Some(lfm) => lfm,
None => {
println!("Failed to get Local File Manager!");
@ -43,7 +43,7 @@ fn run_aggr_test(runtime: Runtime) {
lfm.delete_local();
let result = runtime.create_instance::<dyn ILocalFileManager>(&CLSID_LOCAL_FILE_MANAGER_CLASS);
let mut localfilemanager = match result {
let localfilemanager = match result {
Ok(localfilemanager) => localfilemanager,
Err(e) => {
println!("Failed to get localfilemanager, {:x}", e as u32);

Просмотреть файл

@ -4,5 +4,5 @@ use winapi::shared::winerror::HRESULT;
#[com_interface(25A41124-23D0-46BE-8351-044889D5E37E)]
pub trait IFileManager: IUnknown {
fn delete_all(&mut self) -> HRESULT;
fn delete_all(&self) -> HRESULT;
}

Просмотреть файл

@ -4,5 +4,5 @@ use winapi::shared::winerror::HRESULT;
#[com_interface(4FC333E3-C389-4C48-B108-7895B0AF21AD)]
pub trait ILocalFileManager: IUnknown {
fn delete_local(&mut self) -> HRESULT;
fn delete_local(&self) -> HRESULT;
}

Просмотреть файл

@ -11,7 +11,7 @@ pub struct LocalFileManager {
}
impl ILocalFileManager for LocalFileManager {
fn delete_local(&mut self) -> HRESULT {
fn delete_local(&self) -> HRESULT {
println!("Deleting Locally...");
NOERROR
}

Просмотреть файл

@ -24,7 +24,7 @@ pub struct WindowsFileManager {
}
impl IFileManager for WindowsFileManager {
fn delete_all(&mut self) -> HRESULT {
fn delete_all(&self) -> HRESULT {
println!("Deleting all by delegating to Local and Remote File Managers...");
NOERROR
}

Просмотреть файл

@ -10,7 +10,7 @@ fn main() {
}
};
let mut factory = match runtime.get_class_object(&CLSID_CAT_CLASS) {
let factory = match runtime.get_class_object(&CLSID_CAT_CLASS) {
Ok(factory) => {
println!("Got cat class object");
factory
@ -21,7 +21,7 @@ fn main() {
}
};
let mut unknown = match factory.get_instance::<dyn IUnknown>() {
let unknown = match factory.get_instance::<dyn IUnknown>() {
Some(unknown) => {
println!("Got IUnknown");
unknown
@ -32,7 +32,7 @@ fn main() {
}
};
let mut animal = match unknown.get_interface::<dyn IAnimal>() {
let animal = match unknown.get_interface::<dyn IAnimal>() {
Some(animal) => {
println!("Got IAnimal");
animal
@ -46,7 +46,7 @@ fn main() {
animal.eat();
// Test cross-vtable interface queries for both directions.
let mut domestic_animal = match animal.get_interface::<dyn IDomesticAnimal>() {
let domestic_animal = match animal.get_interface::<dyn IDomesticAnimal>() {
Some(domestic_animal) => {
println!("Got IDomesticAnimal");
domestic_animal
@ -59,7 +59,7 @@ fn main() {
domestic_animal.train();
let mut new_cat = match domestic_animal.get_interface::<dyn ICat>() {
let new_cat = match domestic_animal.get_interface::<dyn ICat>() {
Some(new_cat) => {
println!("Got ICat");
new_cat
@ -72,7 +72,7 @@ fn main() {
new_cat.ignore_humans();
// Test querying within second vtable.
let mut domestic_animal_two = match domestic_animal.get_interface::<dyn IDomesticAnimal>() {
let domestic_animal_two = match domestic_animal.get_interface::<dyn IDomesticAnimal>() {
Some(domestic_animal_two) => {
println!("Got IDomesticAnimal");
domestic_animal_two
@ -84,12 +84,7 @@ fn main() {
};
domestic_animal_two.train();
// These doesn't compile
// animal.ignore_humans();
// animal.raw_add_ref();
// animal.add_ref();
let mut cat = match runtime.create_instance::<dyn ICat>(&CLSID_CAT_CLASS) {
let cat = match runtime.create_instance::<dyn ICat>(&CLSID_CAT_CLASS) {
Ok(cat) => {
println!("Got another cat");
cat

Просмотреть файл

@ -3,5 +3,5 @@ use winapi::um::winnt::HRESULT;
#[com_interface(EFF8970E-C50F-45E0-9284-291CE5A6F771)]
pub trait IAnimal: IUnknown {
fn eat(&mut self) -> HRESULT;
fn eat(&self) -> HRESULT;
}

Просмотреть файл

@ -5,5 +5,5 @@ use crate::IAnimal;
#[com_interface(F5353C58-CFD9-4204-8D92-D274C7578B53)]
pub trait ICat: IAnimal {
fn ignore_humans(&mut self) -> HRESULT;
fn ignore_humans(&self) -> HRESULT;
}

Просмотреть файл

@ -5,5 +5,5 @@ use crate::IAnimal;
#[com_interface(C22425DF-EFB2-4B85-933E-9CF7B23459E8)]
pub trait IDomesticAnimal: IAnimal {
fn train(&mut self) -> HRESULT;
fn train(&self) -> HRESULT;
}

Просмотреть файл

@ -12,21 +12,21 @@ pub struct BritishShortHairCat {
}
impl IDomesticAnimal for BritishShortHairCat {
fn train(&mut self) -> HRESULT {
fn train(&self) -> HRESULT {
println!("Training...");
NOERROR
}
}
impl ICat for BritishShortHairCat {
fn ignore_humans(&mut self) -> HRESULT {
fn ignore_humans(&self) -> HRESULT {
println!("Ignoring Humans...");
NOERROR
}
}
impl IAnimal for BritishShortHairCat {
fn eat(&mut self) -> HRESULT {
fn eat(&self) -> HRESULT {
println!("Eating...");
NOERROR
}

Просмотреть файл

@ -5,21 +5,31 @@ use syn::ItemStruct;
// We manually generate a ClassFactory without macros, otherwise
// it leads to an infinite loop.
pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream {
let base_interface_idents = co_class::class_factory::get_class_factory_base_interface_idents();
let aggr_map = co_class::class_factory::get_class_factory_aggr_map();
let struct_ident = &struct_item.ident;
let class_factory_ident = macro_utils::class_factory_ident(&struct_ident);
let struct_definition =
co_class::class_factory::gen_class_factory_struct_definition(&class_factory_ident);
let lock_server = co_class::class_factory::gen_lock_server();
let iunknown_impl = co_class::class_factory::gen_iunknown_impl(&class_factory_ident);
let class_factory_impl = co_class::class_factory::gen_class_factory_impl(&class_factory_ident);
let iunknown_impl = co_class::class_factory::gen_iunknown_impl(
&base_interface_idents,
&aggr_map,
&class_factory_ident,
);
let class_factory_impl = co_class::class_factory::gen_class_factory_impl(
&base_interface_idents,
&class_factory_ident,
);
quote! {
#struct_definition
impl com::IClassFactory for #class_factory_ident {
unsafe fn create_instance(
&mut self,
&self,
aggr: *mut <dyn com::IUnknown as com::ComInterface>::VPtr,
riid: winapi::shared::guiddef::REFIID,
ppv: *mut *mut winapi::ctypes::c_void,

Просмотреть файл

@ -1,7 +1,7 @@
use proc_macro2::TokenStream as HelperTokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Fields, Ident, ItemStruct};
use syn::{Ident, ItemStruct};
/// As an aggregable COM object, you need to have an inner non-delegating IUnknown vtable.
/// All IUnknown calls to this COM object will delegate to the IUnknown interface pointer
@ -14,36 +14,25 @@ pub fn generate(
let struct_ident = &struct_item.ident;
let vis = &struct_item.vis;
let bases_interface_idents = base_interface_idents.iter().map(|base| {
let field_ident = macro_utils::vptr_field_ident(&base);
quote!(#field_ident: <dyn #base as com::ComInterface>::VPtr)
});
let base_fields = co_class::com_struct::gen_base_fields(base_interface_idents);
let ref_count_field = co_class::com_struct::gen_ref_count_field();
let user_fields = co_class::com_struct::gen_user_fields(struct_item);
let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map);
let ref_count_ident = macro_utils::ref_count_ident();
// COM Fields for an aggregable coclass.
let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident();
let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident();
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
#aggr_field_ident: *mut <dyn com::IUnknown as com::ComInterface>::VPtr
)
});
quote!(
#[repr(C)]
#vis struct #struct_ident {
#(#bases_interface_idents,)*
#base_fields
#non_delegating_iunknown_field_ident: <dyn com::IUnknown as com::ComInterface>::VPtr,
// Non-reference counted interface pointer to outer IUnknown.
#iunknown_to_use_field_ident: *mut <dyn com::IUnknown as com::ComInterface>::VPtr,
#ref_count_ident: u32,
#(#aggregates,)*
#fields
#ref_count_field
#aggregate_fields
#user_fields
}
)
}

Просмотреть файл

@ -1,7 +1,7 @@
use proc_macro2::TokenStream as HelperTokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Fields, Ident, ItemStruct};
use syn::{Ident, ItemStruct};
/// Generates the methods that the com struct needs to have. These include:
/// allocate: To initialise the vtables, including the non_delegatingegating_iunknown one.
@ -17,9 +17,9 @@ pub fn generate(
let struct_ident = &struct_item.ident;
let allocate_fn = gen_allocate_fn(aggr_map, base_interface_idents, struct_item);
let set_iunknown_fn = gen_set_iunknown_fn();
let inner_iunknown_fns = gen_inner_iunknown_fns(base_interface_idents, aggr_map, struct_item);
let get_class_object_fn = gen_get_class_object_fn(struct_item);
let set_aggregate_fns = gen_set_aggregate_fns(aggr_map);
let inner_iunknown_fns = gen_inner_iunknown_fns(base_interface_idents, aggr_map, struct_ident);
let get_class_object_fn = co_class::com_struct_impl::gen_get_class_object_fn(struct_item);
let set_aggregate_fns = co_class::com_struct_impl::gen_set_aggregate_fns(aggr_map);
quote!(
impl #struct_ident {
@ -32,19 +32,6 @@ pub fn generate(
)
}
/// Function used by in-process DLL macro to get an instance of the
/// class object.
fn gen_get_class_object_fn(struct_item: &ItemStruct) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let class_factory_ident = macro_utils::class_factory_ident(&struct_ident);
quote!(
pub fn get_class_object() -> Box<#class_factory_ident> {
<#class_factory_ident>::new()
}
)
}
/// Function that should only be used by Class Object, to set the
/// object's iunknown_to_use, if the object is going to get aggregated.
fn gen_set_iunknown_fn() -> HelperTokenStream {
@ -67,32 +54,66 @@ fn gen_set_iunknown_fn() -> HelperTokenStream {
fn gen_inner_iunknown_fns(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_item: &ItemStruct,
struct_ident: &Ident,
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let ref_count_ident = macro_utils::ref_count_ident();
let inner_query_interface = gen_inner_query_interface(base_interface_idents, aggr_map);
let inner_add_ref = gen_inner_add_ref();
let inner_release = gen_inner_release(base_interface_idents, aggr_map, struct_ident);
quote!(
#inner_query_interface
#inner_add_ref
#inner_release
)
}
pub(crate) fn inner_add_ref(&mut self) -> u32 {
self.#ref_count_ident += 1;
println!("Count now {}", self.#ref_count_ident);
self.#ref_count_ident
pub fn gen_inner_add_ref() -> HelperTokenStream {
let add_ref_implementation = co_class::iunknown_impl::gen_add_ref_implementation();
quote! {
pub(crate) fn inner_add_ref(&self) -> u32 {
#add_ref_implementation
}
}
}
pub(crate) fn inner_release(&mut self) -> u32 {
self.#ref_count_ident -= 1;
println!("Count now {}", self.#ref_count_ident);
let count = self.#ref_count_ident;
if count == 0 {
println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident));
// drop(self)
unsafe { Box::from_raw(self as *const _ as *mut #struct_ident); }
pub fn gen_inner_release(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_ident: &Ident,
) -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
let release_decrement = co_class::iunknown_impl::gen_release_decrement(&ref_count_ident);
let release_assign_new_count_to_var =
co_class::iunknown_impl::gen_release_assign_new_count_to_var(
&ref_count_ident,
&ref_count_ident,
);
let release_new_count_var_zero_check =
co_class::iunknown_impl::gen_new_count_var_zero_check(&ref_count_ident);
let release_drops =
co_class::iunknown_impl::gen_release_drops(base_interface_idents, aggr_map, struct_ident);
let non_delegating_iunknown_drop = gen_non_delegating_iunknown_drop();
quote! {
unsafe fn inner_release(&self) -> u32 {
#release_decrement
#release_assign_new_count_to_var
if #release_new_count_var_zero_check {
#non_delegating_iunknown_drop
#release_drops
}
count
#ref_count_ident
}
}
}
fn gen_non_delegating_iunknown_drop() -> HelperTokenStream {
let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident();
quote!(
Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut <dyn com::IUnknown as com::ComInterface>::VTable);
)
}
@ -104,50 +125,13 @@ fn gen_inner_query_interface(
let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident();
// Generate match arms for implemented interfaces
let match_arms = base_interface_idents.iter().map(|base| {
let match_condition =
quote!(<dyn #base as com::ComInterface>::is_iid_in_inheritance_chain(riid));
let vptr_field_ident = macro_utils::vptr_field_ident(&base);
quote!(
else if #match_condition {
*ppv = &self.#vptr_field_ident as *const _ as *mut winapi::ctypes::c_void;
}
)
});
let base_match_arms = co_class::iunknown_impl::gen_base_match_arms(base_interface_idents);
// Generate match arms for aggregated interfaces
let aggr_match_arms = aggr_map.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| {
// Construct the OR match conditions for a single aggregated object.
let first_base_interface_ident = &aggr_base_interface_idents[0];
let first_aggr_match_condition = quote!(
<dyn #first_base_interface_ident as com::ComInterface>::is_iid_in_inheritance_chain(riid)
);
let rem_aggr_match_conditions = aggr_base_interface_idents.iter().skip(1).map(|base| {
quote!(|| <dyn #base as com::ComInterface>::is_iid_in_inheritance_chain(riid))
});
quote!(
else if #first_aggr_match_condition #(#rem_aggr_match_conditions)* {
let mut aggr_interface_ptr: ComPtr<dyn com::IUnknown> = ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void);
let hr = aggr_interface_ptr.query_interface(riid, ppv);
if com::failed(hr) {
return winapi::shared::winerror::E_NOINTERFACE;
}
// We release it as the previous call add_ref-ed the inner object.
// The intention is to transfer reference counting logic to the
// outer object.
aggr_interface_ptr.release();
core::mem::forget(aggr_interface_ptr);
}
)
});
let aggr_match_arms = co_class::iunknown_impl::gen_aggregate_match_arms(aggr_map);
quote!(
pub(crate) fn inner_query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> HRESULT {
pub(crate) fn inner_query_interface(&self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> HRESULT {
println!("Non delegating QI");
unsafe {
@ -155,7 +139,7 @@ fn gen_inner_query_interface(
if winapi::shared::guiddef::IsEqualGUID(riid, &com::IID_IUNKNOWN) {
*ppv = &self.#non_delegating_iunknown_field_ident as *const _ as *mut winapi::ctypes::c_void;
} #(#match_arms)* #(#aggr_match_arms)* else {
} #base_match_arms #aggr_match_arms else {
*ppv = std::ptr::null_mut::<winapi::ctypes::c_void>();
println!("Returning NO INTERFACE.");
return winapi::shared::winerror::E_NOINTERFACE;
@ -179,45 +163,26 @@ fn gen_allocate_fn(
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let mut offset_count: usize = 0;
let base_inits = base_interface_idents.iter().map(|base| {
let vtable_var_ident = quote::format_ident!("{}_vtable", base.to_string().to_lowercase());
let vptr_field_ident = macro_utils::vptr_field_ident(&base);
let base_inits =
co_class::com_struct_impl::gen_allocate_base_inits(struct_ident, base_interface_idents);
let out = quote!(
let #vtable_var_ident = com::vtable!(#struct_ident: #base, #offset_count);
let #vptr_field_ident = Box::into_raw(Box::new(#vtable_var_ident));
);
// Allocate function signature
let allocate_parameters =
co_class::com_struct_impl::gen_allocate_function_parameters_signature(struct_item);
offset_count += 1;
out
});
let base_fields = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(base);
quote!(#vptr_field_ident)
});
let ref_count_ident = macro_utils::ref_count_ident();
// Syntax for instantiating the fields of the struct.
let base_fields = co_class::com_struct_impl::gen_allocate_base_fields(base_interface_idents);
let ref_count_field = co_class::com_struct_impl::gen_allocate_ref_count_field();
let user_fields = co_class::com_struct_impl::gen_allocate_user_fields(struct_item);
let aggregate_fields = co_class::com_struct_impl::gen_allocate_aggregate_fields(aggr_map);
// Aggregable COM struct specific fields
let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident();
let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident();
let non_delegating_iunknown_offset = base_interface_idents.len();
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
let field_idents = fields.iter().map(|field| {
let field_ident = field.ident.as_ref().unwrap().clone();
quote!(#field_ident)
});
let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
#aggr_field_ident: std::ptr::null_mut()
)
});
quote!(
fn allocate(#fields) -> Box<#struct_ident> {
fn allocate(#allocate_parameters) -> Box<#struct_ident> {
println!("Allocating new VTable for {}", stringify!(#struct_ident));
// Non-delegating methods.
@ -253,33 +218,17 @@ fn gen_allocate_fn(
};
let #non_delegating_iunknown_field_ident = Box::into_raw(Box::new(__non_delegating_iunknown_vtable));
#(#base_inits)*
#base_inits
let out = #struct_ident {
#(#base_fields,)*
#base_fields
#non_delegating_iunknown_field_ident,
#iunknown_to_use_field_ident: std::ptr::null_mut::<<dyn com::IUnknown as com::ComInterface>::VPtr>(),
#ref_count_ident: 0,
#(#aggregate_inits,)*
#(#field_idents)*
#ref_count_field
#aggregate_fields
#user_fields
};
Box::new(out)
}
)
}
fn gen_set_aggregate_fns(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let mut fns = Vec::new();
for (aggr_field_ident, aggr_base_interface_idents) in aggr_map.iter() {
for base in aggr_base_interface_idents {
let set_aggregate_fn_ident = macro_utils::set_aggregate_fn_ident(&base);
fns.push(quote!(
fn #set_aggregate_fn_ident(&mut self, aggr: *mut <dyn com::IUnknown as com::ComInterface>::VPtr) {
// TODO: What happens if we are overwriting an existing aggregate?
self.#aggr_field_ident = aggr
}
));
}
}
quote!(#(#fns)*)
}

Просмотреть файл

@ -1,41 +0,0 @@
use proc_macro2::TokenStream as HelperTokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Ident, ItemStruct};
pub fn generate(
aggr_map: &HashMap<Ident, Vec<Ident>>,
base_interface_idents: &[Ident],
struct_item: &ItemStruct,
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident();
let box_from_raws = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(&base);
quote!(
Box::from_raw(self.#vptr_field_ident as *mut <dyn #base as com::ComInterface>::VTable);
)
});
let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
if !self.#aggr_field_ident.is_null() {
let mut aggr_interface_ptr: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void);
aggr_interface_ptr.release();
core::mem::forget(aggr_interface_ptr);
}
)
});
quote!(
impl std::ops::Drop for #struct_ident {
fn drop(&mut self) {
let _ = unsafe {
#(#aggregate_drops)*
#(#box_from_raws)*
Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut <dyn com::IUnknown as com::ComInterface>::VTable)
};
}
}
)
}

Просмотреть файл

@ -12,32 +12,33 @@ use syn::ItemStruct;
pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident();
let ptr_casting = quote! { as *const winapi::ctypes::c_void as *mut winapi::ctypes::c_void };
quote!(
impl com::IUnknown for #struct_ident {
unsafe fn query_interface(
&mut self,
&self,
riid: *const winapi::shared::guiddef::IID,
ppv: *mut *mut winapi::ctypes::c_void
) -> winapi::shared::winerror::HRESULT {
println!("Delegating QI");
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void);
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting);
let hr = iunknown_to_use.query_interface(riid, ppv);
core::mem::forget(iunknown_to_use);
hr
}
fn add_ref(&mut self) -> u32 {
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = unsafe { com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void) };
fn add_ref(&self) -> u32 {
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = unsafe { com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting) };
let res = iunknown_to_use.add_ref();
core::mem::forget(iunknown_to_use);
res
}
unsafe fn release(&mut self) -> u32 {
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void);
unsafe fn release(&self) -> u32 {
let mut iunknown_to_use: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting);
let res = iunknown_to_use.release();
core::mem::forget(iunknown_to_use);

Просмотреть файл

@ -7,20 +7,16 @@ use std::iter::FromIterator;
mod class_factory;
mod com_struct;
mod com_struct_impl;
mod drop_impl;
mod iunknown_impl;
pub fn expand_aggr_co_class(input: &ItemStruct, attr_args: &AttributeArgs) -> TokenStream {
let base_interface_idents = macro_utils::base_interface_idents(attr_args);
let aggr_interface_idents = macro_utils::get_aggr_map(attr_args);
let mut out = Vec::<TokenStream>::new();
let mut out: Vec<TokenStream> = Vec::new();
out.push(com_struct::generate(&aggr_interface_idents, &base_interface_idents, input).into());
out.push(
com_struct_impl::generate(&base_interface_idents, &aggr_interface_idents, input).into(),
);
out.push(com_struct_impl::generate(&base_interface_idents, &aggr_interface_idents, input).into());
out.push(iunknown_impl::generate(input).into());
out.push(drop_impl::generate(&aggr_interface_idents, &base_interface_idents, input).into());
out.push(class_factory::generate(input).into());
TokenStream::from_iter(out)

Просмотреть файл

@ -1,24 +1,43 @@
use proc_macro2::{Ident, TokenStream as HelperTokenStream};
use quote::quote;
use quote::{format_ident, quote};
use std::collections::HashMap;
use syn::ItemStruct;
fn get_iclass_factory_interface_ident() -> Ident {
format_ident!("IClassFactory")
}
pub fn get_class_factory_base_interface_idents() -> Vec<Ident> {
vec![get_iclass_factory_interface_ident()]
}
pub fn get_class_factory_aggr_map() -> HashMap<Ident, Vec<Ident>> {
HashMap::new()
}
// We manually generate a ClassFactory without macros, otherwise
// it leads to an infinite loop.
pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream {
// Manually define base_interface_idents and aggr_map usually obtained by
// parsing attributes.
let base_interface_idents = get_class_factory_base_interface_idents();
let aggr_map = get_class_factory_aggr_map();
let struct_ident = &struct_item.ident;
let class_factory_ident = macro_utils::class_factory_ident(&struct_ident);
let struct_definition = gen_class_factory_struct_definition(&class_factory_ident);
let lock_server = gen_lock_server();
let iunknown_impl = gen_iunknown_impl(&class_factory_ident);
let class_factory_impl = gen_class_factory_impl(&class_factory_ident);
let iunknown_impl = gen_iunknown_impl(&base_interface_idents, &aggr_map, &class_factory_ident);
let class_factory_impl = gen_class_factory_impl(&base_interface_idents, &class_factory_ident);
quote! {
#struct_definition
impl com::IClassFactory for #class_factory_ident {
unsafe fn create_instance(
&mut self,
&self,
aggr: *mut <dyn com::IUnknown as com::ComInterface>::VPtr,
riid: winapi::shared::guiddef::REFIID,
ppv: *mut *mut winapi::ctypes::c_void,
@ -49,13 +68,16 @@ pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream {
}
}
// Can't use gen_base_fields here, since user might not have imported com::IClassFactory.
pub fn gen_class_factory_struct_definition(class_factory_ident: &Ident) -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
let ref_count_field = crate::com_struct::gen_ref_count_field();
let interface_ident = get_iclass_factory_interface_ident();
let vptr_field_ident = macro_utils::vptr_field_ident(&interface_ident);
quote! {
#[repr(C)]
pub struct #class_factory_ident {
inner: <dyn com::IClassFactory as com::ComInterface>::VPtr,
#ref_count_ident: u32,
#vptr_field_ident: <dyn com::IClassFactory as com::ComInterface>::VPtr,
#ref_count_field
}
}
}
@ -63,17 +85,21 @@ pub fn gen_class_factory_struct_definition(class_factory_ident: &Ident) -> Helpe
pub fn gen_lock_server() -> HelperTokenStream {
quote! {
// TODO: Implement correctly
fn lock_server(&mut self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT {
fn lock_server(&self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT {
println!("LockServer called");
winapi::shared::winerror::S_OK
}
}
}
pub fn gen_iunknown_impl(class_factory_ident: &Ident) -> HelperTokenStream {
pub fn gen_iunknown_impl(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
class_factory_ident: &Ident,
) -> HelperTokenStream {
let query_interface = gen_query_interface(class_factory_ident);
let add_ref = crate::iunknown_impl::gen_add_ref();
let release = crate::iunknown_impl::gen_release(class_factory_ident);
let release = gen_release(&base_interface_idents, &aggr_map, class_factory_ident);
quote! {
impl com::IUnknown for #class_factory_ident {
#query_interface
@ -83,9 +109,43 @@ pub fn gen_iunknown_impl(class_factory_ident: &Ident) -> HelperTokenStream {
}
}
fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream {
pub fn gen_release(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_ident: &Ident,
) -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
let release_decrement = crate::iunknown_impl::gen_release_decrement(&ref_count_ident);
let release_assign_new_count_to_var = crate::iunknown_impl::gen_release_assign_new_count_to_var(
&ref_count_ident,
&ref_count_ident,
);
let release_new_count_var_zero_check =
crate::iunknown_impl::gen_new_count_var_zero_check(&ref_count_ident);
let release_drops =
crate::iunknown_impl::gen_release_drops(base_interface_idents, aggr_map, struct_ident);
quote! {
unsafe fn query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT {
unsafe fn release(&self) -> u32 {
use com::IClassFactory;
#release_decrement
#release_assign_new_count_to_var
if #release_new_count_var_zero_check {
#release_drops
}
#ref_count_ident
}
}
}
fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream {
let vptr_field_ident = macro_utils::vptr_field_ident(&get_iclass_factory_interface_ident());
quote! {
unsafe fn query_interface(&self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT {
// Bringing trait into scope to access add_ref method.
use com::IUnknown;
@ -93,7 +153,7 @@ fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream {
let riid = &*riid;
if winapi::shared::guiddef::IsEqualGUID(riid, &<dyn com::IUnknown as com::ComInterface>::IID) | winapi::shared::guiddef::IsEqualGUID(riid, &<dyn com::IClassFactory as com::ComInterface>::IID) {
*ppv = &self.inner as *const _ as *mut winapi::ctypes::c_void;
*ppv = &self.#vptr_field_ident as *const _ as *mut winapi::ctypes::c_void;
self.add_ref();
winapi::shared::winerror::NOERROR
} else {
@ -104,22 +164,29 @@ fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream {
}
}
pub fn gen_class_factory_impl(class_factory_ident: &Ident) -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
pub fn gen_class_factory_impl(
base_interface_idents: &[Ident],
class_factory_ident: &Ident,
) -> HelperTokenStream {
let ref_count_field = crate::com_struct_impl::gen_allocate_ref_count_field();
let base_fields = crate::com_struct_impl::gen_allocate_base_fields(base_interface_idents);
let base_inits =
crate::com_struct_impl::gen_allocate_base_inits(class_factory_ident, base_interface_idents);
quote! {
impl #class_factory_ident {
pub(crate) fn new() -> Box<#class_factory_ident> {
use com::IClassFactory;
println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident));
let class_vtable = com::vtable!(#class_factory_ident: IClassFactory);
// allocate directly since no macros generated an `allocate` function
let vptr = Box::into_raw(Box::new(class_vtable));
let class_factory = #class_factory_ident {
inner: vptr,
#ref_count_ident: 0,
println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident));
#base_inits
let out = #class_factory_ident {
#base_fields
#ref_count_field
};
Box::new(class_factory)
Box::new(out)
}
}
}

Просмотреть файл

@ -18,31 +18,50 @@ pub fn generate(
let struct_ident = &struct_item.ident;
let vis = &struct_item.vis;
let base_fields = gen_base_fields(base_interface_idents);
let ref_count_field = gen_ref_count_field();
let user_fields = gen_user_fields(struct_item);
let aggregate_fields = gen_aggregate_fields(aggr_map);
quote!(
#[repr(C)]
#vis struct #struct_ident {
#base_fields
#ref_count_field
#aggregate_fields
#user_fields
}
)
}
pub fn gen_base_fields(base_interface_idents: &[Ident]) -> HelperTokenStream {
let bases_interface_idents = base_interface_idents.iter().map(|base| {
let field_ident = macro_utils::vptr_field_ident(&base);
quote!(#field_ident: <dyn #base as com::ComInterface>::VPtr)
});
quote!(#(#bases_interface_idents,)*)
}
pub fn gen_ref_count_field() -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
quote!(#ref_count_ident: std::cell::Cell<u32>,)
}
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
pub fn gen_aggregate_fields(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
#aggr_field_ident: *mut <dyn com::IUnknown as com::ComInterface>::VPtr
)
});
quote!(
#[repr(C)]
#vis struct #struct_ident {
#(#bases_interface_idents,)*
#ref_count_ident: u32,
#(#aggregates,)*
#fields
}
)
quote!(#(#aggregates,)*)
}
pub fn gen_user_fields(struct_item: &ItemStruct) -> HelperTokenStream {
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
quote!(#fields)
}

Просмотреть файл

@ -3,9 +3,6 @@ use quote::{format_ident, quote};
use std::collections::HashMap;
use syn::{Fields, Ident, ItemStruct};
/// Generates the allocate and get_class_object function for the COM object.
/// allocate: instantiates the COM fields, such as vpointers for the COM object.
/// get_class_object: Instantiate an instance to the class object.
pub fn generate(
aggr_map: &HashMap<Ident, Vec<Ident>>,
base_interface_idents: &[Ident],
@ -13,7 +10,112 @@ pub fn generate(
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
// Allocate stuff
let allocate_fn = gen_allocate_fn(aggr_map, base_interface_idents, struct_item);
let set_aggregate_fns = gen_set_aggregate_fns(aggr_map);
let get_class_object_fn = gen_get_class_object_fn(struct_item);
quote!(
impl #struct_ident {
#allocate_fn
#get_class_object_fn
#set_aggregate_fns
}
)
}
/// Function used to instantiate the COM fields, such as vpointers for the COM object.
pub fn gen_allocate_fn(
aggr_map: &HashMap<Ident, Vec<Ident>>,
base_interface_idents: &[Ident],
struct_item: &ItemStruct,
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let base_inits = gen_allocate_base_inits(struct_ident, base_interface_idents);
// Allocate function signature
let allocate_parameters = gen_allocate_function_parameters_signature(struct_item);
// Syntax for instantiating the fields of the struct.
let base_fields = gen_allocate_base_fields(base_interface_idents);
let ref_count_field = gen_allocate_ref_count_field();
let user_fields = gen_allocate_user_fields(struct_item);
let aggregate_fields = gen_allocate_aggregate_fields(aggr_map);
// Initialise all aggregated objects as NULL.
quote!(
fn allocate(#allocate_parameters) -> Box<#struct_ident> {
println!("Allocating new VTable for {}", stringify!(#struct_ident));
#base_inits
let out = #struct_ident {
#base_fields
#ref_count_field
#aggregate_fields
#user_fields
};
Box::new(out)
}
)
}
pub fn gen_allocate_function_parameters_signature(struct_item: &ItemStruct) -> HelperTokenStream {
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
quote!(#fields)
}
pub fn gen_allocate_aggregate_fields(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
#aggr_field_ident: std::ptr::null_mut()
)
});
quote!(#(#aggregate_inits,)*)
}
// User field input as parameters to the allocate function.
pub fn gen_allocate_user_fields(struct_item: &ItemStruct) -> HelperTokenStream {
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
let field_idents = fields.iter().map(|field| {
let field_ident = field.ident.as_ref().unwrap().clone();
quote!(#field_ident)
});
quote!(#(#field_idents)*)
}
// Reference count field initialisation.
pub fn gen_allocate_ref_count_field() -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
quote!(
#ref_count_ident: std::cell::Cell::new(0),
)
}
// Generate the vptr field idents needed in the instantiation syntax of the COM struct.
pub fn gen_allocate_base_fields(base_interface_idents: &[Ident]) -> HelperTokenStream {
let base_fields = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(base);
quote!(#vptr_field_ident)
});
quote!(#(#base_fields,)*)
}
// Initialise VTables with the correct adjustor thunks, through the vtable! macro.
pub fn gen_allocate_base_inits(
struct_ident: &Ident,
base_interface_idents: &[Ident],
) -> HelperTokenStream {
let mut offset_count: usize = 0;
let base_inits = base_interface_idents.iter().map(|base| {
let vtable_var_ident = format_ident!("{}_vtable", base.to_string().to_lowercase());
@ -27,63 +129,28 @@ pub fn generate(
offset_count += 1;
out
});
let base_fields = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(base);
quote!(#vptr_field_ident)
});
let ref_count_ident = macro_utils::ref_count_ident();
// GetClassObject stuff
quote!(#(#base_inits)*)
}
/// Function used by in-process DLL macro to get an instance of the
/// class object.
pub fn gen_get_class_object_fn(struct_item: &ItemStruct) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let class_factory_ident = macro_utils::class_factory_ident(&struct_ident);
let fields = match &struct_item.fields {
Fields::Named(f) => &f.named,
_ => panic!("Found non Named fields in struct."),
};
let field_idents = fields.iter().map(|field| {
let field_ident = field.ident.as_ref().unwrap().clone();
quote!(#field_ident)
});
let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
#aggr_field_ident: std::ptr::null_mut()
)
});
let set_aggregate_fns = gen_set_aggregate_fns(aggr_map);
quote!(
impl #struct_ident {
fn allocate(#fields) -> Box<#struct_ident> {
println!("Allocating new VTable for {}", stringify!(#struct_ident));
#(#base_inits)*
let out = #struct_ident {
#(#base_fields,)*
#ref_count_ident: 0,
#(#aggregate_inits,)*
#(#field_idents)*
};
Box::new(out)
}
pub fn get_class_object() -> Box<#class_factory_ident> {
<#class_factory_ident>::new()
}
#set_aggregate_fns
pub fn get_class_object() -> Box<#class_factory_ident> {
<#class_factory_ident>::new()
}
)
}
fn gen_set_aggregate_fns(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
pub fn gen_set_aggregate_fns(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let mut fns = Vec::new();
for (aggr_field_ident, aggr_base_interface_idents) in aggr_map.iter() {
for base in aggr_base_interface_idents {
let set_aggregate_fn_ident = format_ident!(
"set_aggregate_{}",
macro_utils::camel_to_snake(&base.to_string())
);
let set_aggregate_fn_ident = macro_utils::set_aggregate_fn_ident(&base);
fns.push(quote!(
fn #set_aggregate_fn_ident(&mut self, aggr: *mut <dyn com::IUnknown as com::ComInterface>::VPtr) {
// TODO: What happens if we are overwriting an existing aggregate?

Просмотреть файл

@ -1,41 +0,0 @@
use proc_macro2::TokenStream as HelperTokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Ident, ItemStruct};
pub fn generate(
aggr_map: &HashMap<Ident, Vec<Ident>>,
base_interface_idents: &[Ident],
struct_item: &ItemStruct,
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let box_from_raws = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(&base);
quote!(
Box::from_raw(self.#vptr_field_ident as *mut <dyn #base as com::ComInterface>::VTable);
)
});
let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
if !self.#aggr_field_ident.is_null() {
let mut aggr_interface_ptr: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void);
aggr_interface_ptr.release();
core::mem::forget(aggr_interface_ptr);
}
)
});
quote!(
impl std::ops::Drop for #struct_ident {
fn drop(&mut self) {
use com::IUnknown;
let _ = unsafe {
#(#aggregate_drops)*
#(#box_from_raws)*
};
}
}
)
}

Просмотреть файл

@ -8,14 +8,14 @@ use syn::{Ident, ItemStruct};
/// any interfaces exposed through an aggregated object.
pub fn generate(
base_interface_idents: &[Ident],
aggr_interface_idents: &HashMap<Ident, Vec<Ident>>,
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_item: &ItemStruct,
) -> HelperTokenStream {
let struct_ident = &struct_item.ident;
let query_interface = gen_query_interface(base_interface_idents, aggr_interface_idents);
let query_interface = gen_query_interface(base_interface_idents, aggr_map);
let add_ref = gen_add_ref();
let release = gen_release(struct_ident);
let release = gen_release(base_interface_idents, aggr_map, struct_ident);
quote!(
impl com::IUnknown for #struct_ident {
@ -27,35 +27,125 @@ pub fn generate(
}
pub fn gen_add_ref() -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
let add_ref_implementation = gen_add_ref_implementation();
quote! {
fn add_ref(&mut self) -> u32 {
self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count");
println!("Count now {}", self.#ref_count_ident);
self.#ref_count_ident
fn add_ref(&self) -> u32 {
#add_ref_implementation
}
}
}
pub fn gen_release(struct_ident: &Ident) -> HelperTokenStream {
pub fn gen_add_ref_implementation() -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
quote! {
unsafe fn release(&mut self) -> u32 {
self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count");
println!("Count now {}", self.#ref_count_ident);
let count = self.#ref_count_ident;
if count == 0 {
println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident));
Box::from_raw(self as *const _ as *mut #struct_ident);
}
count
}
}
quote!(
let value = self.#ref_count_ident.get().checked_add(1).expect("Overflow of reference count");
self.#ref_count_ident.set(value);
println!("Count now {}", value);
value
)
}
fn gen_query_interface(
pub fn gen_release(
base_interface_idents: &[Ident],
aggr_interface_idents: &HashMap<Ident, Vec<Ident>>,
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_ident: &Ident,
) -> HelperTokenStream {
let ref_count_ident = macro_utils::ref_count_ident();
let release_decrement = gen_release_decrement(&ref_count_ident);
let release_assign_new_count_to_var =
gen_release_assign_new_count_to_var(&ref_count_ident, &ref_count_ident);
let release_new_count_var_zero_check = gen_new_count_var_zero_check(&ref_count_ident);
let release_drops = gen_release_drops(base_interface_idents, aggr_map, struct_ident);
quote! {
unsafe fn release(&self) -> u32 {
#release_decrement
#release_assign_new_count_to_var
if #release_new_count_var_zero_check {
#release_drops
}
#ref_count_ident
}
}
}
pub fn gen_release_drops(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
struct_ident: &Ident,
) -> HelperTokenStream {
let vptr_drops = gen_vptr_drops(base_interface_idents);
let aggregate_drops = gen_aggregate_drops(aggr_map);
let com_object_drop = gen_com_object_drop(struct_ident);
quote!(
#vptr_drops
#aggregate_drops
#com_object_drop
)
}
fn gen_aggregate_drops(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| {
quote!(
if !self.#aggr_field_ident.is_null() {
let mut aggr_interface_ptr: com::ComPtr<dyn com::IUnknown> = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void);
aggr_interface_ptr.release();
core::mem::forget(aggr_interface_ptr);
}
)
});
quote!(#(#aggregate_drops)*)
}
fn gen_vptr_drops(base_interface_idents: &[Ident]) -> HelperTokenStream {
let vptr_drops = base_interface_idents.iter().map(|base| {
let vptr_field_ident = macro_utils::vptr_field_ident(&base);
quote!(
Box::from_raw(self.#vptr_field_ident as *mut <dyn #base as com::ComInterface>::VTable);
)
});
quote!(#(#vptr_drops)*)
}
fn gen_com_object_drop(struct_ident: &Ident) -> HelperTokenStream {
quote!(
println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident));
Box::from_raw(self as *const _ as *mut #struct_ident);
)
}
pub fn gen_release_decrement(ref_count_ident: &Ident) -> HelperTokenStream {
quote!(
let value = self.#ref_count_ident.get().checked_sub(1).expect("Underflow of reference count");
self.#ref_count_ident.set(value);
)
}
pub fn gen_release_assign_new_count_to_var(
ref_count_ident: &Ident,
new_count_ident: &Ident,
) -> HelperTokenStream {
quote!(
let #new_count_ident = self.#ref_count_ident.get();
println!("Count now {}", #new_count_ident);
)
}
pub fn gen_new_count_var_zero_check(new_count_ident: &Ident) -> HelperTokenStream {
quote!(
#new_count_ident == 0
)
}
pub fn gen_query_interface(
base_interface_idents: &[Ident],
aggr_map: &HashMap<Ident, Vec<Ident>>,
) -> HelperTokenStream {
let first_vptr_field = macro_utils::vptr_field_ident(&base_interface_idents[0]);
@ -63,11 +153,11 @@ fn gen_query_interface(
let base_match_arms = gen_base_match_arms(base_interface_idents);
// Generate match arms for aggregated interfaces
let aggr_match_arms = gen_aggregate_match_arms(aggr_interface_idents);
let aggr_match_arms = gen_aggregate_match_arms(aggr_map);
quote!(
unsafe fn query_interface(
&mut self,
&self,
riid: *const winapi::shared::guiddef::IID,
ppv: *mut *mut winapi::ctypes::c_void
) -> winapi::shared::winerror::HRESULT {
@ -88,7 +178,7 @@ fn gen_query_interface(
)
}
fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream {
pub fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream {
// Generate match arms for implemented interfaces
let base_match_arms = base_interface_idents.iter().map(|base| {
let match_condition =
@ -105,10 +195,8 @@ fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream {
quote!(#(#base_match_arms)*)
}
fn gen_aggregate_match_arms(
aggr_interface_idents: &HashMap<Ident, Vec<Ident>>,
) -> HelperTokenStream {
let aggr_match_arms = aggr_interface_idents.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| {
pub fn gen_aggregate_match_arms(aggr_map: &HashMap<Ident, Vec<Ident>>) -> HelperTokenStream {
let aggr_match_arms = aggr_map.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| {
// Construct the OR match conditions for a single aggregated object.
let first_base_interface_ident = &aggr_base_interface_idents[0];

Просмотреть файл

@ -5,22 +5,18 @@ use syn::{AttributeArgs, ItemStruct};
use std::iter::FromIterator;
pub mod class_factory;
mod com_struct;
mod com_struct_impl;
mod drop_impl;
pub mod com_struct;
pub mod com_struct_impl;
pub mod iunknown_impl;
pub fn expand_co_class(input: &ItemStruct, attr_args: &AttributeArgs) -> TokenStream {
let base_interface_idents = macro_utils::base_interface_idents(attr_args);
let aggr_interface_idents = macro_utils::get_aggr_map(attr_args);
let mut out = Vec::<TokenStream>::new();
let mut out: Vec<TokenStream> = Vec::new();
out.push(com_struct::generate(&aggr_interface_idents, &base_interface_idents, input).into());
out.push(
com_struct_impl::generate(&aggr_interface_idents, &base_interface_idents, input).into(),
);
out.push(com_struct_impl::generate(&aggr_interface_idents, &base_interface_idents, input).into());
out.push(iunknown_impl::generate(&base_interface_idents, &aggr_interface_idents, input).into());
out.push(drop_impl::generate(&aggr_interface_idents, &base_interface_idents, input).into());
out.push(class_factory::generate(input).into());
TokenStream::from_iter(out)

Просмотреть файл

@ -102,7 +102,15 @@ fn gen_raw_params(interface_ident: &Ident, method: &TraitItemMethod) -> HelperTo
let vptr_ident = vptr::ident(&interface_ident.to_string());
for param in method.sig.inputs.iter() {
match param {
FnArg::Receiver(_n) => {
FnArg::Receiver(s) => {
assert!(
s.reference.is_some(),
"COM interface methods cannot take ownership of self"
);
assert!(
s.mutability.is_none(),
"COM interface methods cannot take mutable reference to self"
);
params.push(quote!(
*mut #vptr_ident,
));

Просмотреть файл

@ -96,7 +96,7 @@ fn gen_vtable_function(
let return_type = &fun.output;
quote! {
unsafe extern "stdcall" fn #function_ident<C: #interface_ident, O: com::offset::Offset>(#(#params)*) #return_type {
let this = arg0.sub(O::VALUE) as *mut C;
let this = arg0.sub(O::VALUE) as *const C as *mut C;
(*this).#method_name(#(#args)*)
}
}

Просмотреть файл

@ -37,13 +37,13 @@ impl<T: ?Sized + ComInterface> ComPtr<T> {
self.ptr
}
fn cast_and_add_ref(&mut self) {
fn cast_and_add_ref(&self) {
unsafe {
(*(self as *const _ as *mut ComPtr<dyn IUnknown>)).add_ref();
}
}
pub fn get_interface<S: ComInterface + ?Sized>(&mut self) -> Option<ComPtr<S>> {
pub fn get_interface<S: ComInterface + ?Sized>(&self) -> Option<ComPtr<S>> {
let mut ppv = std::ptr::null_mut::<c_void>();
let hr = unsafe {
(*(self as *const _ as *mut ComPtr<dyn IUnknown>))
@ -59,7 +59,6 @@ impl<T: ?Sized + ComInterface> ComPtr<T> {
impl<T: ComInterface + ?Sized> Drop for ComPtr<T> {
fn drop(&mut self) {
println!("Dropped!");
unsafe {
(*(self as *const _ as *mut ComPtr<dyn IUnknown>)).release();
}
@ -68,7 +67,7 @@ impl<T: ComInterface + ?Sized> Drop for ComPtr<T> {
impl<T: ComInterface> Clone for ComPtr<T> {
fn clone(&self) -> Self {
let mut new_ptr = ComPtr {
let new_ptr = ComPtr {
ptr: self.ptr,
phantom: PhantomData,
};

Просмотреть файл

@ -17,16 +17,16 @@ use crate::{
#[com_interface(00000001-0000-0000-c000-000000000046)]
pub trait IClassFactory: IUnknown {
unsafe fn create_instance(
&mut self,
&self,
aggr: *mut IUnknownVPtr,
riid: REFIID,
ppv: *mut *mut c_void,
) -> HRESULT;
fn lock_server(&mut self, increment: BOOL) -> HRESULT;
fn lock_server(&self, increment: BOOL) -> HRESULT;
}
impl ComPtr<dyn IClassFactory> {
pub fn get_instance<T: ComInterface + ?Sized>(&mut self) -> Option<ComPtr<T>> {
pub fn get_instance<T: ComInterface + ?Sized>(&self) -> Option<ComPtr<T>> {
let mut ppv = std::ptr::null_mut::<c_void>();
let aggr = std::ptr::null_mut();
let hr = unsafe { self.create_instance(aggr, &T::IID as *const IID, &mut ppv) };

Просмотреть файл

@ -14,7 +14,7 @@ pub trait IUnknown {
///
/// [`QueryInterface` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void)
/// [`IUnknown::get_interface`]: trait.IUnknown.html#method.get_interface
unsafe fn query_interface(&mut self, riid: REFIID, ppv: *mut *mut c_void) -> HRESULT;
unsafe fn query_interface(&self, riid: REFIID, ppv: *mut *mut c_void) -> HRESULT;
/// The COM [`AddRef` Method]
///
@ -23,7 +23,7 @@ pub trait IUnknown {
///
/// [`AddRef` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref
/// [`ComPtr`]: struct.ComPtr.html
fn add_ref(&mut self) -> u32;
fn add_ref(&self) -> u32;
/// The COM [`Release` Method]
///
@ -32,5 +32,5 @@ pub trait IUnknown {
///
/// [`Release` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release
/// [`ComPtr`]: struct.ComPtr.html
unsafe fn release(&mut self) -> u32;
unsafe fn release(&self) -> u32;
}