From 38533355703416d3a6016ab1fca10304455dbec3 Mon Sep 17 00:00:00 2001 From: adrianwithah Date: Thu, 12 Sep 2019 18:27:54 +0100 Subject: [PATCH 1/7] Refactor aggr_co_class com_struct_impl to use co_class generation methods --- macros/aggr_co_class/src/com_struct_impl.rs | 159 +++++-------------- macros/co_class/src/com_struct_impl.rs | 163 ++++++++++++++------ macros/co_class/src/iunknown_impl.rs | 4 +- macros/co_class/src/lib.rs | 2 +- 4 files changed, 158 insertions(+), 170 deletions(-) diff --git a/macros/aggr_co_class/src/com_struct_impl.rs b/macros/aggr_co_class/src/com_struct_impl.rs index 2b786e3..3240247 100644 --- a/macros/aggr_co_class/src/com_struct_impl.rs +++ b/macros/aggr_co_class/src/com_struct_impl.rs @@ -18,8 +18,8 @@ pub fn generate( let allocate_fn = gen_allocate_fn(aggr_map, base_interface_idents, struct_item); let set_iunknown_fn = gen_set_iunknown_fn(); let inner_iunknown_fns = gen_inner_iunknown_fns(base_interface_idents, aggr_map, struct_item); - let get_class_object_fn = gen_get_class_object_fn(struct_item); - let set_aggregate_fns = gen_set_aggregate_fns(aggr_map); + let get_class_object_fn = co_class::com_struct_impl::gen_get_class_object_fn(struct_item); + let set_aggregate_fns = co_class::com_struct_impl::gen_set_aggregate_fns(aggr_map); quote!( impl #struct_ident { @@ -32,19 +32,6 @@ pub fn generate( ) } -/// Function used by in-process DLL macro to get an instance of the -/// class object. -fn gen_get_class_object_fn(struct_item: &ItemStruct) -> HelperTokenStream { - let struct_ident = &struct_item.ident; - let class_factory_ident = macro_utils::class_factory_ident(&struct_ident); - - quote!( - pub fn get_class_object() -> Box<#class_factory_ident> { - <#class_factory_ident>::new() - } - ) -} - /// Function that should only be used by Class Object, to set the /// object's iunknown_to_use, if the object is going to get aggregated. fn gen_set_iunknown_fn() -> HelperTokenStream { @@ -70,30 +57,42 @@ fn gen_inner_iunknown_fns( struct_item: &ItemStruct, ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let ref_count_ident = macro_utils::ref_count_ident(); let inner_query_interface = gen_inner_query_interface(base_interface_idents, aggr_map); + let inner_add_ref = gen_inner_add_ref(); + let inner_release = gen_inner_release(struct_ident); quote!( #inner_query_interface + #inner_add_ref + #inner_release + ) +} +pub fn gen_inner_add_ref() -> HelperTokenStream { + let ref_count_ident = macro_utils::ref_count_ident(); + quote! { pub(crate) fn inner_add_ref(&mut self) -> u32 { - self.#ref_count_ident += 1; + self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count"); println!("Count now {}", self.#ref_count_ident); self.#ref_count_ident } + } +} - pub(crate) fn inner_release(&mut self) -> u32 { - self.#ref_count_ident -= 1; +pub fn gen_inner_release(struct_ident: &Ident) -> HelperTokenStream { + let ref_count_ident = macro_utils::ref_count_ident(); + quote! { + pub(crate) unsafe fn inner_release(&mut self) -> u32 { + self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count"); println!("Count now {}", self.#ref_count_ident); let count = self.#ref_count_ident; if count == 0 { println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); - // drop(self) - unsafe { Box::from_raw(self as *const _ as *mut #struct_ident); } + Box::from_raw(self as *const _ as *mut #struct_ident); } count } - ) + } } /// Non-delegating query interface @@ -104,47 +103,10 @@ fn gen_inner_query_interface( let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); // Generate match arms for implemented interfaces - let match_arms = base_interface_idents.iter().map(|base| { - let match_condition = - quote!(::is_iid_in_inheritance_chain(riid)); - let vptr_field_ident = macro_utils::vptr_field_ident(&base); - - quote!( - else if #match_condition { - *ppv = &self.#vptr_field_ident as *const _ as *mut winapi::ctypes::c_void; - } - ) - }); + let base_match_arms = co_class::iunknown_impl::gen_base_match_arms(base_interface_idents); // Generate match arms for aggregated interfaces - let aggr_match_arms = aggr_map.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| { - - // Construct the OR match conditions for a single aggregated object. - let first_base_interface_ident = &aggr_base_interface_idents[0]; - let first_aggr_match_condition = quote!( - ::is_iid_in_inheritance_chain(riid) - ); - let rem_aggr_match_conditions = aggr_base_interface_idents.iter().skip(1).map(|base| { - quote!(|| ::is_iid_in_inheritance_chain(riid)) - }); - - quote!( - else if #first_aggr_match_condition #(#rem_aggr_match_conditions)* { - let mut aggr_interface_ptr: ComPtr = ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void); - let hr = aggr_interface_ptr.query_interface(riid, ppv); - if com::failed(hr) { - return winapi::shared::winerror::E_NOINTERFACE; - } - - // We release it as the previous call add_ref-ed the inner object. - // The intention is to transfer reference counting logic to the - // outer object. - aggr_interface_ptr.release(); - - core::mem::forget(aggr_interface_ptr); - } - ) - }); + let aggr_match_arms = co_class::iunknown_impl::gen_aggregate_match_arms(aggr_map); quote!( pub(crate) fn inner_query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> HRESULT { @@ -155,7 +117,7 @@ fn gen_inner_query_interface( if winapi::shared::guiddef::IsEqualGUID(riid, &com::IID_IUNKNOWN) { *ppv = &self.#non_delegating_iunknown_field_ident as *const _ as *mut winapi::ctypes::c_void; - } #(#match_arms)* #(#aggr_match_arms)* else { + } #base_match_arms #aggr_match_arms else { *ppv = std::ptr::null_mut::(); println!("Returning NO INTERFACE."); return winapi::shared::winerror::E_NOINTERFACE; @@ -179,45 +141,24 @@ fn gen_allocate_fn( ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let mut offset_count: usize = 0; - let base_inits = base_interface_idents.iter().map(|base| { - let vtable_var_ident = quote::format_ident!("{}_vtable", base.to_string().to_lowercase()); - let vptr_field_ident = macro_utils::vptr_field_ident(&base); + let base_inits = co_class::com_struct_impl::gen_allocate_base_inits(struct_ident, base_interface_idents); - let out = quote!( - let #vtable_var_ident = com::vtable!(#struct_ident: #base, #offset_count); - let #vptr_field_ident = Box::into_raw(Box::new(#vtable_var_ident)); - ); + // Allocate function signature + let allocate_parameters = co_class::com_struct_impl::gen_allocate_function_parameters_signature(struct_item); - offset_count += 1; - out - }); - let base_fields = base_interface_idents.iter().map(|base| { - let vptr_field_ident = macro_utils::vptr_field_ident(base); - quote!(#vptr_field_ident) - }); - let ref_count_ident = macro_utils::ref_count_ident(); + // Syntax for instantiating the fields of the struct. + let base_fields = co_class::com_struct_impl::gen_allocate_base_fields(base_interface_idents); + let ref_count_field = co_class::com_struct_impl::gen_allocate_ref_count_field(); + let user_fields = co_class::com_struct_impl::gen_allocate_user_fields(struct_item); + let aggregate_fields = co_class::com_struct_impl::gen_allocate_aggregate_fields(aggr_map); + + // Aggregable COM struct specific fields let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); let non_delegating_iunknown_offset = base_interface_idents.len(); - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - let field_idents = fields.iter().map(|field| { - let field_ident = field.ident.as_ref().unwrap().clone(); - quote!(#field_ident) - }); - - let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - #aggr_field_ident: std::ptr::null_mut() - ) - }); - quote!( - fn allocate(#fields) -> Box<#struct_ident> { + fn allocate(#allocate_parameters) -> Box<#struct_ident> { println!("Allocating new VTable for {}", stringify!(#struct_ident)); // Non-delegating methods. @@ -253,33 +194,17 @@ fn gen_allocate_fn( }; let #non_delegating_iunknown_field_ident = Box::into_raw(Box::new(__non_delegating_iunknown_vtable)); - #(#base_inits)* + #base_inits + let out = #struct_ident { - #(#base_fields,)* + #base_fields #non_delegating_iunknown_field_ident, #iunknown_to_use_field_ident: std::ptr::null_mut::<::VPtr>(), - #ref_count_ident: 0, - #(#aggregate_inits,)* - #(#field_idents)* + #ref_count_field + #aggregate_fields + #user_fields }; Box::new(out) } ) -} - -fn gen_set_aggregate_fns(aggr_map: &HashMap>) -> HelperTokenStream { - let mut fns = Vec::new(); - for (aggr_field_ident, aggr_base_interface_idents) in aggr_map.iter() { - for base in aggr_base_interface_idents { - let set_aggregate_fn_ident = macro_utils::set_aggregate_fn_ident(&base); - fns.push(quote!( - fn #set_aggregate_fn_ident(&mut self, aggr: *mut ::VPtr) { - // TODO: What happens if we are overwriting an existing aggregate? - self.#aggr_field_ident = aggr - } - )); - } - } - - quote!(#(#fns)*) -} +} \ No newline at end of file diff --git a/macros/co_class/src/com_struct_impl.rs b/macros/co_class/src/com_struct_impl.rs index 43e171a..93f5615 100644 --- a/macros/co_class/src/com_struct_impl.rs +++ b/macros/co_class/src/com_struct_impl.rs @@ -3,9 +3,6 @@ use quote::{format_ident, quote}; use std::collections::HashMap; use syn::{Fields, Ident, ItemStruct}; -/// Generates the allocate and get_class_object function for the COM object. -/// allocate: instantiates the COM fields, such as vpointers for the COM object. -/// get_class_object: Instantiate an instance to the class object. pub fn generate( aggr_map: &HashMap>, base_interface_idents: &[Ident], @@ -13,7 +10,108 @@ pub fn generate( ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - // Allocate stuff + let allocate_fn = gen_allocate_fn(aggr_map, base_interface_idents, struct_item); + let set_aggregate_fns = gen_set_aggregate_fns(aggr_map); + let get_class_object_fn = gen_get_class_object_fn(struct_item); + + quote!( + impl #struct_ident { + #allocate_fn + #get_class_object_fn + #set_aggregate_fns + } + ) +} + +/// Function used to instantiate the COM fields, such as vpointers for the COM object. +pub fn gen_allocate_fn( + aggr_map: &HashMap>, + base_interface_idents: &[Ident], + struct_item: &ItemStruct, +) -> HelperTokenStream { + let struct_ident = &struct_item.ident; + + let base_inits = gen_allocate_base_inits(struct_ident, base_interface_idents); + + // Allocate function signature + let allocate_parameters = gen_allocate_function_parameters_signature(struct_item); + + // Syntax for instantiating the fields of the struct. + let base_fields = gen_allocate_base_fields(base_interface_idents); + let ref_count_field = gen_allocate_ref_count_field(); + let user_fields = gen_allocate_user_fields(struct_item); + let aggregate_fields = gen_allocate_aggregate_fields(aggr_map); + + // Initialise all aggregated objects as NULL. + quote!( + fn allocate(#allocate_parameters) -> Box<#struct_ident> { + println!("Allocating new VTable for {}", stringify!(#struct_ident)); + #base_inits + let out = #struct_ident { + #base_fields + #ref_count_field + #aggregate_fields + #user_fields + }; + Box::new(out) + } + ) +} + +pub fn gen_allocate_function_parameters_signature(struct_item: &ItemStruct) -> HelperTokenStream { + let fields = match &struct_item.fields { + Fields::Named(f) => &f.named, + _ => panic!("Found non Named fields in struct."), + }; + + quote!(#fields) +} + +pub fn gen_allocate_aggregate_fields(aggr_map: &HashMap>) -> HelperTokenStream { + let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| { + quote!( + #aggr_field_ident: std::ptr::null_mut() + ) + }); + + quote!(#(#aggregate_inits,)*) + +} + +// User field input as parameters to the allocate function. +pub fn gen_allocate_user_fields(struct_item: &ItemStruct) -> HelperTokenStream { + let fields = match &struct_item.fields { + Fields::Named(f) => &f.named, + _ => panic!("Found non Named fields in struct."), + }; + let field_idents = fields.iter().map(|field| { + let field_ident = field.ident.as_ref().unwrap().clone(); + quote!(#field_ident) + }); + + quote!(#(#field_idents)*) +} + +// Reference count field initialisation. +pub fn gen_allocate_ref_count_field() -> HelperTokenStream { + let ref_count_ident = macro_utils::ref_count_ident(); + quote!( + #ref_count_ident: 0, + ) +} + +// Generate the vptr field idents needed in the instantiation syntax of the COM struct. +pub fn gen_allocate_base_fields(base_interface_idents: &[Ident]) -> HelperTokenStream { + let base_fields = base_interface_idents.iter().map(|base| { + let vptr_field_ident = macro_utils::vptr_field_ident(base); + quote!(#vptr_field_ident) + }); + + quote!(#(#base_fields,)*) +} + +// Initialise VTables with the correct adjustor thunks, through the vtable! macro. +pub fn gen_allocate_base_inits(struct_ident: &Ident, base_interface_idents: &[Ident]) -> HelperTokenStream { let mut offset_count: usize = 0; let base_inits = base_interface_idents.iter().map(|base| { let vtable_var_ident = format_ident!("{}_vtable", base.to_string().to_lowercase()); @@ -27,63 +125,28 @@ pub fn generate( offset_count += 1; out }); - let base_fields = base_interface_idents.iter().map(|base| { - let vptr_field_ident = macro_utils::vptr_field_ident(base); - quote!(#vptr_field_ident) - }); - let ref_count_ident = macro_utils::ref_count_ident(); - // GetClassObject stuff + quote!(#(#base_inits)*) +} + +/// Function used by in-process DLL macro to get an instance of the +/// class object. +pub fn gen_get_class_object_fn(struct_item: &ItemStruct) -> HelperTokenStream { + let struct_ident = &struct_item.ident; let class_factory_ident = macro_utils::class_factory_ident(&struct_ident); - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - let field_idents = fields.iter().map(|field| { - let field_ident = field.ident.as_ref().unwrap().clone(); - quote!(#field_ident) - }); - - let aggregate_inits = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - #aggr_field_ident: std::ptr::null_mut() - ) - }); - - let set_aggregate_fns = gen_set_aggregate_fns(aggr_map); - quote!( - impl #struct_ident { - fn allocate(#fields) -> Box<#struct_ident> { - println!("Allocating new VTable for {}", stringify!(#struct_ident)); - #(#base_inits)* - let out = #struct_ident { - #(#base_fields,)* - #ref_count_ident: 0, - #(#aggregate_inits,)* - #(#field_idents)* - }; - Box::new(out) - } - - pub fn get_class_object() -> Box<#class_factory_ident> { - <#class_factory_ident>::new() - } - - #set_aggregate_fns + pub fn get_class_object() -> Box<#class_factory_ident> { + <#class_factory_ident>::new() } ) } -fn gen_set_aggregate_fns(aggr_map: &HashMap>) -> HelperTokenStream { +pub fn gen_set_aggregate_fns(aggr_map: &HashMap>) -> HelperTokenStream { let mut fns = Vec::new(); for (aggr_field_ident, aggr_base_interface_idents) in aggr_map.iter() { for base in aggr_base_interface_idents { - let set_aggregate_fn_ident = format_ident!( - "set_aggregate_{}", - macro_utils::camel_to_snake(&base.to_string()) - ); + let set_aggregate_fn_ident = macro_utils::set_aggregate_fn_ident(&base); fns.push(quote!( fn #set_aggregate_fn_ident(&mut self, aggr: *mut ::VPtr) { // TODO: What happens if we are overwriting an existing aggregate? diff --git a/macros/co_class/src/iunknown_impl.rs b/macros/co_class/src/iunknown_impl.rs index 0525164..089aa30 100644 --- a/macros/co_class/src/iunknown_impl.rs +++ b/macros/co_class/src/iunknown_impl.rs @@ -88,7 +88,7 @@ fn gen_query_interface( ) } -fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream { +pub fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream { // Generate match arms for implemented interfaces let base_match_arms = base_interface_idents.iter().map(|base| { let match_condition = @@ -105,7 +105,7 @@ fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream { quote!(#(#base_match_arms)*) } -fn gen_aggregate_match_arms( +pub fn gen_aggregate_match_arms( aggr_interface_idents: &HashMap>, ) -> HelperTokenStream { let aggr_match_arms = aggr_interface_idents.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| { diff --git a/macros/co_class/src/lib.rs b/macros/co_class/src/lib.rs index 77db959..a134d5a 100644 --- a/macros/co_class/src/lib.rs +++ b/macros/co_class/src/lib.rs @@ -6,7 +6,7 @@ use std::iter::FromIterator; pub mod class_factory; mod com_struct; -mod com_struct_impl; +pub mod com_struct_impl; mod drop_impl; pub mod iunknown_impl; From 4eb76077b5195380277f18f8dbea5219819047e8 Mon Sep 17 00:00:00 2001 From: adrianwithah Date: Thu, 12 Sep 2019 18:50:40 +0100 Subject: [PATCH 2/7] Refactor aggr_co_class com_struct to re-use gen methods from co_class --- macros/aggr_co_class/src/com_struct.rs | 28 +++++---------- macros/co_class/src/com_struct.rs | 47 ++++++++++++++++++-------- macros/co_class/src/com_struct_impl.rs | 2 ++ macros/co_class/src/lib.rs | 4 +-- 4 files changed, 45 insertions(+), 36 deletions(-) diff --git a/macros/aggr_co_class/src/com_struct.rs b/macros/aggr_co_class/src/com_struct.rs index 3d468d3..b75429a 100644 --- a/macros/aggr_co_class/src/com_struct.rs +++ b/macros/aggr_co_class/src/com_struct.rs @@ -14,36 +14,24 @@ pub fn generate( let struct_ident = &struct_item.ident; let vis = &struct_item.vis; - let bases_interface_idents = base_interface_idents.iter().map(|base| { - let field_ident = macro_utils::vptr_field_ident(&base); - quote!(#field_ident: ::VPtr) - }); + let base_fields = co_class::com_struct::gen_base_fields(base_interface_idents); + let ref_count_field = co_class::com_struct::gen_ref_count_field(); + let user_fields = co_class::com_struct::gen_user_fields(struct_item); + let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map); - let ref_count_ident = macro_utils::ref_count_ident(); let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - - let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - #aggr_field_ident: *mut ::VPtr - ) - }); - quote!( #[repr(C)] #vis struct #struct_ident { - #(#bases_interface_idents,)* + #base_fields #non_delegating_iunknown_field_ident: ::VPtr, // Non-reference counted interface pointer to outer IUnknown. #iunknown_to_use_field_ident: *mut ::VPtr, - #ref_count_ident: u32, - #(#aggregates,)* - #fields + #ref_count_field + #aggregate_fields + #user_fields } ) } diff --git a/macros/co_class/src/com_struct.rs b/macros/co_class/src/com_struct.rs index 5e74bd2..4a7fa3b 100644 --- a/macros/co_class/src/com_struct.rs +++ b/macros/co_class/src/com_struct.rs @@ -18,31 +18,50 @@ pub fn generate( let struct_ident = &struct_item.ident; let vis = &struct_item.vis; + let base_fields = gen_base_fields(base_interface_idents); + let ref_count_field = gen_ref_count_field(); + let user_fields = gen_user_fields(struct_item); + let aggregate_fields = gen_aggregate_fields(aggr_map); + + quote!( + #[repr(C)] + #vis struct #struct_ident { + #base_fields + #ref_count_field + #aggregate_fields + #user_fields + } + ) +} + +pub fn gen_base_fields(base_interface_idents: &[Ident]) -> HelperTokenStream { let bases_interface_idents = base_interface_idents.iter().map(|base| { let field_ident = macro_utils::vptr_field_ident(&base); quote!(#field_ident: ::VPtr) }); + quote!(#(#bases_interface_idents,)*) +} +pub fn gen_ref_count_field() -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); + quote!(#ref_count_ident: u32,) +} - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - +pub fn gen_aggregate_fields(aggr_map: &HashMap>) -> HelperTokenStream { let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| { quote!( #aggr_field_ident: *mut ::VPtr ) }); - quote!( - #[repr(C)] - #vis struct #struct_ident { - #(#bases_interface_idents,)* - #ref_count_ident: u32, - #(#aggregates,)* - #fields - } - ) + quote!(#(#aggregates,)*) +} + +pub fn gen_user_fields(struct_item: &ItemStruct) -> HelperTokenStream { + let fields = match &struct_item.fields { + Fields::Named(f) => &f.named, + _ => panic!("Found non Named fields in struct."), + }; + + quote!(#fields) } diff --git a/macros/co_class/src/com_struct_impl.rs b/macros/co_class/src/com_struct_impl.rs index 93f5615..a7575d6 100644 --- a/macros/co_class/src/com_struct_impl.rs +++ b/macros/co_class/src/com_struct_impl.rs @@ -46,7 +46,9 @@ pub fn gen_allocate_fn( quote!( fn allocate(#allocate_parameters) -> Box<#struct_ident> { println!("Allocating new VTable for {}", stringify!(#struct_ident)); + #base_inits + let out = #struct_ident { #base_fields #ref_count_field diff --git a/macros/co_class/src/lib.rs b/macros/co_class/src/lib.rs index a134d5a..dd22f44 100644 --- a/macros/co_class/src/lib.rs +++ b/macros/co_class/src/lib.rs @@ -5,9 +5,9 @@ use syn::{AttributeArgs, ItemStruct}; use std::iter::FromIterator; pub mod class_factory; -mod com_struct; +pub mod com_struct; pub mod com_struct_impl; -mod drop_impl; +pub mod drop_impl; pub mod iunknown_impl; // Macro expansion entry point. From 7038c9d186950ef552dd7b986f7767d0a3bd69d5 Mon Sep 17 00:00:00 2001 From: adrianwithah Date: Thu, 12 Sep 2019 18:59:24 +0100 Subject: [PATCH 3/7] Refactor aggr_co_class drop_impl to use gen methods from co_class --- macros/aggr_co_class/src/com_struct.rs | 1 + macros/aggr_co_class/src/drop_impl.rs | 33 ++++++++---------- macros/co_class/src/drop_impl.rs | 48 ++++++++++++++++---------- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/macros/aggr_co_class/src/com_struct.rs b/macros/aggr_co_class/src/com_struct.rs index b75429a..681be91 100644 --- a/macros/aggr_co_class/src/com_struct.rs +++ b/macros/aggr_co_class/src/com_struct.rs @@ -19,6 +19,7 @@ pub fn generate( let user_fields = co_class::com_struct::gen_user_fields(struct_item); let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map); + // COM Fields for an aggregable coclass. let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); diff --git a/macros/aggr_co_class/src/drop_impl.rs b/macros/aggr_co_class/src/drop_impl.rs index 377c4c6..8063244 100644 --- a/macros/aggr_co_class/src/drop_impl.rs +++ b/macros/aggr_co_class/src/drop_impl.rs @@ -9,33 +9,28 @@ pub fn generate( struct_item: &ItemStruct, ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); - let box_from_raws = base_interface_idents.iter().map(|base| { - let vptr_field_ident = macro_utils::vptr_field_ident(&base); - quote!( - Box::from_raw(self.#vptr_field_ident as *mut ::VTable); - ) - }); - let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - if !self.#aggr_field_ident.is_null() { - let mut aggr_interface_ptr: com::ComPtr = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void); - aggr_interface_ptr.release(); - core::mem::forget(aggr_interface_ptr); - } - ) - }); + let aggregate_drops = co_class::drop_impl::gen_aggregate_drops(aggr_map); + let vptr_drops = co_class::drop_impl::gen_vptr_drops(base_interface_idents); + let non_delegating_iunknown_drop = gen_non_delegating_iunknown_drop(); quote!( impl std::ops::Drop for #struct_ident { fn drop(&mut self) { let _ = unsafe { - #(#aggregate_drops)* - #(#box_from_raws)* - Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut ::VTable) + #aggregate_drops + #vptr_drops + #non_delegating_iunknown_drop }; } } ) } + + +fn gen_non_delegating_iunknown_drop() -> HelperTokenStream { + let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); + quote!( + Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut ::VTable) + ) +} \ No newline at end of file diff --git a/macros/co_class/src/drop_impl.rs b/macros/co_class/src/drop_impl.rs index 3e09d63..a42bb12 100644 --- a/macros/co_class/src/drop_impl.rs +++ b/macros/co_class/src/drop_impl.rs @@ -9,13 +9,25 @@ pub fn generate( struct_item: &ItemStruct, ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let box_from_raws = base_interface_idents.iter().map(|base| { - let vptr_field_ident = macro_utils::vptr_field_ident(&base); - quote!( - Box::from_raw(self.#vptr_field_ident as *mut ::VTable); - ) - }); + let vptr_drops = gen_vptr_drops(base_interface_idents); + let aggregate_drops = gen_aggregate_drops(aggr_map); + + quote!( + impl std::ops::Drop for #struct_ident { + fn drop(&mut self) { + use com::IUnknown; + + let _ = unsafe { + #aggregate_drops + #vptr_drops + }; + } + } + ) +} + +pub fn gen_aggregate_drops(aggr_map: &HashMap>) -> HelperTokenStream { let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| { quote!( if !self.#aggr_field_ident.is_null() { @@ -26,16 +38,16 @@ pub fn generate( ) }); - quote!( - impl std::ops::Drop for #struct_ident { - fn drop(&mut self) { - use com::IUnknown; - - let _ = unsafe { - #(#aggregate_drops)* - #(#box_from_raws)* - }; - } - } - ) + quote!(#(#aggregate_drops)*) +} + +pub fn gen_vptr_drops(base_interface_idents: &[Ident]) -> HelperTokenStream { + let vptr_drops = base_interface_idents.iter().map(|base| { + let vptr_field_ident = macro_utils::vptr_field_ident(&base); + quote!( + Box::from_raw(self.#vptr_field_ident as *mut ::VTable); + ) + }); + + quote!(#(#vptr_drops)*) } From e02f59e463266034ad7be44fdb81cf1a5b21614f Mon Sep 17 00:00:00 2001 From: adrianwithah Date: Thu, 12 Sep 2019 19:05:43 +0100 Subject: [PATCH 4/7] Address warnings, ran cargo fmt --- macros/aggr_co_class/src/com_struct.rs | 4 ++-- macros/aggr_co_class/src/com_struct_impl.rs | 12 +++++++----- macros/aggr_co_class/src/drop_impl.rs | 3 +-- macros/co_class/src/com_struct.rs | 2 +- macros/co_class/src/com_struct_impl.rs | 8 +++++--- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/macros/aggr_co_class/src/com_struct.rs b/macros/aggr_co_class/src/com_struct.rs index 681be91..6d5a104 100644 --- a/macros/aggr_co_class/src/com_struct.rs +++ b/macros/aggr_co_class/src/com_struct.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream as HelperTokenStream; use quote::quote; use std::collections::HashMap; -use syn::{Fields, Ident, ItemStruct}; +use syn::{Ident, ItemStruct}; /// As an aggregable COM object, you need to have an inner non-delegating IUnknown vtable. /// All IUnknown calls to this COM object will delegate to the IUnknown interface pointer @@ -17,7 +17,7 @@ pub fn generate( let base_fields = co_class::com_struct::gen_base_fields(base_interface_idents); let ref_count_field = co_class::com_struct::gen_ref_count_field(); let user_fields = co_class::com_struct::gen_user_fields(struct_item); - let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map); + let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map); // COM Fields for an aggregable coclass. let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); diff --git a/macros/aggr_co_class/src/com_struct_impl.rs b/macros/aggr_co_class/src/com_struct_impl.rs index 3240247..13bb7dd 100644 --- a/macros/aggr_co_class/src/com_struct_impl.rs +++ b/macros/aggr_co_class/src/com_struct_impl.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream as HelperTokenStream; use quote::quote; use std::collections::HashMap; -use syn::{Fields, Ident, ItemStruct}; +use syn::{Ident, ItemStruct}; /// Generates the methods that the com struct needs to have. These include: /// allocate: To initialise the vtables, including the non_delegatingegating_iunknown one. @@ -141,10 +141,12 @@ fn gen_allocate_fn( ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let base_inits = co_class::com_struct_impl::gen_allocate_base_inits(struct_ident, base_interface_idents); + let base_inits = + co_class::com_struct_impl::gen_allocate_base_inits(struct_ident, base_interface_idents); // Allocate function signature - let allocate_parameters = co_class::com_struct_impl::gen_allocate_function_parameters_signature(struct_item); + let allocate_parameters = + co_class::com_struct_impl::gen_allocate_function_parameters_signature(struct_item); // Syntax for instantiating the fields of the struct. let base_fields = co_class::com_struct_impl::gen_allocate_base_fields(base_interface_idents); @@ -152,7 +154,7 @@ fn gen_allocate_fn( let user_fields = co_class::com_struct_impl::gen_allocate_user_fields(struct_item); let aggregate_fields = co_class::com_struct_impl::gen_allocate_aggregate_fields(aggr_map); - // Aggregable COM struct specific fields + // Aggregable COM struct specific fields let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); let non_delegating_iunknown_offset = base_interface_idents.len(); @@ -207,4 +209,4 @@ fn gen_allocate_fn( Box::new(out) } ) -} \ No newline at end of file +} diff --git a/macros/aggr_co_class/src/drop_impl.rs b/macros/aggr_co_class/src/drop_impl.rs index 8063244..cae5a2e 100644 --- a/macros/aggr_co_class/src/drop_impl.rs +++ b/macros/aggr_co_class/src/drop_impl.rs @@ -27,10 +27,9 @@ pub fn generate( ) } - fn gen_non_delegating_iunknown_drop() -> HelperTokenStream { let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); quote!( Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut ::VTable) ) -} \ No newline at end of file +} diff --git a/macros/co_class/src/com_struct.rs b/macros/co_class/src/com_struct.rs index 4a7fa3b..6c58a24 100644 --- a/macros/co_class/src/com_struct.rs +++ b/macros/co_class/src/com_struct.rs @@ -21,7 +21,7 @@ pub fn generate( let base_fields = gen_base_fields(base_interface_idents); let ref_count_field = gen_ref_count_field(); let user_fields = gen_user_fields(struct_item); - let aggregate_fields = gen_aggregate_fields(aggr_map); + let aggregate_fields = gen_aggregate_fields(aggr_map); quote!( #[repr(C)] diff --git a/macros/co_class/src/com_struct_impl.rs b/macros/co_class/src/com_struct_impl.rs index a7575d6..9eeb0f0 100644 --- a/macros/co_class/src/com_struct_impl.rs +++ b/macros/co_class/src/com_struct_impl.rs @@ -48,7 +48,7 @@ pub fn gen_allocate_fn( println!("Allocating new VTable for {}", stringify!(#struct_ident)); #base_inits - + let out = #struct_ident { #base_fields #ref_count_field @@ -77,7 +77,6 @@ pub fn gen_allocate_aggregate_fields(aggr_map: &HashMap>) -> H }); quote!(#(#aggregate_inits,)*) - } // User field input as parameters to the allocate function. @@ -113,7 +112,10 @@ pub fn gen_allocate_base_fields(base_interface_idents: &[Ident]) -> HelperTokenS } // Initialise VTables with the correct adjustor thunks, through the vtable! macro. -pub fn gen_allocate_base_inits(struct_ident: &Ident, base_interface_idents: &[Ident]) -> HelperTokenStream { +pub fn gen_allocate_base_inits( + struct_ident: &Ident, + base_interface_idents: &[Ident], +) -> HelperTokenStream { let mut offset_count: usize = 0; let base_inits = base_interface_idents.iter().map(|base| { let vtable_var_ident = format_ident!("{}_vtable", base.to_string().to_lowercase()); From bab1d095c129b7247998d3abf1cb11310b7fa764 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Fri, 13 Sep 2019 15:35:15 +0200 Subject: [PATCH 5/7] Change interfaces to not have exclusive access of self --- examples/basic/client/src/main.rs | 19 +++++--------- examples/basic/interface/src/ianimal.rs | 2 +- examples/basic/interface/src/icat.rs | 2 +- .../basic/interface/src/idomesticanimal.rs | 2 +- .../server/src/british_short_hair_cat.rs | 6 ++--- macros/aggr_co_class/src/class_factory.rs | 2 +- macros/aggr_co_class/src/com_struct.rs | 2 +- macros/aggr_co_class/src/com_struct_impl.rs | 26 +++++++++---------- macros/aggr_co_class/src/iunknown_impl.rs | 13 +++++----- macros/co_class/src/class_factory.rs | 10 +++---- macros/co_class/src/com_struct.rs | 2 +- macros/co_class/src/com_struct_impl.rs | 2 +- macros/co_class/src/iunknown_impl.rs | 20 +++++++------- .../src/vtable_macro.rs | 2 +- src/comptr.rs | 7 +++-- src/iclassfactory.rs | 6 ++--- src/iunknown.rs | 6 ++--- 17 files changed, 63 insertions(+), 66 deletions(-) diff --git a/examples/basic/client/src/main.rs b/examples/basic/client/src/main.rs index b4779a8..087afb3 100644 --- a/examples/basic/client/src/main.rs +++ b/examples/basic/client/src/main.rs @@ -10,7 +10,7 @@ fn main() { } }; - let mut factory = match runtime.get_class_object(&CLSID_CAT_CLASS) { + let factory = match runtime.get_class_object(&CLSID_CAT_CLASS) { Ok(factory) => { println!("Got cat class object"); factory @@ -21,7 +21,7 @@ fn main() { } }; - let mut unknown = match factory.get_instance::() { + let unknown = match factory.get_instance::() { Some(unknown) => { println!("Got IUnknown"); unknown @@ -32,7 +32,7 @@ fn main() { } }; - let mut animal = match unknown.get_interface::() { + let animal = match unknown.get_interface::() { Some(animal) => { println!("Got IAnimal"); animal @@ -46,7 +46,7 @@ fn main() { animal.eat(); // Test cross-vtable interface queries for both directions. - let mut domestic_animal = match animal.get_interface::() { + let domestic_animal = match animal.get_interface::() { Some(domestic_animal) => { println!("Got IDomesticAnimal"); domestic_animal @@ -59,7 +59,7 @@ fn main() { domestic_animal.train(); - let mut new_cat = match domestic_animal.get_interface::() { + let new_cat = match domestic_animal.get_interface::() { Some(new_cat) => { println!("Got ICat"); new_cat @@ -72,7 +72,7 @@ fn main() { new_cat.ignore_humans(); // Test querying within second vtable. - let mut domestic_animal_two = match domestic_animal.get_interface::() { + let domestic_animal_two = match domestic_animal.get_interface::() { Some(domestic_animal_two) => { println!("Got IDomesticAnimal"); domestic_animal_two @@ -84,12 +84,7 @@ fn main() { }; domestic_animal_two.train(); - // These doesn't compile - // animal.ignore_humans(); - // animal.raw_add_ref(); - // animal.add_ref(); - - let mut cat = match runtime.create_instance::(&CLSID_CAT_CLASS) { + let cat = match runtime.create_instance::(&CLSID_CAT_CLASS) { Ok(cat) => { println!("Got another cat"); cat diff --git a/examples/basic/interface/src/ianimal.rs b/examples/basic/interface/src/ianimal.rs index 04b9de6..60c96dc 100644 --- a/examples/basic/interface/src/ianimal.rs +++ b/examples/basic/interface/src/ianimal.rs @@ -3,5 +3,5 @@ use winapi::um::winnt::HRESULT; #[com_interface(EFF8970E-C50F-45E0-9284-291CE5A6F771)] pub trait IAnimal: IUnknown { - fn eat(&mut self) -> HRESULT; + fn eat(&self) -> HRESULT; } diff --git a/examples/basic/interface/src/icat.rs b/examples/basic/interface/src/icat.rs index fc05f10..796495b 100644 --- a/examples/basic/interface/src/icat.rs +++ b/examples/basic/interface/src/icat.rs @@ -5,5 +5,5 @@ use crate::IAnimal; #[com_interface(F5353C58-CFD9-4204-8D92-D274C7578B53)] pub trait ICat: IAnimal { - fn ignore_humans(&mut self) -> HRESULT; + fn ignore_humans(&self) -> HRESULT; } diff --git a/examples/basic/interface/src/idomesticanimal.rs b/examples/basic/interface/src/idomesticanimal.rs index 0a48eae..8ecd36b 100644 --- a/examples/basic/interface/src/idomesticanimal.rs +++ b/examples/basic/interface/src/idomesticanimal.rs @@ -5,5 +5,5 @@ use crate::IAnimal; #[com_interface(C22425DF-EFB2-4B85-933E-9CF7B23459E8)] pub trait IDomesticAnimal: IAnimal { - fn train(&mut self) -> HRESULT; + fn train(&self) -> HRESULT; } diff --git a/examples/basic/server/src/british_short_hair_cat.rs b/examples/basic/server/src/british_short_hair_cat.rs index df5cffa..8aba4d8 100644 --- a/examples/basic/server/src/british_short_hair_cat.rs +++ b/examples/basic/server/src/british_short_hair_cat.rs @@ -12,21 +12,21 @@ pub struct BritishShortHairCat { } impl IDomesticAnimal for BritishShortHairCat { - fn train(&mut self) -> HRESULT { + fn train(&self) -> HRESULT { println!("Training..."); NOERROR } } impl ICat for BritishShortHairCat { - fn ignore_humans(&mut self) -> HRESULT { + fn ignore_humans(&self) -> HRESULT { println!("Ignoring Humans..."); NOERROR } } impl IAnimal for BritishShortHairCat { - fn eat(&mut self) -> HRESULT { + fn eat(&self) -> HRESULT { println!("Eating..."); NOERROR } diff --git a/macros/aggr_co_class/src/class_factory.rs b/macros/aggr_co_class/src/class_factory.rs index c31fc95..141ba6f 100644 --- a/macros/aggr_co_class/src/class_factory.rs +++ b/macros/aggr_co_class/src/class_factory.rs @@ -19,7 +19,7 @@ pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { impl com::IClassFactory for #class_factory_ident { unsafe fn create_instance( - &mut self, + &self, aggr: *mut ::VPtr, riid: winapi::shared::guiddef::REFIID, ppv: *mut *mut winapi::ctypes::c_void, diff --git a/macros/aggr_co_class/src/com_struct.rs b/macros/aggr_co_class/src/com_struct.rs index 3d468d3..3c53288 100644 --- a/macros/aggr_co_class/src/com_struct.rs +++ b/macros/aggr_co_class/src/com_struct.rs @@ -41,7 +41,7 @@ pub fn generate( #non_delegating_iunknown_field_ident: ::VPtr, // Non-reference counted interface pointer to outer IUnknown. #iunknown_to_use_field_ident: *mut ::VPtr, - #ref_count_ident: u32, + #ref_count_ident: std::cell::Cell, #(#aggregates,)* #fields } diff --git a/macros/aggr_co_class/src/com_struct_impl.rs b/macros/aggr_co_class/src/com_struct_impl.rs index 2b786e3..d52c931 100644 --- a/macros/aggr_co_class/src/com_struct_impl.rs +++ b/macros/aggr_co_class/src/com_struct_impl.rs @@ -76,22 +76,22 @@ fn gen_inner_iunknown_fns( quote!( #inner_query_interface - pub(crate) fn inner_add_ref(&mut self) -> u32 { - self.#ref_count_ident += 1; - println!("Count now {}", self.#ref_count_ident); - self.#ref_count_ident + pub(crate) fn inner_add_ref(&self) -> u32 { + let value = self.#ref_count_ident.get().checked_add(1).expect("Overflow of reference count"); + self.#ref_count_ident.set(value); + println!("Count now {}", value); + value } - pub(crate) fn inner_release(&mut self) -> u32 { - self.#ref_count_ident -= 1; - println!("Count now {}", self.#ref_count_ident); - let count = self.#ref_count_ident; - if count == 0 { + pub(crate) fn inner_release(&self) -> u32 { + let value = self.#ref_count_ident.get().checked_sub(1).expect("Underflow of reference count"); + println!("Count now {}", value); + self.#ref_count_ident.set(value); + if value == 0 { println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); - // drop(self) unsafe { Box::from_raw(self as *const _ as *mut #struct_ident); } } - count + value } ) } @@ -147,7 +147,7 @@ fn gen_inner_query_interface( }); quote!( - pub(crate) fn inner_query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> HRESULT { + pub(crate) fn inner_query_interface(&self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> HRESULT { println!("Non delegating QI"); unsafe { @@ -258,7 +258,7 @@ fn gen_allocate_fn( #(#base_fields,)* #non_delegating_iunknown_field_ident, #iunknown_to_use_field_ident: std::ptr::null_mut::<::VPtr>(), - #ref_count_ident: 0, + #ref_count_ident: std::cell::Cell::new(0), #(#aggregate_inits,)* #(#field_idents)* }; diff --git a/macros/aggr_co_class/src/iunknown_impl.rs b/macros/aggr_co_class/src/iunknown_impl.rs index 35b2442..3c7affc 100644 --- a/macros/aggr_co_class/src/iunknown_impl.rs +++ b/macros/aggr_co_class/src/iunknown_impl.rs @@ -12,32 +12,33 @@ use syn::ItemStruct; pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { let struct_ident = &struct_item.ident; let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); + let ptr_casting = quote! { as *const winapi::ctypes::c_void as *mut winapi::ctypes::c_void }; quote!( impl com::IUnknown for #struct_ident { unsafe fn query_interface( - &mut self, + &self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void ) -> winapi::shared::winerror::HRESULT { println!("Delegating QI"); - let mut iunknown_to_use: com::ComPtr = com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void); + let mut iunknown_to_use: com::ComPtr = com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting); let hr = iunknown_to_use.query_interface(riid, ppv); core::mem::forget(iunknown_to_use); hr } - fn add_ref(&mut self) -> u32 { - let mut iunknown_to_use: com::ComPtr = unsafe { com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void) }; + fn add_ref(&self) -> u32 { + let mut iunknown_to_use: com::ComPtr = unsafe { com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting) }; let res = iunknown_to_use.add_ref(); core::mem::forget(iunknown_to_use); res } - unsafe fn release(&mut self) -> u32 { - let mut iunknown_to_use: com::ComPtr = com::ComPtr::new(self.#iunknown_to_use_field_ident as *mut winapi::ctypes::c_void); + unsafe fn release(&self) -> u32 { + let mut iunknown_to_use: com::ComPtr = com::ComPtr::new(self.#iunknown_to_use_field_ident #ptr_casting); let res = iunknown_to_use.release(); core::mem::forget(iunknown_to_use); diff --git a/macros/co_class/src/class_factory.rs b/macros/co_class/src/class_factory.rs index 91849bb..6065c76 100644 --- a/macros/co_class/src/class_factory.rs +++ b/macros/co_class/src/class_factory.rs @@ -18,7 +18,7 @@ pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { impl com::IClassFactory for #class_factory_ident { unsafe fn create_instance( - &mut self, + &self, aggr: *mut ::VPtr, riid: winapi::shared::guiddef::REFIID, ppv: *mut *mut winapi::ctypes::c_void, @@ -55,7 +55,7 @@ pub fn gen_class_factory_struct_definition(class_factory_ident: &Ident) -> Helpe #[repr(C)] pub struct #class_factory_ident { inner: ::VPtr, - #ref_count_ident: u32, + #ref_count_ident: std::cell::Cell, } } } @@ -63,7 +63,7 @@ pub fn gen_class_factory_struct_definition(class_factory_ident: &Ident) -> Helpe pub fn gen_lock_server() -> HelperTokenStream { quote! { // TODO: Implement correctly - fn lock_server(&mut self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT { + fn lock_server(&self, _increment: winapi::shared::minwindef::BOOL) -> winapi::shared::winerror::HRESULT { println!("LockServer called"); winapi::shared::winerror::S_OK } @@ -85,7 +85,7 @@ pub fn gen_iunknown_impl(class_factory_ident: &Ident) -> HelperTokenStream { fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream { quote! { - unsafe fn query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT { + unsafe fn query_interface(&self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT { // Bringing trait into scope to access add_ref method. use com::IUnknown; @@ -117,7 +117,7 @@ pub fn gen_class_factory_impl(class_factory_ident: &Ident) -> HelperTokenStream let vptr = Box::into_raw(Box::new(class_vtable)); let class_factory = #class_factory_ident { inner: vptr, - #ref_count_ident: 0, + #ref_count_ident: std::cell::Cell::new(0), }; Box::new(class_factory) } diff --git a/macros/co_class/src/com_struct.rs b/macros/co_class/src/com_struct.rs index 5e74bd2..64b924a 100644 --- a/macros/co_class/src/com_struct.rs +++ b/macros/co_class/src/com_struct.rs @@ -40,7 +40,7 @@ pub fn generate( #[repr(C)] #vis struct #struct_ident { #(#bases_interface_idents,)* - #ref_count_ident: u32, + #ref_count_ident: std::cell::Cell, #(#aggregates,)* #fields } diff --git a/macros/co_class/src/com_struct_impl.rs b/macros/co_class/src/com_struct_impl.rs index 43e171a..f2c97a4 100644 --- a/macros/co_class/src/com_struct_impl.rs +++ b/macros/co_class/src/com_struct_impl.rs @@ -60,7 +60,7 @@ pub fn generate( #(#base_inits)* let out = #struct_ident { #(#base_fields,)* - #ref_count_ident: 0, + #ref_count_ident: std::cell::Cell::new(0), #(#aggregate_inits,)* #(#field_idents)* }; diff --git a/macros/co_class/src/iunknown_impl.rs b/macros/co_class/src/iunknown_impl.rs index 0525164..730bd60 100644 --- a/macros/co_class/src/iunknown_impl.rs +++ b/macros/co_class/src/iunknown_impl.rs @@ -29,10 +29,11 @@ pub fn generate( pub fn gen_add_ref() -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); quote! { - fn add_ref(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - self.#ref_count_ident + fn add_ref(&self) -> u32 { + let value = self.#ref_count_ident.get().checked_add(1).expect("Overflow of reference count"); + self.#ref_count_ident.set(value); + println!("Count now {}", value); + value } } } @@ -40,10 +41,11 @@ pub fn gen_add_ref() -> HelperTokenStream { pub fn gen_release(struct_ident: &Ident) -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); quote! { - unsafe fn release(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - let count = self.#ref_count_ident; + unsafe fn release(&self) -> u32 { + let value = self.#ref_count_ident.get().checked_sub(1).expect("Underflow of reference count"); + self.#ref_count_ident.set(value); + let count = self.#ref_count_ident.get(); + println!("Count now {}", count); if count == 0 { println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); Box::from_raw(self as *const _ as *mut #struct_ident); @@ -67,7 +69,7 @@ fn gen_query_interface( quote!( unsafe fn query_interface( - &mut self, + &self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void ) -> winapi::shared::winerror::HRESULT { diff --git a/macros/com_interface_attribute/src/vtable_macro.rs b/macros/com_interface_attribute/src/vtable_macro.rs index e422382..6637070 100644 --- a/macros/com_interface_attribute/src/vtable_macro.rs +++ b/macros/com_interface_attribute/src/vtable_macro.rs @@ -96,7 +96,7 @@ fn gen_vtable_function( let return_type = &fun.output; quote! { unsafe extern "stdcall" fn #function_ident(#(#params)*) #return_type { - let this = arg0.sub(O::VALUE) as *mut C; + let this = arg0.sub(O::VALUE) as *const C as *mut C; (*this).#method_name(#(#args)*) } } diff --git a/src/comptr.rs b/src/comptr.rs index 533f786..7d84c12 100644 --- a/src/comptr.rs +++ b/src/comptr.rs @@ -37,13 +37,13 @@ impl ComPtr { self.ptr } - fn cast_and_add_ref(&mut self) { + fn cast_and_add_ref(&self) { unsafe { (*(self as *const _ as *mut ComPtr)).add_ref(); } } - pub fn get_interface(&mut self) -> Option> { + pub fn get_interface(&self) -> Option> { let mut ppv = std::ptr::null_mut::(); let hr = unsafe { (*(self as *const _ as *mut ComPtr)) @@ -59,7 +59,6 @@ impl ComPtr { impl Drop for ComPtr { fn drop(&mut self) { - println!("Dropped!"); unsafe { (*(self as *const _ as *mut ComPtr)).release(); } @@ -68,7 +67,7 @@ impl Drop for ComPtr { impl Clone for ComPtr { fn clone(&self) -> Self { - let mut new_ptr = ComPtr { + let new_ptr = ComPtr { ptr: self.ptr, phantom: PhantomData, }; diff --git a/src/iclassfactory.rs b/src/iclassfactory.rs index 978495d..a9dbeff 100644 --- a/src/iclassfactory.rs +++ b/src/iclassfactory.rs @@ -17,16 +17,16 @@ use crate::{ #[com_interface(00000001-0000-0000-c000-000000000046)] pub trait IClassFactory: IUnknown { unsafe fn create_instance( - &mut self, + &self, aggr: *mut IUnknownVPtr, riid: REFIID, ppv: *mut *mut c_void, ) -> HRESULT; - fn lock_server(&mut self, increment: BOOL) -> HRESULT; + fn lock_server(&self, increment: BOOL) -> HRESULT; } impl ComPtr { - pub fn get_instance(&mut self) -> Option> { + pub fn get_instance(&self) -> Option> { let mut ppv = std::ptr::null_mut::(); let aggr = std::ptr::null_mut(); let hr = unsafe { self.create_instance(aggr, &T::IID as *const IID, &mut ppv) }; diff --git a/src/iunknown.rs b/src/iunknown.rs index aff4ef4..1237bed 100644 --- a/src/iunknown.rs +++ b/src/iunknown.rs @@ -14,7 +14,7 @@ pub trait IUnknown { /// /// [`QueryInterface` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void) /// [`IUnknown::get_interface`]: trait.IUnknown.html#method.get_interface - unsafe fn query_interface(&mut self, riid: REFIID, ppv: *mut *mut c_void) -> HRESULT; + unsafe fn query_interface(&self, riid: REFIID, ppv: *mut *mut c_void) -> HRESULT; /// The COM [`AddRef` Method] /// @@ -23,7 +23,7 @@ pub trait IUnknown { /// /// [`AddRef` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref /// [`ComPtr`]: struct.ComPtr.html - fn add_ref(&mut self) -> u32; + fn add_ref(&self) -> u32; /// The COM [`Release` Method] /// @@ -32,5 +32,5 @@ pub trait IUnknown { /// /// [`Release` Method]: https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release /// [`ComPtr`]: struct.ComPtr.html - unsafe fn release(&mut self) -> u32; + unsafe fn release(&self) -> u32; } From aca1cbd00f542e1504240f667493be399945eb4a Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Fri, 13 Sep 2019 15:41:25 +0200 Subject: [PATCH 6/7] Ensure interface methods only take &self --- examples/aggregation/client/src/main.rs | 6 +++--- examples/aggregation/interface/src/ifile_manager.rs | 2 +- .../aggregation/interface/src/ilocal_file_manager.rs | 2 +- examples/aggregation/server/src/local_file_manager.rs | 2 +- .../aggregation/server/src/windows_file_manager.rs | 2 +- macros/com_interface_attribute/src/vtable.rs | 10 +++++++++- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/aggregation/client/src/main.rs b/examples/aggregation/client/src/main.rs index b2e686f..481633d 100644 --- a/examples/aggregation/client/src/main.rs +++ b/examples/aggregation/client/src/main.rs @@ -21,7 +21,7 @@ fn main() { fn run_aggr_test(runtime: Runtime) { let result = runtime.create_instance::(&CLSID_WINDOWS_FILE_MANAGER_CLASS); - let mut filemanager = match result { + let filemanager = match result { Ok(filemanager) => filemanager, Err(e) => { println!("Failed to get filemanager, {:x}", e as u32); @@ -32,7 +32,7 @@ fn run_aggr_test(runtime: Runtime) { filemanager.delete_all(); let result = filemanager.get_interface::(); - let mut lfm = match result { + let lfm = match result { Some(lfm) => lfm, None => { println!("Failed to get Local File Manager!"); @@ -43,7 +43,7 @@ fn run_aggr_test(runtime: Runtime) { lfm.delete_local(); let result = runtime.create_instance::(&CLSID_LOCAL_FILE_MANAGER_CLASS); - let mut localfilemanager = match result { + let localfilemanager = match result { Ok(localfilemanager) => localfilemanager, Err(e) => { println!("Failed to get localfilemanager, {:x}", e as u32); diff --git a/examples/aggregation/interface/src/ifile_manager.rs b/examples/aggregation/interface/src/ifile_manager.rs index da581ac..3914942 100644 --- a/examples/aggregation/interface/src/ifile_manager.rs +++ b/examples/aggregation/interface/src/ifile_manager.rs @@ -4,5 +4,5 @@ use winapi::shared::winerror::HRESULT; #[com_interface(25A41124-23D0-46BE-8351-044889D5E37E)] pub trait IFileManager: IUnknown { - fn delete_all(&mut self) -> HRESULT; + fn delete_all(&self) -> HRESULT; } diff --git a/examples/aggregation/interface/src/ilocal_file_manager.rs b/examples/aggregation/interface/src/ilocal_file_manager.rs index 7971a14..f2d663e 100644 --- a/examples/aggregation/interface/src/ilocal_file_manager.rs +++ b/examples/aggregation/interface/src/ilocal_file_manager.rs @@ -4,5 +4,5 @@ use winapi::shared::winerror::HRESULT; #[com_interface(4FC333E3-C389-4C48-B108-7895B0AF21AD)] pub trait ILocalFileManager: IUnknown { - fn delete_local(&mut self) -> HRESULT; + fn delete_local(&self) -> HRESULT; } diff --git a/examples/aggregation/server/src/local_file_manager.rs b/examples/aggregation/server/src/local_file_manager.rs index 5de6906..1f9de07 100644 --- a/examples/aggregation/server/src/local_file_manager.rs +++ b/examples/aggregation/server/src/local_file_manager.rs @@ -11,7 +11,7 @@ pub struct LocalFileManager { } impl ILocalFileManager for LocalFileManager { - fn delete_local(&mut self) -> HRESULT { + fn delete_local(&self) -> HRESULT { println!("Deleting Locally..."); NOERROR } diff --git a/examples/aggregation/server/src/windows_file_manager.rs b/examples/aggregation/server/src/windows_file_manager.rs index 07ff88c..9cb28b0 100644 --- a/examples/aggregation/server/src/windows_file_manager.rs +++ b/examples/aggregation/server/src/windows_file_manager.rs @@ -23,7 +23,7 @@ pub struct WindowsFileManager { } impl IFileManager for WindowsFileManager { - fn delete_all(&mut self) -> HRESULT { + fn delete_all(&self) -> HRESULT { println!("Deleting all by delegating to Local and Remote File Managers..."); NOERROR } diff --git a/macros/com_interface_attribute/src/vtable.rs b/macros/com_interface_attribute/src/vtable.rs index b081a67..7a188c1 100644 --- a/macros/com_interface_attribute/src/vtable.rs +++ b/macros/com_interface_attribute/src/vtable.rs @@ -102,7 +102,15 @@ fn gen_raw_params(interface_ident: &Ident, method: &TraitItemMethod) -> HelperTo let vptr_ident = vptr::ident(&interface_ident.to_string()); for param in method.sig.inputs.iter() { match param { - FnArg::Receiver(_n) => { + FnArg::Receiver(s) => { + assert!( + s.reference.is_some(), + "COM interface methods cannot take ownership of self" + ); + assert!( + s.mutability.is_none(), + "COM interface methods cannot take mutable reference to self" + ); params.push(quote!( *mut #vptr_ident, )); From 8d497eac3290646d7e3a5d7b185c10081e8f6a73 Mon Sep 17 00:00:00 2001 From: adrianwithah Date: Fri, 13 Sep 2019 16:53:43 +0100 Subject: [PATCH 7/7] Refactor IUnknown::release method, remove Drop impl --- macros/aggr_co_class/src/class_factory.rs | 8 +- macros/aggr_co_class/src/com_struct_impl.rs | 50 +++++--- macros/aggr_co_class/src/drop_impl.rs | 35 ------ macros/aggr_co_class/src/lib.rs | 8 +- macros/co_class/src/class_factory.rs | 97 ++++++++++++--- macros/co_class/src/drop_impl.rs | 53 --------- macros/co_class/src/iunknown_impl.rs | 124 ++++++++++++++++---- macros/co_class/src/lib.rs | 10 +- 8 files changed, 229 insertions(+), 156 deletions(-) delete mode 100644 macros/aggr_co_class/src/drop_impl.rs delete mode 100644 macros/co_class/src/drop_impl.rs diff --git a/macros/aggr_co_class/src/class_factory.rs b/macros/aggr_co_class/src/class_factory.rs index c31fc95..255f251 100644 --- a/macros/aggr_co_class/src/class_factory.rs +++ b/macros/aggr_co_class/src/class_factory.rs @@ -5,14 +5,18 @@ use syn::ItemStruct; // We manually generate a ClassFactory without macros, otherwise // it leads to an infinite loop. pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { + + let base_interface_idents = co_class::class_factory::get_class_factory_base_interface_idents(); + let aggr_map = co_class::class_factory::get_class_factory_aggr_map(); + let struct_ident = &struct_item.ident; let class_factory_ident = macro_utils::class_factory_ident(&struct_ident); let struct_definition = co_class::class_factory::gen_class_factory_struct_definition(&class_factory_ident); let lock_server = co_class::class_factory::gen_lock_server(); - let iunknown_impl = co_class::class_factory::gen_iunknown_impl(&class_factory_ident); - let class_factory_impl = co_class::class_factory::gen_class_factory_impl(&class_factory_ident); + let iunknown_impl = co_class::class_factory::gen_iunknown_impl(&base_interface_idents, &aggr_map, &class_factory_ident); + let class_factory_impl = co_class::class_factory::gen_class_factory_impl(&base_interface_idents, &class_factory_ident); quote! { #struct_definition diff --git a/macros/aggr_co_class/src/com_struct_impl.rs b/macros/aggr_co_class/src/com_struct_impl.rs index 13bb7dd..d02d4ce 100644 --- a/macros/aggr_co_class/src/com_struct_impl.rs +++ b/macros/aggr_co_class/src/com_struct_impl.rs @@ -17,7 +17,7 @@ pub fn generate( let struct_ident = &struct_item.ident; let allocate_fn = gen_allocate_fn(aggr_map, base_interface_idents, struct_item); let set_iunknown_fn = gen_set_iunknown_fn(); - let inner_iunknown_fns = gen_inner_iunknown_fns(base_interface_idents, aggr_map, struct_item); + let inner_iunknown_fns = gen_inner_iunknown_fns(base_interface_idents, aggr_map, struct_ident); let get_class_object_fn = co_class::com_struct_impl::gen_get_class_object_fn(struct_item); let set_aggregate_fns = co_class::com_struct_impl::gen_set_aggregate_fns(aggr_map); @@ -54,12 +54,11 @@ fn gen_set_iunknown_fn() -> HelperTokenStream { fn gen_inner_iunknown_fns( base_interface_idents: &[Ident], aggr_map: &HashMap>, - struct_item: &ItemStruct, + struct_ident: &Ident, ) -> HelperTokenStream { - let struct_ident = &struct_item.ident; let inner_query_interface = gen_inner_query_interface(base_interface_idents, aggr_map); let inner_add_ref = gen_inner_add_ref(); - let inner_release = gen_inner_release(struct_ident); + let inner_release = gen_inner_release(base_interface_idents, aggr_map, struct_ident); quote!( #inner_query_interface @@ -69,32 +68,49 @@ fn gen_inner_iunknown_fns( } pub fn gen_inner_add_ref() -> HelperTokenStream { - let ref_count_ident = macro_utils::ref_count_ident(); + let add_ref_implementation = co_class::iunknown_impl::gen_add_ref_implementation(); + quote! { pub(crate) fn inner_add_ref(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - self.#ref_count_ident + #add_ref_implementation } } } -pub fn gen_inner_release(struct_ident: &Ident) -> HelperTokenStream { +pub fn gen_inner_release( + base_interface_idents: &[Ident], + aggr_map: &HashMap>, + struct_ident: &Ident, +) -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); + + let release_decrement = co_class::iunknown_impl::gen_release_decrement(&ref_count_ident); + let release_assign_new_count_to_var = co_class::iunknown_impl::gen_release_assign_new_count_to_var(&ref_count_ident, &ref_count_ident); + let release_new_count_var_zero_check = co_class::iunknown_impl::gen_new_count_var_zero_check(&ref_count_ident); + let release_drops = co_class::iunknown_impl::gen_release_drops(base_interface_idents, aggr_map, struct_ident); + let non_delegating_iunknown_drop = gen_non_delegating_iunknown_drop(); + quote! { - pub(crate) unsafe fn inner_release(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - let count = self.#ref_count_ident; - if count == 0 { - println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); - Box::from_raw(self as *const _ as *mut #struct_ident); + unsafe fn inner_release(&mut self) -> u32 { + #release_decrement + #release_assign_new_count_to_var + if #release_new_count_var_zero_check { + #non_delegating_iunknown_drop + #release_drops } - count + + #ref_count_ident } } } +fn gen_non_delegating_iunknown_drop() -> HelperTokenStream { + let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); + quote!( + Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut ::VTable); + ) +} + /// Non-delegating query interface fn gen_inner_query_interface( base_interface_idents: &[Ident], diff --git a/macros/aggr_co_class/src/drop_impl.rs b/macros/aggr_co_class/src/drop_impl.rs deleted file mode 100644 index cae5a2e..0000000 --- a/macros/aggr_co_class/src/drop_impl.rs +++ /dev/null @@ -1,35 +0,0 @@ -use proc_macro2::TokenStream as HelperTokenStream; -use quote::quote; -use std::collections::HashMap; -use syn::{Ident, ItemStruct}; - -pub fn generate( - aggr_map: &HashMap>, - base_interface_idents: &[Ident], - struct_item: &ItemStruct, -) -> HelperTokenStream { - let struct_ident = &struct_item.ident; - - let aggregate_drops = co_class::drop_impl::gen_aggregate_drops(aggr_map); - let vptr_drops = co_class::drop_impl::gen_vptr_drops(base_interface_idents); - let non_delegating_iunknown_drop = gen_non_delegating_iunknown_drop(); - - quote!( - impl std::ops::Drop for #struct_ident { - fn drop(&mut self) { - let _ = unsafe { - #aggregate_drops - #vptr_drops - #non_delegating_iunknown_drop - }; - } - } - ) -} - -fn gen_non_delegating_iunknown_drop() -> HelperTokenStream { - let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); - quote!( - Box::from_raw(self.#non_delegating_iunknown_field_ident as *mut ::VTable) - ) -} diff --git a/macros/aggr_co_class/src/lib.rs b/macros/aggr_co_class/src/lib.rs index 63d7df3..5d2dacd 100644 --- a/macros/aggr_co_class/src/lib.rs +++ b/macros/aggr_co_class/src/lib.rs @@ -7,7 +7,6 @@ use std::iter::FromIterator; mod class_factory; mod com_struct; mod com_struct_impl; -mod drop_impl; mod iunknown_impl; // Macro expansion entry point. @@ -18,15 +17,14 @@ pub fn expand_aggr_co_class(attr: TokenStream, item: TokenStream) -> TokenStream // Parse attributes let base_interface_idents = macro_utils::base_interface_idents(&attr_args); - let aggr_interface_idents = macro_utils::get_aggr_map(&attr_args); + let aggr_map = macro_utils::get_aggr_map(&attr_args); let mut out: Vec = Vec::new(); - out.push(com_struct::generate(&aggr_interface_idents, &base_interface_idents, &input).into()); + out.push(com_struct::generate(&aggr_map, &base_interface_idents, &input).into()); out.push( - com_struct_impl::generate(&base_interface_idents, &aggr_interface_idents, &input).into(), + com_struct_impl::generate(&base_interface_idents, &aggr_map, &input).into(), ); out.push(iunknown_impl::generate(&input).into()); - out.push(drop_impl::generate(&aggr_interface_idents, &base_interface_idents, &input).into()); out.push(class_factory::generate(&input).into()); TokenStream::from_iter(out) diff --git a/macros/co_class/src/class_factory.rs b/macros/co_class/src/class_factory.rs index 91849bb..bbb8704 100644 --- a/macros/co_class/src/class_factory.rs +++ b/macros/co_class/src/class_factory.rs @@ -1,17 +1,36 @@ use proc_macro2::{Ident, TokenStream as HelperTokenStream}; -use quote::quote; +use quote::{quote, format_ident,}; use syn::ItemStruct; +use std::collections::HashMap; + +fn get_iclass_factory_interface_ident() -> Ident { + format_ident!("IClassFactory") +} + +pub fn get_class_factory_base_interface_idents() -> Vec { + vec![get_iclass_factory_interface_ident()] +} + +pub fn get_class_factory_aggr_map() -> HashMap> { + HashMap::new() +} // We manually generate a ClassFactory without macros, otherwise // it leads to an infinite loop. pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { + // Manually define base_interface_idents and aggr_map usually obtained by + // parsing attributes. + + let base_interface_idents = get_class_factory_base_interface_idents(); + let aggr_map = get_class_factory_aggr_map(); + let struct_ident = &struct_item.ident; let class_factory_ident = macro_utils::class_factory_ident(&struct_ident); let struct_definition = gen_class_factory_struct_definition(&class_factory_ident); let lock_server = gen_lock_server(); - let iunknown_impl = gen_iunknown_impl(&class_factory_ident); - let class_factory_impl = gen_class_factory_impl(&class_factory_ident); + let iunknown_impl = gen_iunknown_impl(&base_interface_idents, &aggr_map, &class_factory_ident); + let class_factory_impl = gen_class_factory_impl(&base_interface_idents, &class_factory_ident); quote! { #struct_definition @@ -49,13 +68,16 @@ pub fn generate(struct_item: &ItemStruct) -> HelperTokenStream { } } +// Can't use gen_base_fields here, since user might not have imported com::IClassFactory. pub fn gen_class_factory_struct_definition(class_factory_ident: &Ident) -> HelperTokenStream { - let ref_count_ident = macro_utils::ref_count_ident(); + let ref_count_field = crate::com_struct::gen_ref_count_field(); + let interface_ident = get_iclass_factory_interface_ident(); + let vptr_field_ident = macro_utils::vptr_field_ident(&interface_ident); quote! { #[repr(C)] pub struct #class_factory_ident { - inner: ::VPtr, - #ref_count_ident: u32, + #vptr_field_ident: ::VPtr, + #ref_count_field } } } @@ -70,10 +92,14 @@ pub fn gen_lock_server() -> HelperTokenStream { } } -pub fn gen_iunknown_impl(class_factory_ident: &Ident) -> HelperTokenStream { +pub fn gen_iunknown_impl( + base_interface_idents: &[Ident], + aggr_map: &HashMap>, + class_factory_ident: &Ident, +) -> HelperTokenStream { let query_interface = gen_query_interface(class_factory_ident); let add_ref = crate::iunknown_impl::gen_add_ref(); - let release = crate::iunknown_impl::gen_release(class_factory_ident); + let release = gen_release(&base_interface_idents, &aggr_map, class_factory_ident); quote! { impl com::IUnknown for #class_factory_ident { #query_interface @@ -83,7 +109,36 @@ pub fn gen_iunknown_impl(class_factory_ident: &Ident) -> HelperTokenStream { } } +pub fn gen_release( + base_interface_idents: &[Ident], + aggr_map: &HashMap>, + struct_ident: &Ident, +) -> HelperTokenStream { + let ref_count_ident = macro_utils::ref_count_ident(); + + let release_decrement = crate::iunknown_impl::gen_release_decrement(&ref_count_ident); + let release_assign_new_count_to_var = crate::iunknown_impl::gen_release_assign_new_count_to_var(&ref_count_ident, &ref_count_ident); + let release_new_count_var_zero_check = crate::iunknown_impl::gen_new_count_var_zero_check(&ref_count_ident); + let release_drops = crate::iunknown_impl::gen_release_drops(base_interface_idents, aggr_map, struct_ident); + + quote! { + unsafe fn release(&mut self) -> u32 { + use com::IClassFactory; + + #release_decrement + #release_assign_new_count_to_var + if #release_new_count_var_zero_check { + #release_drops + } + + #ref_count_ident + } + } +} + fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream { + let vptr_field_ident = macro_utils::vptr_field_ident(&get_iclass_factory_interface_ident()); + quote! { unsafe fn query_interface(&mut self, riid: *const winapi::shared::guiddef::IID, ppv: *mut *mut winapi::ctypes::c_void) -> winapi::shared::winerror::HRESULT { // Bringing trait into scope to access add_ref method. @@ -93,7 +148,7 @@ fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream { let riid = &*riid; if winapi::shared::guiddef::IsEqualGUID(riid, &::IID) | winapi::shared::guiddef::IsEqualGUID(riid, &::IID) { - *ppv = &self.inner as *const _ as *mut winapi::ctypes::c_void; + *ppv = &self.#vptr_field_ident as *const _ as *mut winapi::ctypes::c_void; self.add_ref(); winapi::shared::winerror::NOERROR } else { @@ -104,22 +159,28 @@ fn gen_query_interface(class_factory_ident: &Ident) -> HelperTokenStream { } } -pub fn gen_class_factory_impl(class_factory_ident: &Ident) -> HelperTokenStream { - let ref_count_ident = macro_utils::ref_count_ident(); +pub fn gen_class_factory_impl( + base_interface_idents: &[Ident], + class_factory_ident: &Ident, +) -> HelperTokenStream { + let ref_count_field = crate::com_struct_impl::gen_allocate_ref_count_field(); + let base_fields = crate::com_struct_impl::gen_allocate_base_fields(base_interface_idents); + let base_inits = crate::com_struct_impl::gen_allocate_base_inits(class_factory_ident, base_interface_idents); + quote! { impl #class_factory_ident { pub(crate) fn new() -> Box<#class_factory_ident> { use com::IClassFactory; - println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident)); - let class_vtable = com::vtable!(#class_factory_ident: IClassFactory); // allocate directly since no macros generated an `allocate` function - let vptr = Box::into_raw(Box::new(class_vtable)); - let class_factory = #class_factory_ident { - inner: vptr, - #ref_count_ident: 0, + println!("Allocating new Vtable for {}...", stringify!(#class_factory_ident)); + #base_inits + + let out = #class_factory_ident { + #base_fields + #ref_count_field }; - Box::new(class_factory) + Box::new(out) } } } diff --git a/macros/co_class/src/drop_impl.rs b/macros/co_class/src/drop_impl.rs deleted file mode 100644 index a42bb12..0000000 --- a/macros/co_class/src/drop_impl.rs +++ /dev/null @@ -1,53 +0,0 @@ -use proc_macro2::TokenStream as HelperTokenStream; -use quote::quote; -use std::collections::HashMap; -use syn::{Ident, ItemStruct}; - -pub fn generate( - aggr_map: &HashMap>, - base_interface_idents: &[Ident], - struct_item: &ItemStruct, -) -> HelperTokenStream { - let struct_ident = &struct_item.ident; - - let vptr_drops = gen_vptr_drops(base_interface_idents); - let aggregate_drops = gen_aggregate_drops(aggr_map); - - quote!( - impl std::ops::Drop for #struct_ident { - fn drop(&mut self) { - use com::IUnknown; - - let _ = unsafe { - #aggregate_drops - #vptr_drops - }; - } - } - ) -} - -pub fn gen_aggregate_drops(aggr_map: &HashMap>) -> HelperTokenStream { - let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - if !self.#aggr_field_ident.is_null() { - let mut aggr_interface_ptr: com::ComPtr = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void); - aggr_interface_ptr.release(); - core::mem::forget(aggr_interface_ptr); - } - ) - }); - - quote!(#(#aggregate_drops)*) -} - -pub fn gen_vptr_drops(base_interface_idents: &[Ident]) -> HelperTokenStream { - let vptr_drops = base_interface_idents.iter().map(|base| { - let vptr_field_ident = macro_utils::vptr_field_ident(&base); - quote!( - Box::from_raw(self.#vptr_field_ident as *mut ::VTable); - ) - }); - - quote!(#(#vptr_drops)*) -} diff --git a/macros/co_class/src/iunknown_impl.rs b/macros/co_class/src/iunknown_impl.rs index 089aa30..6f2f83c 100644 --- a/macros/co_class/src/iunknown_impl.rs +++ b/macros/co_class/src/iunknown_impl.rs @@ -8,14 +8,14 @@ use syn::{Ident, ItemStruct}; /// any interfaces exposed through an aggregated object. pub fn generate( base_interface_idents: &[Ident], - aggr_interface_idents: &HashMap>, + aggr_map: &HashMap>, struct_item: &ItemStruct, ) -> HelperTokenStream { let struct_ident = &struct_item.ident; - let query_interface = gen_query_interface(base_interface_idents, aggr_interface_idents); + let query_interface = gen_query_interface(base_interface_idents, aggr_map); let add_ref = gen_add_ref(); - let release = gen_release(struct_ident); + let release = gen_release(base_interface_idents, aggr_map, struct_ident); quote!( impl com::IUnknown for #struct_ident { @@ -27,35 +27,119 @@ pub fn generate( } pub fn gen_add_ref() -> HelperTokenStream { - let ref_count_ident = macro_utils::ref_count_ident(); + let add_ref_implementation = gen_add_ref_implementation(); + quote! { fn add_ref(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - self.#ref_count_ident + #add_ref_implementation } } } -pub fn gen_release(struct_ident: &Ident) -> HelperTokenStream { +pub fn gen_add_ref_implementation() -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); + quote!( + self.#ref_count_ident = self.#ref_count_ident.checked_add(1).expect("Overflow of reference count"); + println!("Count now {}", self.#ref_count_ident); + self.#ref_count_ident + ) +} + +pub fn gen_release( + base_interface_idents: &[Ident], + aggr_map: &HashMap>, + struct_ident: &Ident, +) -> HelperTokenStream { + let ref_count_ident = macro_utils::ref_count_ident(); + + let release_decrement = gen_release_decrement(&ref_count_ident); + let release_assign_new_count_to_var = gen_release_assign_new_count_to_var(&ref_count_ident, &ref_count_ident); + let release_new_count_var_zero_check = gen_new_count_var_zero_check(&ref_count_ident); + let release_drops = gen_release_drops(base_interface_idents, aggr_map, struct_ident); + quote! { unsafe fn release(&mut self) -> u32 { - self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count"); - println!("Count now {}", self.#ref_count_ident); - let count = self.#ref_count_ident; - if count == 0 { - println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); - Box::from_raw(self as *const _ as *mut #struct_ident); + #release_decrement + #release_assign_new_count_to_var + if #release_new_count_var_zero_check { + #release_drops } - count + + #ref_count_ident } } } -fn gen_query_interface( +pub fn gen_release_drops( base_interface_idents: &[Ident], - aggr_interface_idents: &HashMap>, + aggr_map: &HashMap>, + struct_ident: &Ident, +) -> HelperTokenStream { + let vptr_drops = gen_vptr_drops(base_interface_idents); + let aggregate_drops = gen_aggregate_drops(aggr_map); + let com_object_drop = gen_com_object_drop(struct_ident); + + quote!( + #vptr_drops + #aggregate_drops + #com_object_drop + ) +} + +fn gen_aggregate_drops(aggr_map: &HashMap>) -> HelperTokenStream { + let aggregate_drops = aggr_map.iter().map(|(aggr_field_ident, _)| { + quote!( + if !self.#aggr_field_ident.is_null() { + let mut aggr_interface_ptr: com::ComPtr = com::ComPtr::new(self.#aggr_field_ident as *mut winapi::ctypes::c_void); + aggr_interface_ptr.release(); + core::mem::forget(aggr_interface_ptr); + } + ) + }); + + quote!(#(#aggregate_drops)*) +} + +fn gen_vptr_drops(base_interface_idents: &[Ident]) -> HelperTokenStream { + let vptr_drops = base_interface_idents.iter().map(|base| { + let vptr_field_ident = macro_utils::vptr_field_ident(&base); + quote!( + Box::from_raw(self.#vptr_field_ident as *mut ::VTable); + ) + }); + + quote!(#(#vptr_drops)*) +} + +fn gen_com_object_drop(struct_ident: &Ident) -> HelperTokenStream { + quote!( + println!("Count is 0 for {}. Freeing memory...", stringify!(#struct_ident)); + Box::from_raw(self as *const _ as *mut #struct_ident); + ) +} + +pub fn gen_release_decrement(ref_count_ident: &Ident) -> HelperTokenStream { + quote!( + self.#ref_count_ident = self.#ref_count_ident.checked_sub(1).expect("Underflow of reference count"); + println!("Count now {}", self.#ref_count_ident); + ) +} + +pub fn gen_release_assign_new_count_to_var(ref_count_ident: &Ident, new_count_ident: &Ident) -> HelperTokenStream { + quote!( + let #new_count_ident = self.#ref_count_ident; + ) +} + +pub fn gen_new_count_var_zero_check(new_count_ident: &Ident) -> HelperTokenStream { + quote!( + #new_count_ident == 0 + ) +} + +pub fn gen_query_interface( + base_interface_idents: &[Ident], + aggr_map: &HashMap>, ) -> HelperTokenStream { let first_vptr_field = macro_utils::vptr_field_ident(&base_interface_idents[0]); @@ -63,7 +147,7 @@ fn gen_query_interface( let base_match_arms = gen_base_match_arms(base_interface_idents); // Generate match arms for aggregated interfaces - let aggr_match_arms = gen_aggregate_match_arms(aggr_interface_idents); + let aggr_match_arms = gen_aggregate_match_arms(aggr_map); quote!( unsafe fn query_interface( @@ -106,9 +190,9 @@ pub fn gen_base_match_arms(base_interface_idents: &[Ident]) -> HelperTokenStream } pub fn gen_aggregate_match_arms( - aggr_interface_idents: &HashMap>, + aggr_map: &HashMap>, ) -> HelperTokenStream { - let aggr_match_arms = aggr_interface_idents.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| { + let aggr_match_arms = aggr_map.iter().map(|(aggr_field_ident, aggr_base_interface_idents)| { // Construct the OR match conditions for a single aggregated object. let first_base_interface_ident = &aggr_base_interface_idents[0]; diff --git a/macros/co_class/src/lib.rs b/macros/co_class/src/lib.rs index dd22f44..da3cac9 100644 --- a/macros/co_class/src/lib.rs +++ b/macros/co_class/src/lib.rs @@ -7,7 +7,6 @@ use std::iter::FromIterator; pub mod class_factory; pub mod com_struct; pub mod com_struct_impl; -pub mod drop_impl; pub mod iunknown_impl; // Macro expansion entry point. @@ -17,17 +16,16 @@ pub fn expand_co_class(attr: TokenStream, item: TokenStream) -> TokenStream { // Parse attributes let base_interface_idents = macro_utils::base_interface_idents(&attr_args); - let aggr_interface_idents = macro_utils::get_aggr_map(&attr_args); + let aggr_map = macro_utils::get_aggr_map(&attr_args); let mut out: Vec = Vec::new(); - out.push(com_struct::generate(&aggr_interface_idents, &base_interface_idents, &input).into()); + out.push(com_struct::generate(&aggr_map, &base_interface_idents, &input).into()); out.push( - com_struct_impl::generate(&aggr_interface_idents, &base_interface_idents, &input).into(), + com_struct_impl::generate(&aggr_map, &base_interface_idents, &input).into(), ); out.push( - iunknown_impl::generate(&base_interface_idents, &aggr_interface_idents, &input).into(), + iunknown_impl::generate(&base_interface_idents, &aggr_map, &input).into(), ); - out.push(drop_impl::generate(&aggr_interface_idents, &base_interface_idents, &input).into()); out.push(class_factory::generate(&input).into()); TokenStream::from_iter(out)