From c86beb6ff4798992962069cca3685c45c563b2f3 Mon Sep 17 00:00:00 2001 From: Simone Busoli Date: Wed, 19 Sep 2012 22:38:42 +0200 Subject: [PATCH] Supporting recursion by detecting recursive enumerable comparisons and considering them as not equal --- .../Constraints/NUnitEqualityComparer.cs | 246 ++++++++++-------- .../Constraints/CollectionConstraintTests.cs | 17 ++ .../Constraints/NUnitEqualityComparerTests.cs | 24 +- 3 files changed, 178 insertions(+), 109 deletions(-) diff --git a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs index 1d006347..923fa625 100644 --- a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs +++ b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs @@ -4,7 +4,8 @@ // obtain a copy of the license at http://nunit.org // **************************************************************** -using System; +using System; +using System.Collections.ObjectModel; using System.IO; using System.Collections; #if CLR_2_0 || CLR_4_0 @@ -96,71 +97,76 @@ namespace NUnit.Framework.Constraints /// /// Compares two objects for equality within a tolerance. /// - public bool AreEqual(object x, object y, ref Tolerance tolerance) - { - this.failurePoints = new ArrayList(); - - if (x == null && y == null) - return true; - - if (x == null || y == null) - return false; - - if (object.ReferenceEquals(x, y)) - return true; - - Type xType = x.GetType(); - Type yType = y.GetType(); - - EqualityAdapter externalComparer = GetExternalComparer(x, y); - if (externalComparer != null) - return externalComparer.AreEqual(x, y); - - if (xType.IsArray && yType.IsArray && !compareAsCollection) - return ArraysEqual((Array)x, (Array)y, ref tolerance); - - if (x is IDictionary && y is IDictionary) - return DictionariesEqual((IDictionary)x, (IDictionary)y, ref tolerance); - - //if (x is ICollection && y is ICollection) - // return CollectionsEqual((ICollection)x, (ICollection)y, ref tolerance); - - if (x is IEnumerable && y is IEnumerable && !(x is string && y is string)) - return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance); - - if (x is string && y is string) - return StringsEqual((string)x, (string)y); - - if (x is Stream && y is Stream) - return StreamsEqual((Stream)x, (Stream)y); - - if (x is DirectoryInfo && y is DirectoryInfo) - return DirectoriesEqual((DirectoryInfo)x, (DirectoryInfo)y); - - if (Numerics.IsNumericType(x) && Numerics.IsNumericType(y)) - return Numerics.AreEqual(x, y, ref tolerance); - - if (tolerance != null && tolerance.Value is TimeSpan) - { - TimeSpan amount = (TimeSpan)tolerance.Value; - - if (x is DateTime && y is DateTime) - return ((DateTime)x - (DateTime)y).Duration() <= amount; - - if (x is TimeSpan && y is TimeSpan) - return ((TimeSpan)x - (TimeSpan)y).Duration() <= amount; - } - -#if CLR_2_0 || CLR_4_0 - if (FirstImplementsIEquatableOfSecond(xType, yType)) - return InvokeFirstIEquatableEqualsSecond(x, y); - else if (FirstImplementsIEquatableOfSecond(yType, xType)) - return InvokeFirstIEquatableEqualsSecond(y, x); -#endif - - return x.Equals(y); - } - + public bool AreEqual(object x, object y, ref Tolerance tolerance) + { + return AreEqual(x, y, new EnumerableRecursionHelper(), ref tolerance); + } + + private bool AreEqual(object x, object y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) + { + this.failurePoints = new ArrayList(); + + if (x == null && y == null) + return true; + + if (x == null || y == null) + return false; + + if (object.ReferenceEquals(x, y)) + return true; + + Type xType = x.GetType(); + Type yType = y.GetType(); + + EqualityAdapter externalComparer = GetExternalComparer(x, y); + if (externalComparer != null) + return externalComparer.AreEqual(x, y); + + if (xType.IsArray && yType.IsArray && !compareAsCollection) + return ArraysEqual((Array) x, (Array) y, recursionHelper, ref tolerance); + + if (x is IDictionary && y is IDictionary) + return DictionariesEqual((IDictionary) x, (IDictionary) y, recursionHelper, ref tolerance); + + //if (x is ICollection && y is ICollection) + // return CollectionsEqual((ICollection)x, (ICollection)y, ref tolerance); + + if (x is IEnumerable && y is IEnumerable && !(x is string && y is string)) + return EnumerablesEqual((IEnumerable) x, (IEnumerable) y, recursionHelper, ref tolerance); + + if (x is string && y is string) + return StringsEqual((string) x, (string) y); + + if (x is Stream && y is Stream) + return StreamsEqual((Stream) x, (Stream) y); + + if (x is DirectoryInfo && y is DirectoryInfo) + return DirectoriesEqual((DirectoryInfo) x, (DirectoryInfo) y); + + if (Numerics.IsNumericType(x) && Numerics.IsNumericType(y)) + return Numerics.AreEqual(x, y, ref tolerance); + + if (tolerance != null && tolerance.Value is TimeSpan) + { + TimeSpan amount = (TimeSpan) tolerance.Value; + + if (x is DateTime && y is DateTime) + return ((DateTime) x - (DateTime) y).Duration() <= amount; + + if (x is TimeSpan && y is TimeSpan) + return ((TimeSpan) x - (TimeSpan) y).Duration() <= amount; + } + +#if CLR_2_0 || CLR_4_0 + if (FirstImplementsIEquatableOfSecond(xType, yType)) + return InvokeFirstIEquatableEqualsSecond(x, y); + else if (FirstImplementsIEquatableOfSecond(yType, xType)) + return InvokeFirstIEquatableEqualsSecond(y, x); +#endif + + return x.Equals(y); + } + #if CLR_2_0 || CLR_4_0 private static bool FirstImplementsIEquatableOfSecond(Type first, Type second) { @@ -211,7 +217,7 @@ namespace NUnit.Framework.Constraints /// /// Helper method to compare two arrays /// - private bool ArraysEqual(Array x, Array y, ref Tolerance tolerance) + private bool ArraysEqual(Array x, Array y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { int rank = x.Rank; @@ -222,10 +228,10 @@ namespace NUnit.Framework.Constraints if (x.GetLength(r) != y.GetLength(r)) return false; - return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance); + return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, recursionHelper, ref tolerance); } - private bool DictionariesEqual(IDictionary x, IDictionary y, ref Tolerance tolerance) + private bool DictionariesEqual(IDictionary x, IDictionary y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { if (x.Count != y.Count) return false; @@ -235,43 +241,12 @@ namespace NUnit.Framework.Constraints return false; foreach (object key in x.Keys) - if (!AreEqual(x[key], y[key], ref tolerance)) + if (!AreEqual(x[key], y[key], recursionHelper, ref tolerance)) return false; return true; } - private bool CollectionsEqual(ICollection x, ICollection y, ref Tolerance tolerance) - { - IEnumerator expectedEnum = x.GetEnumerator(); - IEnumerator actualEnum = y.GetEnumerator(); - - int count; - for (count = 0; ; count++) - { - bool expectedHasData = expectedEnum.MoveNext(); - bool actualHasData = actualEnum.MoveNext(); - - if (!expectedHasData && !actualHasData) - return true; - - if (expectedHasData != actualHasData || - !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance)) - { - FailurePoint fp = new FailurePoint(); - fp.Position = count; - fp.ExpectedHasData = expectedHasData; - if (expectedHasData) - fp.ExpectedValue = expectedEnum.Current; - fp.ActualHasData = actualHasData; - if (actualHasData) - fp.ActualValue = actualEnum.Current; - failurePoints.Insert(0, fp); - return false; - } - } - } - private bool StringsEqual(string x, string y) { string s1 = caseInsensitive ? x.ToLower() : x; @@ -279,11 +254,14 @@ namespace NUnit.Framework.Constraints return s1.Equals(s2); } + + private bool EnumerablesEqual(IEnumerable expected, IEnumerable actual, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) + { + if (recursionHelper.CheckRecursion(expected, actual)) + return false; - private bool EnumerablesEqual(IEnumerable x, IEnumerable y, ref Tolerance tolerance) - { - IEnumerator expectedEnum = x.GetEnumerator(); - IEnumerator actualEnum = y.GetEnumerator(); + IEnumerator expectedEnum = expected.GetEnumerator(); + IEnumerator actualEnum = actual.GetEnumerator(); int count; for (count = 0; ; count++) @@ -295,7 +273,7 @@ namespace NUnit.Framework.Constraints return true; if (expectedHasData != actualHasData || - !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance)) + !AreEqual(expectedEnum.Current, actualEnum.Current, recursionHelper, ref tolerance)) { FailurePoint fp = new FailurePoint(); fp.Position = count; @@ -420,6 +398,58 @@ namespace NUnit.Framework.Constraints public bool ActualHasData; } - #endregion + #endregion + + #region Nested EnumerableRecursionHelper class + + class EnumerableRecursionHelper + { + readonly Hashtable table = new Hashtable(); + + public bool CheckRecursion(IEnumerable expected, IEnumerable actual) + { + var key = new UnorderedReferencePair(expected, actual); + + if (table.Contains(key)) + return true; + + table.Add(key, null); + return false; + } + + class UnorderedReferencePair : IEquatable + { + private readonly object first; + private readonly object second; + + public UnorderedReferencePair(object first, object second) + { + this.first = first; + this.second = second; + } + + public bool Equals(UnorderedReferencePair other) + { + return (Equals(first, other.first) && Equals(second, other.second)) || + (Equals(first, other.second) && Equals(second, other.first)); + } + + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) return false; + return obj is UnorderedReferencePair && Equals((UnorderedReferencePair) obj); + } + + public override int GetHashCode() + { + unchecked + { + return ((first != null ? first.GetHashCode() : 0)*397) ^ ((second != null ? second.GetHashCode() : 0)*397); + } + } + } + } + + #endregion } -} +} \ No newline at end of file diff --git a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs index d9d727c1..ce7dfe4a 100644 --- a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs +++ b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs @@ -307,6 +307,23 @@ namespace NUnit.Framework.Constraints } } + [Test] + public void ContainsWithRecursiveStructure() + { + SelfRecursiveEnumerable item = new SelfRecursiveEnumerable(); + SelfRecursiveEnumerable[] container = new SelfRecursiveEnumerable[] {new SelfRecursiveEnumerable(), item}; + + Assert.That(container, new CollectionContainsConstraint(item)); + } + + class SelfRecursiveEnumerable : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return this; + } + } + #if CS_3_0 || CS_4_0 [Test] public void UsesProvidedLambdaExpression() diff --git a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs index 8bd29196..5887ea04 100644 --- a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs +++ b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs @@ -30,7 +30,29 @@ namespace NUnit.Framework.Constraints Assert.True(comparer.AreEqual(array, array, ref tolerance)); } -#if CLR_2_0 || CLR_4_0 +#if CLR_2_0 || CLR_4_0 + [Test] + public void RecursiveEnumerablesAreNotEqual() + { + var a1 = new object[1]; + a1[0] = a1; + var a2 = new object[1]; + a2[0] = a2; + + Assert.False(comparer.AreEqual(a1, a2, ref tolerance)); + } + + [Test] + public void RecursiveEnumerablesAreNotEqual2() + { + var a1 = new object[1]; + var a2 = new object[1]; + a1[0] = a2; + a2[0] = a1; + + Assert.False(comparer.AreEqual(a1, a2, ref tolerance)); + } + [Test] public void IEquatableSuccess() {