Equality Generator now supports baseclasses
This commit is contained in:
Родитель
12544b29d7
Коммит
f4f6b8ba7a
|
@ -62,6 +62,12 @@ namespace Uno.CodeGen.Tests
|
|||
internal ICollection<TSomething> H { get; set; }
|
||||
}
|
||||
|
||||
[GeneratedEquality]
|
||||
internal partial class DerivedEqualityClass : MyEqualityClass<int>
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
[GeneratedEquality]
|
||||
internal partial struct MyEqualityStruct
|
||||
{
|
||||
|
|
|
@ -32,11 +32,16 @@ namespace Uno
|
|||
[SourceGeneratorDependency("Uno.ImmutableGenerator")]
|
||||
public class EqualityGenerator : SourceGenerator
|
||||
{
|
||||
private INamedTypeSymbol _objectSymbol;
|
||||
private INamedTypeSymbol _valueTypeSymbol;
|
||||
private INamedTypeSymbol _boolSymbol;
|
||||
private INamedTypeSymbol _intSymbol;
|
||||
private INamedTypeSymbol _arraySymbol;
|
||||
private INamedTypeSymbol _collectionSymbol;
|
||||
private INamedTypeSymbol _collectionGenericSymbol;
|
||||
private INamedTypeSymbol _iEquatableSymbol;
|
||||
private INamedTypeSymbol _iKeyEquatableSymbol;
|
||||
private INamedTypeSymbol _iKeyEquatableGenericSymbol;
|
||||
private INamedTypeSymbol _generatedEqualityAttributeSymbol;
|
||||
private INamedTypeSymbol _ignoreForEqualityAttributeSymbol;
|
||||
private INamedTypeSymbol _equalityHashCodeAttributeSymbol;
|
||||
|
@ -72,11 +77,16 @@ namespace Uno
|
|||
public override void Execute(SourceGeneratorContext context)
|
||||
{
|
||||
_context = context;
|
||||
_objectSymbol = context.Compilation.GetTypeByMetadataName("System.Object");
|
||||
_valueTypeSymbol = context.Compilation.GetTypeByMetadataName("System.ValueType");
|
||||
_boolSymbol = context.Compilation.GetTypeByMetadataName("System.Bool");
|
||||
_intSymbol = context.Compilation.GetTypeByMetadataName("System.Int32");
|
||||
_arraySymbol = context.Compilation.GetTypeByMetadataName("System.Array");
|
||||
_collectionSymbol = context.Compilation.GetTypeByMetadataName("System.Collections.ICollection");
|
||||
_collectionGenericSymbol = context.Compilation.GetTypeByMetadataName("System.Collections.Generic.ICollection`1");
|
||||
_iEquatableSymbol = context.Compilation.GetTypeByMetadataName("System.IEquatable`1");
|
||||
_iKeyEquatableSymbol = context.Compilation.GetTypeByMetadataName("Uno.Equality.IKeyEquatable");
|
||||
_iKeyEquatableGenericSymbol = context.Compilation.GetTypeByMetadataName("Uno.Equality.IKeyEquatable`1");
|
||||
_generatedEqualityAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.GeneratedEqualityAttribute");
|
||||
_ignoreForEqualityAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.EqualityIgnoreAttribute");
|
||||
_equalityHashCodeAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.EqualityHashAttribute");
|
||||
|
@ -94,7 +104,8 @@ namespace Uno
|
|||
|
||||
var (symbolName, symbolNameWithGenerics, symbolNameForXml, symbolNameDefinition, resultFileName) = typeSymbol.GetSymbolNames();
|
||||
var (equalityMembers, hashMembers, keyEqualityMembers) = GetEqualityMembers(typeSymbol);
|
||||
var generateKeyEquals = keyEqualityMembers.Any();
|
||||
var baseTypeInfo = GetBaseTypeInfo(typeSymbol);
|
||||
var generateKeyEquals = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT || keyEqualityMembers.Any();
|
||||
|
||||
builder.AppendLine("using System;");
|
||||
builder.AppendLine();
|
||||
|
@ -112,6 +123,15 @@ namespace Uno
|
|||
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: you should add the partial modifier to the class {symbolNameWithGenerics}.");
|
||||
}
|
||||
|
||||
if (baseTypeInfo.isBaseType && !baseTypeInfo.baseOverridesEquals)
|
||||
{
|
||||
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: base type {typeSymbol.BaseType} does not override .Equals() method. It could lead to erronous results.");
|
||||
}
|
||||
if (baseTypeInfo.isBaseType && !baseTypeInfo.baseOverridesGetHashCode)
|
||||
{
|
||||
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: base type {typeSymbol.BaseType} does not override .GetHashCode() method. It could lead to erronous results.");
|
||||
}
|
||||
|
||||
var classOrStruct = typeSymbol.IsReferenceType ? "class" : "struct";
|
||||
|
||||
var keyEqualsInterfaces = generateKeyEquals
|
||||
|
@ -184,7 +204,11 @@ namespace Uno
|
|||
{
|
||||
builder.AppendLineInvariant("if (other.GetHashCode() != GetHashCode()) return false;");
|
||||
|
||||
GenerateEqualCalculation(typeSymbol, builder, equalityMembers);
|
||||
var baseCall = baseTypeInfo.baseOverridesEquals
|
||||
? "base.KeyEquals(other)"
|
||||
: null;
|
||||
|
||||
GenerateEqualLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
|
||||
}
|
||||
|
||||
builder.AppendLineInvariant("#endregion");
|
||||
|
@ -222,7 +246,10 @@ namespace Uno
|
|||
|
||||
using (builder.BlockInvariant("private int ComputeHashCode()"))
|
||||
{
|
||||
GenerateHashCalculation(typeSymbol, builder, hashMembers);
|
||||
var baseCall = baseTypeInfo.baseOverridesGetHashCode
|
||||
? "base.GetHashCode()"
|
||||
: null;
|
||||
GenerateHashLogic(typeSymbol, builder, hashMembers, baseCall);
|
||||
}
|
||||
|
||||
builder.AppendLineInvariant("#endregion");
|
||||
|
@ -265,7 +292,11 @@ namespace Uno
|
|||
{
|
||||
builder.AppendLineInvariant("if (other.GetKeyHashCode() != GetKeyHashCode()) return false;");
|
||||
|
||||
GenerateEqualCalculation(typeSymbol, builder, keyEqualityMembers);
|
||||
var baseCall = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT
|
||||
? "base.KeyEquals(other)"
|
||||
: null;
|
||||
|
||||
GenerateEqualLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
|
||||
}
|
||||
|
||||
builder.AppendLine();
|
||||
|
@ -281,7 +312,10 @@ namespace Uno
|
|||
|
||||
using (builder.BlockInvariant("private int ComputeKeyHashCode()"))
|
||||
{
|
||||
GenerateHashCalculation(typeSymbol, builder, keyEqualityMembers);
|
||||
var baseCall = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT
|
||||
? "base.GetKeyHashCode()"
|
||||
: null;
|
||||
GenerateHashLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
|
||||
}
|
||||
|
||||
builder.AppendLine();
|
||||
|
@ -297,11 +331,58 @@ namespace Uno
|
|||
_context.AddCompilationUnit(resultFileName, builder.ToString());
|
||||
}
|
||||
|
||||
private void GenerateEqualCalculation(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] equalityMembers)
|
||||
private (bool isBaseType, bool baseOverridesGetHashCode, bool baseOverridesEquals, bool baseImplementsIEquatable, bool baseImplementsKeyEquals, bool baseImplementsKeyEqualsT)
|
||||
GetBaseTypeInfo(INamedTypeSymbol typeSymbol)
|
||||
{
|
||||
if (equalityMembers.Length == 0)
|
||||
var baseType = typeSymbol.BaseType;
|
||||
if (baseType.Equals(_objectSymbol) || baseType.Equals(_valueTypeSymbol))
|
||||
{
|
||||
builder.AppendLineInvariant("#error No fields or properties used for equality check.");
|
||||
return (false, false, false, false, false, false);
|
||||
}
|
||||
|
||||
var isBaseTypeInSources = baseType.Locations.Any(l => l.IsInSource);
|
||||
|
||||
var baseTypeWillBeGenerated = isBaseTypeInSources
|
||||
&& baseType.FindAttributeFlattened(_generatedEqualityAttributeSymbol) != null;
|
||||
|
||||
var baseOverridesGetHashCode = baseTypeWillBeGenerated
|
||||
|| baseType.GetMethods()
|
||||
.Any(m => m.Name.Equals("GetHashCode")
|
||||
&& !m.IsStatic
|
||||
&& m.IsOverride
|
||||
&& m.ReturnType.Equals(_intSymbol)
|
||||
&& m.Parameters.Length == 0);
|
||||
|
||||
var baseOverridesEquals = baseTypeWillBeGenerated
|
||||
|| baseType.GetMethods()
|
||||
.Any(m => m.Name.Equals("Equals")
|
||||
&& !m.IsStatic
|
||||
&& m.IsOverride
|
||||
&& m.ReturnType.Equals(_boolSymbol)
|
||||
&& m.Parameters.Length == 1
|
||||
&& m.Parameters[0].Type.Equals(_objectSymbol));
|
||||
|
||||
var baseImplementsIEquatable = baseTypeWillBeGenerated
|
||||
|| baseType
|
||||
.Interfaces
|
||||
.Any(i => i.OriginalDefinition.Equals(_iEquatableSymbol));
|
||||
|
||||
var baseImplementsKeyEquals = baseType
|
||||
.Interfaces
|
||||
.Any(i => i.OriginalDefinition.Equals(_iKeyEquatableSymbol));
|
||||
|
||||
var baseImplementsKeyEqualsT = baseType
|
||||
.Interfaces
|
||||
.Any(i => i.OriginalDefinition.Equals(_iKeyEquatableGenericSymbol));
|
||||
|
||||
return (true, baseOverridesGetHashCode, baseOverridesEquals, baseImplementsIEquatable, baseImplementsKeyEquals, baseImplementsKeyEqualsT);
|
||||
}
|
||||
|
||||
private void GenerateEqualLogic(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] equalityMembers, string baseCall)
|
||||
{
|
||||
if (baseCall == null && equalityMembers.Length == 0)
|
||||
{
|
||||
builder.AppendLineInvariant("#warning No fields or properties used for equality check.");
|
||||
}
|
||||
|
||||
foreach (var member in equalityMembers)
|
||||
|
@ -334,19 +415,33 @@ namespace Uno
|
|||
}
|
||||
}
|
||||
}
|
||||
builder.AppendLineInvariant("return true; // no differences found");
|
||||
if (baseCall != null)
|
||||
{
|
||||
builder.AppendLineInvariant($"return {baseCall}; // no differences found, check with base");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.AppendLineInvariant("return true; // no differences found");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateHashCalculation(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] hashMembers)
|
||||
private void GenerateHashLogic(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] hashMembers, string baseCall)
|
||||
{
|
||||
if (hashMembers.Length == 0)
|
||||
if (baseCall == null && hashMembers.Length == 0)
|
||||
{
|
||||
builder.AppendLineInvariant("#warning There is no members marked with [Uno.EqualityHash] or [Uno.EqualityKey]. You should add at least one. Documentation: https://github.com/nventive/Uno.CodeGen/blob/master/doc/Equality%20Generation.md");
|
||||
builder.AppendLineInvariant("return 0; // no members to compute hash");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.AppendLineInvariant("int hash = 104729; // 10 000th prime number");
|
||||
if (baseCall != null)
|
||||
{
|
||||
builder.AppendLineInvariant($"int hash = {baseCall}; // start with hash from base");
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.AppendLineInvariant("int hash = 104729; // 10 000th prime number");
|
||||
}
|
||||
using (builder.BlockInvariant("unchecked"))
|
||||
{
|
||||
for (var i = 0; i < hashMembers.Length; i++)
|
||||
|
@ -401,6 +496,10 @@ namespace Uno
|
|||
{
|
||||
getHashCode = $"({member.Name} ? 1 : 0)";
|
||||
}
|
||||
else if (definition.Equals(_intSymbol))
|
||||
{
|
||||
getHashCode = $"{member.Name}";
|
||||
}
|
||||
else if (definition.DerivesFromType(_arraySymbol))
|
||||
{
|
||||
getHashCode = $"{member.Name}.Length";
|
||||
|
|
Загрузка…
Ссылка в новой задаче