diff --git a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj
index eb0d953..be288e6 100644
--- a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj
+++ b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.csproj
@@ -32,7 +32,7 @@
-
+
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs
new file mode 100644
index 0000000..5deff69
--- /dev/null
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs
@@ -0,0 +1,34 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.CodeAnalysis;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+
+///
+/// Extension methods for the type.
+///
+internal static class ISymbolExtensions
+{
+ ///
+ /// Gets the fully qualified name for a given symbol.
+ ///
+ /// The input instance.
+ /// The fully qualified name for .
+ public static string GetFullyQualifiedName(this ISymbol symbol)
+ {
+ return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ }
+
+ ///
+ /// Checks whether or not a given type symbol has a specified full name.
+ ///
+ /// The input instance to check.
+ /// The full name to check.
+ /// Whether has a full name equals to .
+ public static bool HasFullyQualifiedName(this ISymbol symbol, string name)
+ {
+ return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == name;
+ }
+}
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs
new file mode 100644
index 0000000..be90232
--- /dev/null
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Incremental.cs
@@ -0,0 +1,593 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Immutable;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Text;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Text;
+using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics;
+using CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+using CommunityToolkit.Mvvm.SourceGenerators.Input.Models;
+using CommunityToolkit.Mvvm.SourceGenerators.Models;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators;
+
+///
+/// A source generator for generating command properties from annotated methods.
+///
+[Generator(LanguageNames.CSharp)]
+public sealed partial class ICommandGenerator2 : IIncrementalGenerator
+{
+ ///
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ // Get all method declarations with at least one attribute
+ IncrementalValuesProvider methodSymbols =
+ context.SyntaxProvider
+ .CreateSyntaxProvider(
+ static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 },
+ static (context, _) => (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!);
+
+ // Filter the methods using [ICommand]
+ IncrementalValuesProvider<(IMethodSymbol Symbol, AttributeData Attribute)> methodSymbolsWithAttributeData =
+ methodSymbols
+ .Select(static (item, _) => (
+ item,
+ Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.ICommandAttribute") == true)))
+ .Where(static item => item.Attribute is not null)!;
+
+ // Gather info for all annotated command methods
+ IncrementalValuesProvider> commandInfoWithErrors =
+ methodSymbolsWithAttributeData
+ .Select(static (item, _) =>
+ {
+ CommandInfo? commandInfo = Execute.GetInfo(item.Symbol, item.Attribute, out ImmutableArray diagnostics);
+
+ return new Result(commandInfo, diagnostics);
+ });
+
+ // Output the diagnostics
+ context.ReportDiagnostics(commandInfoWithErrors.Select(static (item, _) => item.Errors));
+
+ // Get the filtered sequence to enable caching
+ IncrementalValuesProvider commandInfo =
+ commandInfoWithErrors
+ .Select(static (item, _) => item.Value)
+ .Where(static item => item is not null)!
+ .WithComparer(CommandInfo.Comparer.Default);
+
+ // Generate the commands
+ context.RegisterSourceOutput(commandInfo, static (context, item) =>
+ {
+ ImmutableArray memberDeclarations = Execute.GetSyntax(item);
+ CompilationUnitSyntax compilationUnit = item.Hierarchy.GetCompilationUnit(memberDeclarations);
+
+ context.AddSource(
+ hintName: $"{item.Hierarchy.FilenameHint}.{item.MethodName}.cs",
+ sourceText: SourceText.From(compilationUnit.ToFullString(), Encoding.UTF8));
+ });
+ }
+
+ ///
+ /// A container for all the logic for .
+ ///
+ private static class Execute
+ {
+ ///
+ /// Processes a given target method.
+ ///
+ /// The input instance to process.
+ /// The instance the method was annotated with.
+ /// The resulting diagnostics from the processing operation.
+ /// The resulting instance for , if available.
+ public static CommandInfo? GetInfo(IMethodSymbol methodSymbol, AttributeData attributeData, out ImmutableArray diagnostics)
+ {
+ ImmutableArray.Builder builder = ImmutableArray.CreateBuilder();
+
+ // Get the command field and property names
+ (string fieldName, string propertyName) = GetGeneratedFieldAndPropertyNames(methodSymbol);
+
+ // Get the command type symbols
+ if (!TryMapCommandTypesFromMethod(
+ methodSymbol,
+ builder,
+ out string? commandInterfaceType,
+ out string? commandClassType,
+ out string? delegateType,
+ out ImmutableArray commandTypeArguments,
+ out ImmutableArray delegateTypeArguments))
+ {
+ goto Failure;
+ }
+
+ // Check the switch to allow concurrent executions
+ if (!TryGetAllowConcurrentExecutionsSwitch(
+ methodSymbol,
+ attributeData,
+ commandClassType,
+ builder,
+ out bool allowConcurrentExecutions))
+ {
+ goto Failure;
+ }
+
+ // Get the CanExecute expression type, if any
+ if (!TryGetCanExecuteExpressionType(
+ methodSymbol,
+ attributeData,
+ commandTypeArguments,
+ builder,
+ out string? canExecuteMemberName,
+ out CanExecuteExpressionType? canExecuteExpressionType))
+ {
+ goto Failure;
+ }
+
+ diagnostics = builder.ToImmutable();
+
+ return new(
+ HierarchyInfo.From(methodSymbol.ContainingType),
+ methodSymbol.Name,
+ fieldName,
+ propertyName,
+ commandInterfaceType,
+ commandClassType,
+ delegateType,
+ commandTypeArguments,
+ delegateTypeArguments,
+ canExecuteMemberName,
+ canExecuteExpressionType,
+ allowConcurrentExecutions);
+
+ Failure:
+ diagnostics = builder.ToImmutable();
+
+ return null;
+ }
+
+ ///
+ /// Creates the instances for a specified command.
+ ///
+ /// The input instance with the info to generate the command.
+ /// The instances for the input command.
+ public static ImmutableArray GetSyntax(CommandInfo commandInfo)
+ {
+ // Prepare all necessary type names with type arguments
+ string commandInterfaceTypeXmlName = commandInfo.CommandTypeArguments.IsEmpty
+ ? commandInfo.CommandInterfaceType
+ : commandInfo.CommandInterfaceType + "{T}";
+ string commandClassTypeName = commandInfo.CommandTypeArguments.IsEmpty
+ ? commandInfo.CommandClassType
+ : $"{commandInfo.CommandClassType}<{string.Join(", ", commandInfo.CommandTypeArguments)}>";
+ string commandInterfaceTypeName = commandInfo.CommandTypeArguments.IsEmpty
+ ? commandInfo.CommandInterfaceType
+ : $"{commandInfo.CommandInterfaceType}<{string.Join(", ", commandInfo.CommandTypeArguments)}>";
+ string delegateTypeName = commandInfo.DelegateTypeArguments.IsEmpty
+ ? commandInfo.DelegateType
+ : $"{commandInfo.DelegateType}<{string.Join(", ", commandInfo.DelegateTypeArguments)}>";
+
+ // Construct the generated field as follows:
+ //
+ // The backing field for
+ // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")]
+ // private ? ;
+ FieldDeclarationSyntax fieldDeclaration =
+ FieldDeclaration(
+ VariableDeclaration(NullableType(IdentifierName(commandClassTypeName)))
+ .AddVariables(VariableDeclarator(Identifier(commandInfo.FieldName))))
+ .AddModifiers(Token(SyntaxKind.PrivateKeyword))
+ .AddAttributeLists(
+ AttributeList(SingletonSeparatedList(
+ Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode"))
+ .AddArgumentListArguments(
+ AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))),
+ AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString()))))))
+ .WithOpenBracketToken(Token(TriviaList(Comment($"/// The backing field for .")), SyntaxKind.OpenBracketToken, TriviaList())));
+
+ // Prepares the argument to pass the underlying method to invoke
+ ImmutableArray.Builder commandCreationArguments = ImmutableArray.CreateBuilder();
+
+ // The first argument is the execute method, which is always present
+ commandCreationArguments.Add(
+ Argument(
+ ObjectCreationExpression(IdentifierName(delegateTypeName))
+ .AddArgumentListArguments(Argument(IdentifierName(commandInfo.MethodName)))));
+
+ // Get the can execute expression, if available
+ ExpressionSyntax? canExecuteExpression = commandInfo.CanExecuteExpressionType switch
+ {
+ // Create a lambda expression ignoring the input value:
+ //
+ // new (, _ => ());
+ CanExecuteExpressionType.MethodInvocationLambdaWithDiscard =>
+ SimpleLambdaExpression(
+ Parameter(Identifier(TriviaList(), SyntaxKind.UnderscoreToken, "_", "_", TriviaList())))
+ .WithExpressionBody(InvocationExpression(IdentifierName(commandInfo.CanExecuteMemberName!))),
+
+ // Create a lambda expression returning the property value:
+ //
+ // new (, () => );
+ CanExecuteExpressionType.PropertyAccessLambda =>
+ ParenthesizedLambdaExpression()
+ .WithExpressionBody(IdentifierName(commandInfo.CanExecuteMemberName!)),
+
+ // Create a lambda expression again, but discarding the input value:
+ //
+ // new (, _ => );
+ CanExecuteExpressionType.PropertyAccessLambdaWithDiscard =>
+ SimpleLambdaExpression(
+ Parameter(Identifier(TriviaList(), SyntaxKind.UnderscoreToken, "_", "_", TriviaList())))
+ .WithExpressionBody(IdentifierName(commandInfo.CanExecuteMemberName!)),
+
+ // Create a method groupd expression, which will become:
+ //
+ // new (, );
+ CanExecuteExpressionType.MethodGroup => IdentifierName(commandInfo.CanExecuteMemberName!),
+ _ => null
+ };
+
+ // Add the can execute expression to the arguments, if available
+ if (canExecuteExpression is not null)
+ {
+ commandCreationArguments.Add(Argument(canExecuteExpression));
+ }
+
+ // Disable concurrent executions, if requested
+ if (!commandInfo.AllowConcurrentExecutions)
+ {
+ commandCreationArguments.Add(Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression)));
+ }
+
+ // Construct the generated property as follows (the explicit delegate cast is needed to avoid overload resolution conflicts):
+ //
+ // Gets an and .
+ // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")]
+ // [global::System.Diagnostics.DebuggerNonUserCode]
+ // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ // public => ??= new ();
+ PropertyDeclarationSyntax propertyDeclaration =
+ PropertyDeclaration(
+ IdentifierName(commandInterfaceTypeName),
+ Identifier(commandInfo.PropertyName))
+ .AddModifiers(Token(SyntaxKind.PublicKeyword))
+ .AddAttributeLists(
+ AttributeList(SingletonSeparatedList(
+ Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode"))
+ .AddArgumentListArguments(
+ AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))),
+ AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString()))))))
+ .WithOpenBracketToken(Token(TriviaList(Comment(
+ $"/// Gets an instance wrapping .")),
+ SyntaxKind.OpenBracketToken,
+ TriviaList())),
+ AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))),
+ AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))))
+ .WithExpressionBody(
+ ArrowExpressionClause(
+ AssignmentExpression(
+ SyntaxKind.CoalesceAssignmentExpression,
+ IdentifierName(commandInfo.FieldName),
+ ObjectCreationExpression(IdentifierName(commandClassTypeName))
+ .AddArgumentListArguments(commandCreationArguments.ToArray()))))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+
+ return ImmutableArray.Create(fieldDeclaration, propertyDeclaration);
+ }
+
+ ///
+ /// Get the generated field and property names for the input method.
+ ///
+ /// The input instance to process.
+ /// The generated field and property names for .
+ private static (string FieldName, string PropertyName) GetGeneratedFieldAndPropertyNames(IMethodSymbol methodSymbol)
+ {
+ string propertyName = methodSymbol.Name;
+
+ if (methodSymbol.ReturnType.HasFullyQualifiedName("global::System.Threading.Tasks.Task") &&
+ methodSymbol.Name.EndsWith("Async"))
+ {
+ propertyName = propertyName.Substring(0, propertyName.Length - "Async".Length);
+ }
+
+ propertyName += "Command";
+
+ string fieldName = $"{char.ToLower(propertyName[0])}{propertyName.Substring(1)}";
+
+ return (fieldName, propertyName);
+ }
+
+ ///
+ /// Gets the type symbols for the input method, if supported.
+ ///
+ /// The input instance to process.
+ /// The current collection of gathered diagnostics.
+ /// The command interface type name.
+ /// The command class type name.
+ /// The delegate type name for the wrapped method.
+ /// The type arguments for and , if any.
+ /// The type arguments for , if any.
+ /// Whether or not was valid and the requested types have been set.
+ private static bool TryMapCommandTypesFromMethod(
+ IMethodSymbol methodSymbol,
+ ImmutableArray.Builder diagnostics,
+ [NotNullWhen(true)] out string? commandInterfaceType,
+ [NotNullWhen(true)] out string? commandClassType,
+ [NotNullWhen(true)] out string? delegateType,
+ out ImmutableArray commandTypeArguments,
+ out ImmutableArray delegateTypeArguments)
+ {
+ // Map to IRelayCommand, RelayCommand, Action
+ if (methodSymbol.ReturnsVoid && methodSymbol.Parameters.Length == 0)
+ {
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.RelayCommand";
+ delegateType = "global::System.Action";
+ commandTypeArguments = ImmutableArray.Empty;
+ delegateTypeArguments = ImmutableArray.Empty;
+
+ return true;
+ }
+
+ // Map to IRelayCommand, RelayCommand, Action
+ if (methodSymbol.ReturnsVoid &&
+ methodSymbol.Parameters.Length == 1 &&
+ methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } parameter)
+ {
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.RelayCommand";
+ delegateType = "global::System.Action";
+ commandTypeArguments = ImmutableArray.Create(parameter.Type.GetFullyQualifiedName());
+ delegateTypeArguments = ImmutableArray.Create(parameter.Type.GetFullyQualifiedName());
+
+ return true;
+ }
+
+ if (methodSymbol.ReturnType.HasFullyQualifiedName("global::System.Threading.Tasks.Task"))
+ {
+ // Map to IAsyncRelayCommand, AsyncRelayCommand, Func
+ if (methodSymbol.Parameters.Length == 0)
+ {
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand";
+ delegateType = "global::System.Func";
+ commandTypeArguments = ImmutableArray.Empty;
+ delegateTypeArguments = ImmutableArray.Create("global::System.Threading.Tasks.Task");
+
+ return true;
+ }
+
+ if (methodSymbol.Parameters.Length == 1 &&
+ methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } singleParameter)
+ {
+ // Map to IAsyncRelayCommand, AsyncRelayCommand, Func
+ if (singleParameter.Type.HasFullyQualifiedName("global::System.Threading.CancellationToken"))
+ {
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand";
+ delegateType = "global::System.Func";
+ commandTypeArguments = ImmutableArray.Empty;
+ delegateTypeArguments = ImmutableArray.Create("global::System.Threading.CancellationToken", "global::System.Threading.Tasks.Task");
+
+ return true;
+ }
+
+ // Map to IAsyncRelayCommand, AsyncRelayCommand, Func
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand";
+ delegateType = "global::System.Func";
+ commandTypeArguments = ImmutableArray.Create(singleParameter.Type.GetFullyQualifiedName());
+ delegateTypeArguments = ImmutableArray.Create(singleParameter.Type.GetFullyQualifiedName(), "global::System.Threading.Tasks.Task");
+
+ return true;
+ }
+
+ // Map to IAsyncRelayCommand, AsyncRelayCommand, Func
+ if (methodSymbol.Parameters.Length == 2 &&
+ methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } firstParameter &&
+ methodSymbol.Parameters[1] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } secondParameter &&
+ secondParameter.Type.HasFullyQualifiedName("global::System.Threading.CancellationToken"))
+ {
+ commandInterfaceType = "global::CommunityToolkit.Mvvm.Input.IAsyncRelayCommand";
+ commandClassType = "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand";
+ delegateType = "global::System.Func";
+ commandTypeArguments = ImmutableArray.Create(firstParameter.Type.GetFullyQualifiedName());
+ delegateTypeArguments = ImmutableArray.Create(firstParameter.Type.GetFullyQualifiedName(), secondParameter.Type.GetFullyQualifiedName(), "global::System.Threading.Tasks.Task");
+
+ return true;
+ }
+ }
+
+ diagnostics.Add(InvalidICommandMethodSignatureError, methodSymbol, methodSymbol.ContainingType, methodSymbol);
+
+ commandInterfaceType = null;
+ commandClassType = null;
+ delegateType = null;
+ commandTypeArguments = ImmutableArray.Empty;
+ delegateTypeArguments = ImmutableArray.Empty;
+
+ return false;
+ }
+
+ ///
+ /// Checks whether or not the user has requested to configure the handling of concurrent executions.
+ ///
+ /// The input instance to process.
+ /// The instance the method was annotated with.
+ /// The command class type name.
+ /// The current collection of gathered diagnostics.
+ /// Whether or not concurrent executions have been disabled.
+ /// Whether or not a value for could be retrieved successfully.
+ private static bool TryGetAllowConcurrentExecutionsSwitch(
+ IMethodSymbol methodSymbol,
+ AttributeData attributeData,
+ string commandClassType,
+ ImmutableArray.Builder diagnostics,
+ out bool allowConcurrentExecutions)
+ {
+ // Try to get the custom switch for concurrent executions. If the switch is not present, the
+ // default value is set to true, to avoid breaking backwards compatibility with the first release.
+ if (!attributeData.TryGetNamedArgument("AllowConcurrentExecutions", out allowConcurrentExecutions))
+ {
+ allowConcurrentExecutions = true;
+
+ return true;
+ }
+
+ // If the current type is an async command type and concurrent execution is disabled, pass that value to the constructor.
+ // If concurrent executions are allowed, there is no need to add any additional argument, as that is the default value.
+ if (commandClassType is "global::CommunityToolkit.Mvvm.Input.AsyncRelayCommand")
+ {
+ return true;
+ }
+ else
+ {
+ diagnostics.Add(InvalidConcurrentExecutionsParameterError, methodSymbol, methodSymbol.ContainingType, methodSymbol);
+
+ return false;
+ }
+ }
+
+ ///
+ /// Tries to get the expression type for the "CanExecute" property, if available.
+ ///
+ /// The input instance to process.
+ /// The instance for .
+ /// The command type arguments, if any.
+ /// The current collection of gathered diagnostics.
+ /// The resulting can execute member name, if available.
+ /// The resulting expression type, if available.
+ /// Whether or not a value for and could be determined (may include ).
+ private static bool TryGetCanExecuteExpressionType(
+ IMethodSymbol methodSymbol,
+ AttributeData attributeData,
+ ImmutableArray commandTypeArguments,
+ ImmutableArray.Builder diagnostics,
+ out string? canExecuteMemberName,
+ out CanExecuteExpressionType? canExecuteExpressionType)
+ {
+ // Get the can execute member, if any
+ if (!attributeData.TryGetNamedArgument("CanExecute", out string? memberName))
+ {
+ canExecuteMemberName = null;
+ canExecuteExpressionType = null;
+
+ return true;
+ }
+
+ if (memberName is null)
+ {
+ diagnostics.Add(InvalidCanExecuteMemberName, methodSymbol, memberName ?? string.Empty, methodSymbol.ContainingType);
+
+ goto Failure;
+ }
+
+ ImmutableArray canExecuteSymbols = methodSymbol.ContainingType!.GetMembers(memberName);
+
+ if (canExecuteSymbols.IsEmpty)
+ {
+ diagnostics.Add(InvalidCanExecuteMemberName, methodSymbol, memberName, methodSymbol.ContainingType);
+ }
+ else if (canExecuteSymbols.Length > 1)
+ {
+ diagnostics.Add(MultipleCanExecuteMemberNameMatches, methodSymbol, memberName, methodSymbol.ContainingType);
+ }
+ else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType))
+ {
+ canExecuteMemberName = memberName;
+
+ return true;
+ }
+ else
+ {
+ diagnostics.Add(InvalidCanExecuteMember, methodSymbol, memberName, methodSymbol.ContainingType);
+ }
+
+ Failure:
+ canExecuteMemberName = null;
+ canExecuteExpressionType = null;
+
+ return false;
+ }
+
+ ///
+ /// Gets the expression type for the can execute logic, if possible.
+ ///
+ /// The can execute member symbol (either a method or a property).
+ /// The type arguments for the command interface, if any.
+ /// The resulting can execute expression type, if available.
+ /// Whether or not was set and the input symbol was valid.
+ private static bool TryGetCanExecuteExpressionFromSymbol(
+ ISymbol canExecuteSymbol,
+ ImmutableArray commandTypeArguments,
+ [NotNullWhen(true)] out CanExecuteExpressionType? canExecuteExpressionType)
+ {
+ if (canExecuteSymbol is IMethodSymbol canExecuteMethodSymbol)
+ {
+ // The return type must always be a bool
+ if (!canExecuteMethodSymbol.ReturnType.HasFullyQualifiedName("bool"))
+ {
+ goto Failure;
+ }
+
+ // Parameterless methods are always valid
+ if (canExecuteMethodSymbol.Parameters.IsEmpty)
+ {
+ // If the command is generic, the input value is ignored
+ if (commandTypeArguments.Length > 0)
+ {
+ canExecuteExpressionType = CanExecuteExpressionType.MethodInvocationLambdaWithDiscard;
+ }
+ else
+ {
+ canExecuteExpressionType = CanExecuteExpressionType.MethodGroup;
+ }
+
+ return true;
+ }
+
+ // If the method has parameters, it has to have a single one matching the command type
+ if (canExecuteMethodSymbol.Parameters.Length == 1 &&
+ commandTypeArguments.Length == 1 &&
+ canExecuteMethodSymbol.Parameters[0].Type.HasFullyQualifiedName(commandTypeArguments[0]))
+ {
+ // Create a method group expression again
+ canExecuteExpressionType = CanExecuteExpressionType.MethodGroup;
+
+ return true;
+ }
+ }
+ else if (canExecuteSymbol is IPropertySymbol { GetMethod: not null } canExecutePropertySymbol)
+ {
+ // The property type must always be a bool
+ if (!canExecutePropertySymbol.Type.HasFullyQualifiedName("bool"))
+ {
+ goto Failure;
+ }
+
+ if (commandTypeArguments.Length > 0)
+ {
+ canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambdaWithDiscard;
+ }
+ else
+ {
+ canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambda;
+ }
+
+ return true;
+ }
+
+ Failure:
+ canExecuteExpressionType = null;
+
+ return false;
+ }
+ }
+}
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs
index d22d7c4..3198a0c 100644
--- a/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs
@@ -24,7 +24,6 @@ namespace CommunityToolkit.Mvvm.SourceGenerators;
///
/// A source generator for generating command properties from annotated methods.
///
-[Generator]
public sealed partial class ICommandGenerator : ISourceGenerator
{
///
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs
new file mode 100644
index 0000000..887a422
--- /dev/null
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CanExecuteExpressionType.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Input.Models;
+
+///
+/// A type describing the type of expression for the "CanExecute" property of a command.
+///
+public enum CanExecuteExpressionType
+{
+ ///
+ /// A method invocation lambda with discard: _ => Method().
+ ///
+ MethodInvocationLambdaWithDiscard,
+
+ ///
+ /// A property access lambda: () => Property.
+ ///
+ PropertyAccessLambda,
+
+ ///
+ /// A property access lambda with discard: _ => Property.
+ ///
+ PropertyAccessLambdaWithDiscard,
+
+ ///
+ /// A method group expression: Method.
+ ///
+ MethodGroup
+}
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs
new file mode 100644
index 0000000..b3728c6
--- /dev/null
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs
@@ -0,0 +1,107 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Linq;
+using CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+using CommunityToolkit.Mvvm.SourceGenerators.Models;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Input.Models;
+
+///
+/// A model with gathered info on a given command method.
+///
+/// The hierarchy info for the containing type.
+/// The name of the target method.
+/// The resulting field name for the generated command.
+/// The resulting property name for the generated command.
+/// The command interface type name.
+/// The command class type name.
+/// The delegate type name for the wrapped method.
+/// The type arguments for and , if any.
+/// The type arguments for , if any.
+/// The member name for the can execute check, if available.
+/// The can execute expression type, if available.
+/// Whether or not concurrent executions have been disabled.
+internal sealed record CommandInfo(
+ HierarchyInfo Hierarchy,
+ string MethodName,
+ string FieldName,
+ string PropertyName,
+ string CommandInterfaceType,
+ string CommandClassType,
+ string DelegateType,
+ ImmutableArray CommandTypeArguments,
+ ImmutableArray DelegateTypeArguments,
+ string? CanExecuteMemberName,
+ CanExecuteExpressionType? CanExecuteExpressionType,
+ bool AllowConcurrentExecutions)
+{
+ ///
+ /// An implementation for .
+ ///
+ public sealed class Comparer : IEqualityComparer
+ {
+ ///
+ /// The singleton instance.
+ ///
+ public static Comparer Default { get; } = new();
+
+ ///
+ public bool Equals(CommandInfo x, CommandInfo y)
+ {
+ if (x is null && y is null)
+ {
+ return true;
+ }
+
+ if (x is null || y is null)
+ {
+ return false;
+ }
+
+ if (ReferenceEquals(x, y))
+ {
+ return true;
+ }
+
+ return
+ HierarchyInfo.Comparer.Default.Equals(x.Hierarchy, y.Hierarchy) &&
+ x.MethodName == y.MethodName &&
+ x.FieldName == y.FieldName &&
+ x.PropertyName == y.PropertyName &&
+ x.CommandInterfaceType == y.CommandInterfaceType &&
+ x.CommandClassType == y.CommandClassType &&
+ x.DelegateType == y.DelegateType &&
+ x.CommandTypeArguments.SequenceEqual(y.CommandTypeArguments) &&
+ x.DelegateTypeArguments.SequenceEqual(y.CommandTypeArguments) &&
+ x.CanExecuteMemberName == y.CanExecuteMemberName &&
+ x.CanExecuteExpressionType == y.CanExecuteExpressionType &&
+ x.AllowConcurrentExecutions == y.AllowConcurrentExecutions;
+ }
+
+ ///
+ public int GetHashCode(CommandInfo obj)
+ {
+ HashCode hashCode = default;
+
+ hashCode.Add(obj.Hierarchy, HierarchyInfo.Comparer.Default);
+ hashCode.Add(obj.MethodName);
+ hashCode.Add(obj.FieldName);
+ hashCode.Add(obj.PropertyName);
+ hashCode.Add(obj.CommandInterfaceType);
+ hashCode.Add(obj.CommandClassType);
+ hashCode.Add(obj.DelegateType);
+ hashCode.AddRange(obj.CommandTypeArguments);
+ hashCode.AddRange(obj.DelegateTypeArguments);
+ hashCode.Add(obj.CanExecuteMemberName);
+ hashCode.Add(obj.CanExecuteExpressionType);
+ hashCode.Add(obj.AllowConcurrentExecutions);
+
+ return hashCode.ToHashCode();
+ }
+ }
+}
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs
new file mode 100644
index 0000000..01ee7ce
--- /dev/null
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.Syntax.cs
@@ -0,0 +1,69 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+// This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp),
+// more info in ThirdPartyNotices.txt in the root of the project.
+
+using System.Collections.Immutable;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Models;
+
+///
+internal sealed partial record HierarchyInfo
+{
+ ///
+ /// Creates a instance wrapping the given members.
+ ///
+ /// The input instances to use.
+ /// A object wrapping .
+ public CompilationUnitSyntax GetCompilationUnit(ImmutableArray memberDeclarations)
+ {
+ // Create the partial type declaration with the given member declarations.
+ // This code produces a class declaration as follows:
+ //
+ // partial class
+ // {
+ //
+ // }
+ ClassDeclarationSyntax classDeclarationSyntax =
+ ClassDeclaration(Names[0])
+ .AddModifiers(Token(SyntaxKind.PartialKeyword))
+ .AddMembers(memberDeclarations.ToArray());
+
+ TypeDeclarationSyntax typeDeclarationSyntax = classDeclarationSyntax;
+
+ // Add all parent types in ascending order, if any
+ foreach (string parentType in Names.AsSpan().Slice(1))
+ {
+ typeDeclarationSyntax =
+ ClassDeclaration(parentType)
+ .AddModifiers(Token(SyntaxKind.PartialKeyword))
+ .AddMembers(typeDeclarationSyntax);
+ }
+
+ // Create the compilation unit with disabled warnings, target namespace and generated type.
+ // This will produce code as follows:
+ //
+ //
+ // #pragma warning disable
+ //
+ // namespace
+ // {
+ //
+ // }
+ return
+ CompilationUnit().AddMembers(
+ NamespaceDeclaration(IdentifierName(Namespace))
+ .WithLeadingTrivia(TriviaList(
+ Comment("// "),
+ Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))))
+ .AddMembers(typeDeclarationSyntax))
+ .NormalizeWhitespace(eol: "\n");
+ }
+}
diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs
index aa7ba22..555372e 100644
--- a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs
+++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs
@@ -22,7 +22,7 @@ namespace CommunityToolkit.Mvvm.SourceGenerators.Models;
/// The metadata name for the current type.
/// Gets the namespace for the current type.
/// Gets the sequence of type definitions containing the current type.
-internal sealed record HierarchyInfo(string FilenameHint, string MetadataName, string Namespace, ImmutableArray Names)
+internal sealed partial record HierarchyInfo(string FilenameHint, string MetadataName, string Namespace, ImmutableArray Names)
{
///
/// Creates a new instance from a given .