diff --git a/macros/aggr_co_class/src/com_struct.rs b/macros/aggr_co_class/src/com_struct.rs index 3d468d3..b75429a 100644 --- a/macros/aggr_co_class/src/com_struct.rs +++ b/macros/aggr_co_class/src/com_struct.rs @@ -14,36 +14,24 @@ pub fn generate( let struct_ident = &struct_item.ident; let vis = &struct_item.vis; - let bases_interface_idents = base_interface_idents.iter().map(|base| { - let field_ident = macro_utils::vptr_field_ident(&base); - quote!(#field_ident: ::VPtr) - }); + let base_fields = co_class::com_struct::gen_base_fields(base_interface_idents); + let ref_count_field = co_class::com_struct::gen_ref_count_field(); + let user_fields = co_class::com_struct::gen_user_fields(struct_item); + let aggregate_fields = co_class::com_struct::gen_aggregate_fields(aggr_map); - let ref_count_ident = macro_utils::ref_count_ident(); let non_delegating_iunknown_field_ident = macro_utils::non_delegating_iunknown_field_ident(); let iunknown_to_use_field_ident = macro_utils::iunknown_to_use_field_ident(); - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - - let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| { - quote!( - #aggr_field_ident: *mut ::VPtr - ) - }); - quote!( #[repr(C)] #vis struct #struct_ident { - #(#bases_interface_idents,)* + #base_fields #non_delegating_iunknown_field_ident: ::VPtr, // Non-reference counted interface pointer to outer IUnknown. #iunknown_to_use_field_ident: *mut ::VPtr, - #ref_count_ident: u32, - #(#aggregates,)* - #fields + #ref_count_field + #aggregate_fields + #user_fields } ) } diff --git a/macros/co_class/src/com_struct.rs b/macros/co_class/src/com_struct.rs index 5e74bd2..4a7fa3b 100644 --- a/macros/co_class/src/com_struct.rs +++ b/macros/co_class/src/com_struct.rs @@ -18,31 +18,50 @@ pub fn generate( let struct_ident = &struct_item.ident; let vis = &struct_item.vis; + let base_fields = gen_base_fields(base_interface_idents); + let ref_count_field = gen_ref_count_field(); + let user_fields = gen_user_fields(struct_item); + let aggregate_fields = gen_aggregate_fields(aggr_map); + + quote!( + #[repr(C)] + #vis struct #struct_ident { + #base_fields + #ref_count_field + #aggregate_fields + #user_fields + } + ) +} + +pub fn gen_base_fields(base_interface_idents: &[Ident]) -> HelperTokenStream { let bases_interface_idents = base_interface_idents.iter().map(|base| { let field_ident = macro_utils::vptr_field_ident(&base); quote!(#field_ident: ::VPtr) }); + quote!(#(#bases_interface_idents,)*) +} +pub fn gen_ref_count_field() -> HelperTokenStream { let ref_count_ident = macro_utils::ref_count_ident(); + quote!(#ref_count_ident: u32,) +} - let fields = match &struct_item.fields { - Fields::Named(f) => &f.named, - _ => panic!("Found non Named fields in struct."), - }; - +pub fn gen_aggregate_fields(aggr_map: &HashMap>) -> HelperTokenStream { let aggregates = aggr_map.iter().map(|(aggr_field_ident, _)| { quote!( #aggr_field_ident: *mut ::VPtr ) }); - quote!( - #[repr(C)] - #vis struct #struct_ident { - #(#bases_interface_idents,)* - #ref_count_ident: u32, - #(#aggregates,)* - #fields - } - ) + quote!(#(#aggregates,)*) +} + +pub fn gen_user_fields(struct_item: &ItemStruct) -> HelperTokenStream { + let fields = match &struct_item.fields { + Fields::Named(f) => &f.named, + _ => panic!("Found non Named fields in struct."), + }; + + quote!(#fields) } diff --git a/macros/co_class/src/com_struct_impl.rs b/macros/co_class/src/com_struct_impl.rs index 93f5615..a7575d6 100644 --- a/macros/co_class/src/com_struct_impl.rs +++ b/macros/co_class/src/com_struct_impl.rs @@ -46,7 +46,9 @@ pub fn gen_allocate_fn( quote!( fn allocate(#allocate_parameters) -> Box<#struct_ident> { println!("Allocating new VTable for {}", stringify!(#struct_ident)); + #base_inits + let out = #struct_ident { #base_fields #ref_count_field diff --git a/macros/co_class/src/lib.rs b/macros/co_class/src/lib.rs index a134d5a..dd22f44 100644 --- a/macros/co_class/src/lib.rs +++ b/macros/co_class/src/lib.rs @@ -5,9 +5,9 @@ use syn::{AttributeArgs, ItemStruct}; use std::iter::FromIterator; pub mod class_factory; -mod com_struct; +pub mod com_struct; pub mod com_struct_impl; -mod drop_impl; +pub mod drop_impl; pub mod iunknown_impl; // Macro expansion entry point.