diff --git a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj index eb0d953..be288e6 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj +++ b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj @@ -32,7 +32,7 @@ - + diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs new file mode 100644 index 0000000..5deff69 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.CodeAnalysis; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; + +/// +/// Extension methods for the type. +/// +internal static class ISymbolExtensions +{ + /// + /// Gets the fully qualified name for a given symbol. + /// + /// The input instance. + /// The fully qualified name for . + public static string GetFullyQualifiedName(this ISymbol symbol) + { + return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + } + + /// + /// Checks whether or not a given type symbol has a specified full name. + /// + /// The input instance to check. + /// The full name to check. + /// Whether has a full name equals to . + public static bool HasFullyQualifiedName(this ISymbol symbol, string name) + { + return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == name; + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs new file mode 100644 index 0000000..be90232 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs @@ -0,0 +1,593 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; +using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; +using CommunityToolkit.Mvvm.SourceGenerators.Models; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace CommunityToolkit.Mvvm.SourceGenerators; + +/// +/// A source generator for generating command properties from annotated methods. +/// +[Generator(LanguageNames.CSharp)] +public sealed partial class ICommandGenerator2 : IIncrementalGenerator +{ + /// + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Get all method declarations with at least one attribute + IncrementalValuesProvider methodSymbols = + context.SyntaxProvider + .CreateSyntaxProvider( + static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 }, + static (context, _) => (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!); + + // Filter the methods using [ICommand] + IncrementalValuesProvider<(IMethodSymbol Symbol, AttributeData Attribute)> methodSymbolsWithAttributeData = + methodSymbols + .Select(static (item, _) => ( + item, + Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.ICommandAttribute") == true))) + .Where(static item => item.Attribute is not null)!; + + // Gather info for all annotated command methods + IncrementalValuesProvider> commandInfoWithErrors = + methodSymbolsWithAttributeData + .Select(static (item, _) => + { + CommandInfo? commandInfo = Execute.GetInfo(item.Symbol, item.Attribute, out ImmutableArray diagnostics); + + return new Result(commandInfo, diagnostics); + }); + + // Output the diagnostics + context.ReportDiagnostics(commandInfoWithErrors.Select(static (item, _) => item.Errors)); + + // Get the filtered sequence to enable caching + IncrementalValuesProvider commandInfo = + commandInfoWithErrors + .Select(static (item, _) => item.Value) + .Where(static item => item is not null)! + .WithComparer(CommandInfo.Comparer.Default); + + // Generate the commands + context.RegisterSourceOutput(commandInfo, static (context, item) => + { + ImmutableArray memberDeclarations = Execute.GetSyntax(item); + CompilationUnitSyntax compilationUnit = item.Hierarchy.GetCompilationUnit(memberDeclarations); + + context.AddSource( + hintName: $"{item.Hierarchy.FilenameHint}.{item.MethodName}.cs", + sourceText: SourceText.From(compilationUnit.ToFullString(), Encoding.UTF8)); + }); + } + + /// + /// A container for all the logic for . + /// + private static class Execute + { + /// + /// Processes a given target method. + /// + /// The input instance to process. + /// The instance the method was annotated with. + /// The resulting diagnostics from the processing operation. + /// The resulting instance for , if available. + public static CommandInfo? GetInfo(IMethodSymbol methodSymbol, AttributeData attributeData, out ImmutableArray diagnostics) + { + ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + + // Get the command field and property names + (string fieldName, string propertyName) = GetGeneratedFieldAndPropertyNames(methodSymbol); + + // Get the command type symbols + if (!TryMapCommandTypesFromMethod( + methodSymbol, + builder, + out string? commandInterfaceType, + out string? commandClassType, + out string? delegateType, + out ImmutableArray commandTypeArguments, + out ImmutableArray delegateTypeArguments)) + { + goto Failure; + } + + // Check the switch to allow concurrent executions + if (!TryGetAllowConcurrentExecutionsSwitch( + methodSymbol, + attributeData, + commandClassType, + builder, + out bool allowConcurrentExecutions)) + { + goto Failure; + } + + // Get the CanExecute expression type, if any + if (!TryGetCanExecuteExpressionType( + methodSymbol, + attributeData, + commandTypeArguments, + builder, + out string? canExecuteMemberName, + out CanExecuteExpressionType? canExecuteExpressionType)) + { + goto Failure; + } + + diagnostics = builder.ToImmutable(); + + return new( + HierarchyInfo.From(methodSymbol.ContainingType), + methodSymbol.Name, + fieldName, + propertyName, + commandInterfaceType, + commandClassType, + delegateType, + commandTypeArguments, + delegateTypeArguments, + canExecuteMemberName, + canExecuteExpressionType, + allowConcurrentExecutions); + + Failure: + diagnostics = builder.ToImmutable(); + + return null; + } + + /// + /// Creates the instances for a specified command. + /// + /// The input instance with the info to generate the command. + /// The instances for the input command. + public static ImmutableArray GetSyntax(CommandInfo commandInfo) + { + // Prepare all necessary type names with type arguments + string commandInterfaceTypeXmlName = commandInfo.CommandTypeArguments.IsEmpty + ? commandInfo.CommandInterfaceType + : commandInfo.CommandInterfaceType + "{T}"; + string commandClassTypeName = commandInfo.CommandTypeArguments.IsEmpty + ? commandInfo.CommandClassType + : $"{commandInfo.CommandClassType}<{string.Join(", ", commandInfo.CommandTypeArguments)}>"; + string commandInterfaceTypeName = commandInfo.CommandTypeArguments.IsEmpty + ? commandInfo.CommandInterfaceType + : $"{commandInfo.CommandInterfaceType}<{string.Join(", ", commandInfo.CommandTypeArguments)}>"; + string delegateTypeName = commandInfo.DelegateTypeArguments.IsEmpty + ? commandInfo.DelegateType + : $"{commandInfo.DelegateType}<{string.Join(", ", commandInfo.DelegateTypeArguments)}>"; + + // Construct the generated field as follows: + // + // The backing field for + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // private ? ; + FieldDeclarationSyntax fieldDeclaration = + FieldDeclaration( + VariableDeclaration(NullableType(IdentifierName(commandClassTypeName))) + .AddVariables(VariableDeclarator(Identifier(commandInfo.FieldName)))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword)) + .AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString())))))) + .WithOpenBracketToken(Token(TriviaList(Comment($"/// The backing field for .")), SyntaxKind.OpenBracketToken, TriviaList()))); + + // Prepares the argument to pass the underlying method to invoke + ImmutableArray.Builder commandCreationArguments = ImmutableArray.CreateBuilder(); + + // The first argument is the execute method, which is always present + commandCreationArguments.Add( + Argument( + ObjectCreationExpression(IdentifierName(delegateTypeName)) + .AddArgumentListArguments(Argument(IdentifierName(commandInfo.MethodName))))); + + // Get the can execute expression, if available + ExpressionSyntax? canExecuteExpression = commandInfo.CanExecuteExpressionType switch + { + // Create a lambda expression ignoring the input value: + // + // new (, _ => ()); + CanExecuteExpressionType.MethodInvocationLambdaWithDiscard => + SimpleLambdaExpression( + Parameter(Identifier(TriviaList(), SyntaxKind.UnderscoreToken, "_", "_", TriviaList()))) + .WithExpressionBody(InvocationExpression(IdentifierName(commandInfo.CanExecuteMemberName!))), + + // Create a lambda expression returning the property value: + // + // new (, () => ); + CanExecuteExpressionType.PropertyAccessLambda => + ParenthesizedLambdaExpression() + .WithExpressionBody(IdentifierName(commandInfo.CanExecuteMemberName!)), + + // Create a lambda expression again, but discarding the input value: + // + // new (, _ => ); + CanExecuteExpressionType.PropertyAccessLambdaWithDiscard => + SimpleLambdaExpression( + Parameter(Identifier(TriviaList(), SyntaxKind.UnderscoreToken, "_", "_", TriviaList()))) + .WithExpressionBody(IdentifierName(commandInfo.CanExecuteMemberName!)), + + // Create a method groupd expression, which will become: + // + // new (, ); + CanExecuteExpressionType.MethodGroup => IdentifierName(commandInfo.CanExecuteMemberName!), + _ => null + }; + + // Add the can execute expression to the arguments, if available + if (canExecuteExpression is not null) + { + commandCreationArguments.Add(Argument(canExecuteExpression)); + } + + // Disable concurrent executions, if requested + if (!commandInfo.AllowConcurrentExecutions) + { + commandCreationArguments.Add(Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))); + } + + // Construct the generated property as follows (the explicit delegate cast is needed to avoid overload resolution conflicts): + // + // Gets an and . + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // public => ??= new (); + PropertyDeclarationSyntax propertyDeclaration = + PropertyDeclaration( + IdentifierName(commandInterfaceTypeName), + Identifier(commandInfo.PropertyName)) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString())))))) + .WithOpenBracketToken(Token(TriviaList(Comment( + $"/// Gets an instance wrapping .")), + SyntaxKind.OpenBracketToken, + TriviaList())), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage"))))) + .WithExpressionBody( + ArrowExpressionClause( + AssignmentExpression( + SyntaxKind.CoalesceAssignmentExpression, + IdentifierName(commandInfo.FieldName), + ObjectCreationExpression(IdentifierName(commandClassTypeName)) + .AddArgumentListArguments(commandCreationArguments.ToArray())))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + return ImmutableArray.Create(fieldDeclaration, propertyDeclaration); + } + + /// + /// Get the generated field and property names for the input method. + /// + /// The input instance to process. + /// The generated field and property names for . + private static (string FieldName, string PropertyName) GetGeneratedFieldAndPropertyNames(IMethodSymbol methodSymbol) + { + string propertyName = methodSymbol.Name; + + if (methodSymbol.ReturnType.HasFullyQualifiedName("global::System.Threading.Tasks.Task") && + methodSymbol.Name.EndsWith("Async")) + { + propertyName = propertyName.Substring(0, propertyName.Length - "Async".Length); + } + + propertyName += "Command"; + + string fieldName = $"{char.ToLower(propertyName[0])}{propertyName.Substring(1)}"; + + return (fieldName, propertyName); + } + + /// + /// Gets the type symbols for the input method, if supported. + /// + /// The input instance to process. + /// The current collection of gathered diagnostics. + /// The command interface type name. + /// The command class type name. + /// The delegate type name for the wrapped method. + /// The type arguments for and , if any. + /// The type arguments for , if any. + /// Whether or not was valid and the requested types have been set. + private static bool TryMapCommandTypesFromMethod( + IMethodSymbol methodSymbol, + ImmutableArray.Builder diagnostics, + [NotNullWhen(true)] out string? commandInterfaceType, + [NotNullWhen(true)] out string? commandClassType, + [NotNullWhen(true)] out string? delegateType, + out ImmutableArray commandTypeArguments, + out ImmutableArray delegateTypeArguments) + { + // Map to IRelayCommand, RelayCommand, Action + if (methodSymbol.ReturnsVoid && methodSymbol.Parameters.Length == 0) + { + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.RelayCommand"; + delegateType = "global::System.Action"; + commandTypeArguments = ImmutableArray.Empty; + delegateTypeArguments = ImmutableArray.Empty; + + return true; + } + + // Map to IRelayCommand, RelayCommand, Action + if (methodSymbol.ReturnsVoid && + methodSymbol.Parameters.Length == 1 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } parameter) + { + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.RelayCommand"; + delegateType = "global::System.Action"; + commandTypeArguments = ImmutableArray.Create(parameter.Type.GetFullyQualifiedName()); + delegateTypeArguments = ImmutableArray.Create(parameter.Type.GetFullyQualifiedName()); + + return true; + } + + if (methodSymbol.ReturnType.HasFullyQualifiedName("global::System.Threading.Tasks.Task")) + { + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (methodSymbol.Parameters.Length == 0) + { + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand"; + delegateType = "global::System.Func"; + commandTypeArguments = ImmutableArray.Empty; + delegateTypeArguments = ImmutableArray.Create("global::System.Threading.Tasks.Task"); + + return true; + } + + if (methodSymbol.Parameters.Length == 1 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } singleParameter) + { + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (singleParameter.Type.HasFullyQualifiedName("global::System.Threading.CancellationToken")) + { + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand"; + delegateType = "global::System.Func"; + commandTypeArguments = ImmutableArray.Empty; + delegateTypeArguments = ImmutableArray.Create("global::System.Threading.CancellationToken", "global::System.Threading.Tasks.Task"); + + return true; + } + + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand"; + delegateType = "global::System.Func"; + commandTypeArguments = ImmutableArray.Create(singleParameter.Type.GetFullyQualifiedName()); + delegateTypeArguments = ImmutableArray.Create(singleParameter.Type.GetFullyQualifiedName(), "global::System.Threading.Tasks.Task"); + + return true; + } + + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (methodSymbol.Parameters.Length == 2 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } firstParameter && + methodSymbol.Parameters[1] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } secondParameter && + secondParameter.Type.HasFullyQualifiedName("global::System.Threading.CancellationToken")) + { + commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand"; + commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand"; + delegateType = "global::System.Func"; + commandTypeArguments = ImmutableArray.Create(firstParameter.Type.GetFullyQualifiedName()); + delegateTypeArguments = ImmutableArray.Create(firstParameter.Type.GetFullyQualifiedName(), secondParameter.Type.GetFullyQualifiedName(), "global::System.Threading.Tasks.Task"); + + return true; + } + } + + diagnostics.Add(InvalidICommandMethodSignatureError, methodSymbol, methodSymbol.ContainingType, methodSymbol); + + commandInterfaceType = null; + commandClassType = null; + delegateType = null; + commandTypeArguments = ImmutableArray.Empty; + delegateTypeArguments = ImmutableArray.Empty; + + return false; + } + + /// + /// Checks whether or not the user has requested to configure the handling of concurrent executions. + /// + /// The input instance to process. + /// The instance the method was annotated with. + /// The command class type name. + /// The current collection of gathered diagnostics. + /// Whether or not concurrent executions have been disabled. + /// Whether or not a value for could be retrieved successfully. + private static bool TryGetAllowConcurrentExecutionsSwitch( + IMethodSymbol methodSymbol, + AttributeData attributeData, + string commandClassType, + ImmutableArray.Builder diagnostics, + out bool allowConcurrentExecutions) + { + // Try to get the custom switch for concurrent executions. If the switch is not present, the + // default value is set to true, to avoid breaking backwards compatibility with the first release. + if (!attributeData.TryGetNamedArgument("AllowConcurrentExecutions", out allowConcurrentExecutions)) + { + allowConcurrentExecutions = true; + + return true; + } + + // If the current type is an async command type and concurrent execution is disabled, pass that value to the constructor. + // If concurrent executions are allowed, there is no need to add any additional argument, as that is the default value. + if (commandClassType is "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand") + { + return true; + } + else + { + diagnostics.Add(InvalidConcurrentExecutionsParameterError, methodSymbol, methodSymbol.ContainingType, methodSymbol); + + return false; + } + } + + /// + /// Tries to get the expression type for the "CanExecute" property, if available. + /// + /// The input instance to process. + /// The instance for . + /// The command type arguments, if any. + /// The current collection of gathered diagnostics. + /// The resulting can execute member name, if available. + /// The resulting expression type, if available. + /// Whether or not a value for and could be determined (may include ). + private static bool TryGetCanExecuteExpressionType( + IMethodSymbol methodSymbol, + AttributeData attributeData, + ImmutableArray commandTypeArguments, + ImmutableArray.Builder diagnostics, + out string? canExecuteMemberName, + out CanExecuteExpressionType? canExecuteExpressionType) + { + // Get the can execute member, if any + if (!attributeData.TryGetNamedArgument("CanExecute", out string? memberName)) + { + canExecuteMemberName = null; + canExecuteExpressionType = null; + + return true; + } + + if (memberName is null) + { + diagnostics.Add(InvalidCanExecuteMemberName, methodSymbol, memberName ?? string.Empty, methodSymbol.ContainingType); + + goto Failure; + } + + ImmutableArray canExecuteSymbols = methodSymbol.ContainingType!.GetMembers(memberName); + + if (canExecuteSymbols.IsEmpty) + { + diagnostics.Add(InvalidCanExecuteMemberName, methodSymbol, memberName, methodSymbol.ContainingType); + } + else if (canExecuteSymbols.Length > 1) + { + diagnostics.Add(MultipleCanExecuteMemberNameMatches, methodSymbol, memberName, methodSymbol.ContainingType); + } + else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType)) + { + canExecuteMemberName = memberName; + + return true; + } + else + { + diagnostics.Add(InvalidCanExecuteMember, methodSymbol, memberName, methodSymbol.ContainingType); + } + + Failure: + canExecuteMemberName = null; + canExecuteExpressionType = null; + + return false; + } + + /// + /// Gets the expression type for the can execute logic, if possible. + /// + /// The can execute member symbol (either a method or a property). + /// The type arguments for the command interface, if any. + /// The resulting can execute expression type, if available. + /// Whether or not was set and the input symbol was valid. + private static bool TryGetCanExecuteExpressionFromSymbol( + ISymbol canExecuteSymbol, + ImmutableArray commandTypeArguments, + [NotNullWhen(true)] out CanExecuteExpressionType? canExecuteExpressionType) + { + if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol) + { + // The return type must always be a bool + if (!canExecuteMethodSymbol.ReturnType.HasFullyQualifiedName("bool")) + { + goto Failure; + } + + // Parameterless methods are always valid + if (canExecuteMethodSymbol.Parameters.IsEmpty) + { + // If the command is generic, the input value is ignored + if (commandTypeArguments.Length > 0) + { + canExecuteExpressionType = CanExecuteExpressionType.MethodInvocationLambdaWithDiscard; + } + else + { + canExecuteExpressionType = CanExecuteExpressionType.MethodGroup; + } + + return true; + } + + // If the method has parameters, it has to have a single one matching the command type + if (canExecuteMethodSymbol.Parameters.Length == 1 && + commandTypeArguments.Length == 1 && + canExecuteMethodSymbol.Parameters[0].Type.HasFullyQualifiedName(commandTypeArguments[0])) + { + // Create a method group expression again + canExecuteExpressionType = CanExecuteExpressionType.MethodGroup; + + return true; + } + } + else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol) + { + // The property type must always be a bool + if (!canExecutePropertySymbol.Type.HasFullyQualifiedName("bool")) + { + goto Failure; + } + + if (commandTypeArguments.Length > 0) + { + canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambdaWithDiscard; + } + else + { + canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambda; + } + + return true; + } + + Failure: + canExecuteExpressionType = null; + + return false; + } + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs index d22d7c4..3198a0c 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs @@ -24,7 +24,6 @@ namespace CommunityToolkit.Mvvm.SourceGenerators; /// /// A source generator for generating command properties from annotated methods. /// -[Generator] public sealed partial class ICommandGenerator : ISourceGenerator { /// diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs new file mode 100644 index 0000000..887a422 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace CommunityToolkit.Mvvm.SourceGenerators.Input.Models; + +/// +/// A type describing the type of expression for the "CanExecute" property of a command. +/// +public enum CanExecuteExpressionType +{ + /// + /// A method invocation lambda with discard: _ => Method(). + /// + MethodInvocationLambdaWithDiscard, + + /// + /// A property access lambda: () => Property. + /// + PropertyAccessLambda, + + /// + /// A property access lambda with discard: _ => Property. + /// + PropertyAccessLambdaWithDiscard, + + /// + /// A method group expression: Method. + /// + MethodGroup +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs new file mode 100644 index 0000000..b3728c6 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Models; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Input.Models; + +/// +/// A model with gathered info on a given command method. +/// +/// The hierarchy info for the containing type. +/// The name of the target method. +/// The resulting field name for the generated command. +/// The resulting property name for the generated command. +/// The command interface type name. +/// The command class type name. +/// The delegate type name for the wrapped method. +/// The type arguments for and , if any. +/// The type arguments for , if any. +/// The member name for the can execute check, if available. +/// The can execute expression type, if available. +/// Whether or not concurrent executions have been disabled. +internal sealed record CommandInfo( + HierarchyInfo Hierarchy, + string MethodName, + string FieldName, + string PropertyName, + string CommandInterfaceType, + string CommandClassType, + string DelegateType, + ImmutableArray CommandTypeArguments, + ImmutableArray DelegateTypeArguments, + string? CanExecuteMemberName, + CanExecuteExpressionType? CanExecuteExpressionType, + bool AllowConcurrentExecutions) +{ + /// + /// An implementation for . + /// + public sealed class Comparer : IEqualityComparer + { + /// + /// The singleton instance. + /// + public static Comparer Default { get; } = new(); + + /// + public bool Equals(CommandInfo x, CommandInfo y) + { + if (x is null && y is null) + { + return true; + } + + if (x is null || y is null) + { + return false; + } + + if (ReferenceEquals(x, y)) + { + return true; + } + + return + HierarchyInfo.Comparer.Default.Equals(x.Hierarchy, y.Hierarchy) && + x.MethodName == y.MethodName && + x.FieldName == y.FieldName && + x.PropertyName == y.PropertyName && + x.CommandInterfaceType == y.CommandInterfaceType && + x.CommandClassType == y.CommandClassType && + x.DelegateType == y.DelegateType && + x.CommandTypeArguments.SequenceEqual(y.CommandTypeArguments) && + x.DelegateTypeArguments.SequenceEqual(y.CommandTypeArguments) && + x.CanExecuteMemberName == y.CanExecuteMemberName && + x.CanExecuteExpressionType == y.CanExecuteExpressionType && + x.AllowConcurrentExecutions == y.AllowConcurrentExecutions; + } + + /// + public int GetHashCode(CommandInfo obj) + { + HashCode hashCode = default; + + hashCode.Add(obj.Hierarchy, HierarchyInfo.Comparer.Default); + hashCode.Add(obj.MethodName); + hashCode.Add(obj.FieldName); + hashCode.Add(obj.PropertyName); + hashCode.Add(obj.CommandInterfaceType); + hashCode.Add(obj.CommandClassType); + hashCode.Add(obj.DelegateType); + hashCode.AddRange(obj.CommandTypeArguments); + hashCode.AddRange(obj.DelegateTypeArguments); + hashCode.Add(obj.CanExecuteMemberName); + hashCode.Add(obj.CanExecuteExpressionType); + hashCode.Add(obj.AllowConcurrentExecutions); + + return hashCode.ToHashCode(); + } + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs new file mode 100644 index 0000000..01ee7ce --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp), +// more info in ThirdPartyNotices.txt in the root of the project. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Models; + +/// +internal sealed partial record HierarchyInfo +{ + /// + /// Creates a instance wrapping the given members. + /// + /// The input instances to use. + /// A object wrapping . + public CompilationUnitSyntax GetCompilationUnit(ImmutableArray memberDeclarations) + { + // Create the partial type declaration with the given member declarations. + // This code produces a class declaration as follows: + // + // partial class + // { + // + // } + ClassDeclarationSyntax classDeclarationSyntax = + ClassDeclaration(Names[0]) + .AddModifiers(Token(SyntaxKind.PartialKeyword)) + .AddMembers(memberDeclarations.ToArray()); + + TypeDeclarationSyntax typeDeclarationSyntax = classDeclarationSyntax; + + // Add all parent types in ascending order, if any + foreach (string parentType in Names.AsSpan().Slice(1)) + { + typeDeclarationSyntax = + ClassDeclaration(parentType) + .AddModifiers(Token(SyntaxKind.PartialKeyword)) + .AddMembers(typeDeclarationSyntax); + } + + // Create the compilation unit with disabled warnings, target namespace and generated type. + // This will produce code as follows: + // + // + // #pragma warning disable + // + // namespace + // { + // + // } + return + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName(Namespace)) + .WithLeadingTrivia(TriviaList( + Comment("// "), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))) + .AddMembers(typeDeclarationSyntax)) + .NormalizeWhitespace(eol: "\n"); + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs index aa7ba22..555372e 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs @@ -22,7 +22,7 @@ namespace CommunityToolkit.Mvvm.SourceGenerators.Models; /// The metadata name for the current type. /// Gets the namespace for the current type. /// Gets the sequence of type definitions containing the current type. -internal sealed record HierarchyInfo(string FilenameHint, string MetadataName, string Namespace, ImmutableArray Names) +internal sealed partial record HierarchyInfo(string FilenameHint, string MetadataName, string Namespace, ImmutableArray Names) { /// /// Creates a new instance from a given .