Equality Generator now supports baseclasses

This commit is contained in:
Carl de Billy 2018-02-02 16:38:53 -05:00
Родитель 12544b29d7
Коммит f4f6b8ba7a
2 изменённых файлов: 117 добавлений и 12 удалений

Просмотреть файл

@ -62,6 +62,12 @@ namespace Uno.CodeGen.Tests
internal ICollection<TSomething> H { get; set; }
}
[GeneratedEquality]
internal partial class DerivedEqualityClass : MyEqualityClass<int>
{
}
[GeneratedEquality]
internal partial struct MyEqualityStruct
{

Просмотреть файл

@ -32,11 +32,16 @@ namespace Uno
[SourceGeneratorDependency("Uno.ImmutableGenerator")]
public class EqualityGenerator : SourceGenerator
{
private INamedTypeSymbol _objectSymbol;
private INamedTypeSymbol _valueTypeSymbol;
private INamedTypeSymbol _boolSymbol;
private INamedTypeSymbol _intSymbol;
private INamedTypeSymbol _arraySymbol;
private INamedTypeSymbol _collectionSymbol;
private INamedTypeSymbol _collectionGenericSymbol;
private INamedTypeSymbol _iEquatableSymbol;
private INamedTypeSymbol _iKeyEquatableSymbol;
private INamedTypeSymbol _iKeyEquatableGenericSymbol;
private INamedTypeSymbol _generatedEqualityAttributeSymbol;
private INamedTypeSymbol _ignoreForEqualityAttributeSymbol;
private INamedTypeSymbol _equalityHashCodeAttributeSymbol;
@ -72,11 +77,16 @@ namespace Uno
public override void Execute(SourceGeneratorContext context)
{
_context = context;
_objectSymbol = context.Compilation.GetTypeByMetadataName("System.Object");
_valueTypeSymbol = context.Compilation.GetTypeByMetadataName("System.ValueType");
_boolSymbol = context.Compilation.GetTypeByMetadataName("System.Bool");
_intSymbol = context.Compilation.GetTypeByMetadataName("System.Int32");
_arraySymbol = context.Compilation.GetTypeByMetadataName("System.Array");
_collectionSymbol = context.Compilation.GetTypeByMetadataName("System.Collections.ICollection");
_collectionGenericSymbol = context.Compilation.GetTypeByMetadataName("System.Collections.Generic.ICollection`1");
_iEquatableSymbol = context.Compilation.GetTypeByMetadataName("System.IEquatable`1");
_iKeyEquatableSymbol = context.Compilation.GetTypeByMetadataName("Uno.Equality.IKeyEquatable");
_iKeyEquatableGenericSymbol = context.Compilation.GetTypeByMetadataName("Uno.Equality.IKeyEquatable`1");
_generatedEqualityAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.GeneratedEqualityAttribute");
_ignoreForEqualityAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.EqualityIgnoreAttribute");
_equalityHashCodeAttributeSymbol = context.Compilation.GetTypeByMetadataName("Uno.EqualityHashAttribute");
@ -94,7 +104,8 @@ namespace Uno
var (symbolName, symbolNameWithGenerics, symbolNameForXml, symbolNameDefinition, resultFileName) = typeSymbol.GetSymbolNames();
var (equalityMembers, hashMembers, keyEqualityMembers) = GetEqualityMembers(typeSymbol);
var generateKeyEquals = keyEqualityMembers.Any();
var baseTypeInfo = GetBaseTypeInfo(typeSymbol);
var generateKeyEquals = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT || keyEqualityMembers.Any();
builder.AppendLine("using System;");
builder.AppendLine();
@ -112,6 +123,15 @@ namespace Uno
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: you should add the partial modifier to the class {symbolNameWithGenerics}.");
}
if (baseTypeInfo.isBaseType && !baseTypeInfo.baseOverridesEquals)
{
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: base type {typeSymbol.BaseType} does not override .Equals() method. It could lead to erronous results.");
}
if (baseTypeInfo.isBaseType && !baseTypeInfo.baseOverridesGetHashCode)
{
builder.AppendLineInvariant($"#warning {nameof(EqualityGenerator)}: base type {typeSymbol.BaseType} does not override .GetHashCode() method. It could lead to erronous results.");
}
var classOrStruct = typeSymbol.IsReferenceType ? "class" : "struct";
var keyEqualsInterfaces = generateKeyEquals
@ -184,7 +204,11 @@ namespace Uno
{
builder.AppendLineInvariant("if (other.GetHashCode() != GetHashCode()) return false;");
GenerateEqualCalculation(typeSymbol, builder, equalityMembers);
var baseCall = baseTypeInfo.baseOverridesEquals
? "base.KeyEquals(other)"
: null;
GenerateEqualLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
}
builder.AppendLineInvariant("#endregion");
@ -222,7 +246,10 @@ namespace Uno
using (builder.BlockInvariant("private int ComputeHashCode()"))
{
GenerateHashCalculation(typeSymbol, builder, hashMembers);
var baseCall = baseTypeInfo.baseOverridesGetHashCode
? "base.GetHashCode()"
: null;
GenerateHashLogic(typeSymbol, builder, hashMembers, baseCall);
}
builder.AppendLineInvariant("#endregion");
@ -265,7 +292,11 @@ namespace Uno
{
builder.AppendLineInvariant("if (other.GetKeyHashCode() != GetKeyHashCode()) return false;");
GenerateEqualCalculation(typeSymbol, builder, keyEqualityMembers);
var baseCall = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT
? "base.KeyEquals(other)"
: null;
GenerateEqualLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
}
builder.AppendLine();
@ -281,7 +312,10 @@ namespace Uno
using (builder.BlockInvariant("private int ComputeKeyHashCode()"))
{
GenerateHashCalculation(typeSymbol, builder, keyEqualityMembers);
var baseCall = baseTypeInfo.baseImplementsKeyEquals || baseTypeInfo.baseImplementsKeyEqualsT
? "base.GetKeyHashCode()"
: null;
GenerateHashLogic(typeSymbol, builder, keyEqualityMembers, baseCall);
}
builder.AppendLine();
@ -297,11 +331,58 @@ namespace Uno
_context.AddCompilationUnit(resultFileName, builder.ToString());
}
private void GenerateEqualCalculation(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] equalityMembers)
private (bool isBaseType, bool baseOverridesGetHashCode, bool baseOverridesEquals, bool baseImplementsIEquatable, bool baseImplementsKeyEquals, bool baseImplementsKeyEqualsT)
GetBaseTypeInfo(INamedTypeSymbol typeSymbol)
{
if (equalityMembers.Length == 0)
var baseType = typeSymbol.BaseType;
if (baseType.Equals(_objectSymbol) || baseType.Equals(_valueTypeSymbol))
{
builder.AppendLineInvariant("#error No fields or properties used for equality check.");
return (false, false, false, false, false, false);
}
var isBaseTypeInSources = baseType.Locations.Any(l => l.IsInSource);
var baseTypeWillBeGenerated = isBaseTypeInSources
&& baseType.FindAttributeFlattened(_generatedEqualityAttributeSymbol) != null;
var baseOverridesGetHashCode = baseTypeWillBeGenerated
|| baseType.GetMethods()
.Any(m => m.Name.Equals("GetHashCode")
&& !m.IsStatic
&& m.IsOverride
&& m.ReturnType.Equals(_intSymbol)
&& m.Parameters.Length == 0);
var baseOverridesEquals = baseTypeWillBeGenerated
|| baseType.GetMethods()
.Any(m => m.Name.Equals("Equals")
&& !m.IsStatic
&& m.IsOverride
&& m.ReturnType.Equals(_boolSymbol)
&& m.Parameters.Length == 1
&& m.Parameters[0].Type.Equals(_objectSymbol));
var baseImplementsIEquatable = baseTypeWillBeGenerated
|| baseType
.Interfaces
.Any(i => i.OriginalDefinition.Equals(_iEquatableSymbol));
var baseImplementsKeyEquals = baseType
.Interfaces
.Any(i => i.OriginalDefinition.Equals(_iKeyEquatableSymbol));
var baseImplementsKeyEqualsT = baseType
.Interfaces
.Any(i => i.OriginalDefinition.Equals(_iKeyEquatableGenericSymbol));
return (true, baseOverridesGetHashCode, baseOverridesEquals, baseImplementsIEquatable, baseImplementsKeyEquals, baseImplementsKeyEqualsT);
}
private void GenerateEqualLogic(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] equalityMembers, string baseCall)
{
if (baseCall == null && equalityMembers.Length == 0)
{
builder.AppendLineInvariant("#warning No fields or properties used for equality check.");
}
foreach (var member in equalityMembers)
@ -334,19 +415,33 @@ namespace Uno
}
}
}
builder.AppendLineInvariant("return true; // no differences found");
if (baseCall != null)
{
builder.AppendLineInvariant($"return {baseCall}; // no differences found, check with base");
}
else
{
builder.AppendLineInvariant("return true; // no differences found");
}
}
private void GenerateHashCalculation(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] hashMembers)
private void GenerateHashLogic(INamedTypeSymbol typeSymbol, IndentedStringBuilder builder, ISymbol[] hashMembers, string baseCall)
{
if (hashMembers.Length == 0)
if (baseCall == null && hashMembers.Length == 0)
{
builder.AppendLineInvariant("#warning There is no members marked with [Uno.EqualityHash] or [Uno.EqualityKey]. You should add at least one. Documentation: https://github.com/nventive/Uno.CodeGen/blob/master/doc/Equality%20Generation.md");
builder.AppendLineInvariant("return 0; // no members to compute hash");
}
else
{
builder.AppendLineInvariant("int hash = 104729; // 10 000th prime number");
if (baseCall != null)
{
builder.AppendLineInvariant($"int hash = {baseCall}; // start with hash from base");
}
else
{
builder.AppendLineInvariant("int hash = 104729; // 10 000th prime number");
}
using (builder.BlockInvariant("unchecked"))
{
for (var i = 0; i < hashMembers.Length; i++)
@ -401,6 +496,10 @@ namespace Uno
{
getHashCode = $"({member.Name} ? 1 : 0)";
}
else if (definition.Equals(_intSymbol))
{
getHashCode = $"{member.Name}";
}
else if (definition.DerivesFromType(_arraySymbol))
{
getHashCode = $"{member.Name}.Length";