diff --git a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs index 1d006347..e4bc9a56 100644 --- a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs +++ b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs @@ -5,6 +5,7 @@ // **************************************************************** using System; +using System.Collections.ObjectModel; using System.IO; using System.Collections; #if CLR_2_0 || CLR_4_0 @@ -97,6 +98,11 @@ namespace NUnit.Framework.Constraints /// Compares two objects for equality within a tolerance. /// public bool AreEqual(object x, object y, ref Tolerance tolerance) + { + return AreEqual(x, y, new EnumerableRecursionHelper(), ref tolerance); + } + + private bool AreEqual(object x, object y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { this.failurePoints = new ArrayList(); @@ -105,9 +111,9 @@ namespace NUnit.Framework.Constraints if (x == null || y == null) return false; - - if (object.ReferenceEquals(x, y)) - return true; + + if (object.ReferenceEquals(x, y)) + return true; Type xType = x.GetType(); Type yType = y.GetType(); @@ -117,38 +123,38 @@ namespace NUnit.Framework.Constraints return externalComparer.AreEqual(x, y); if (xType.IsArray && yType.IsArray && !compareAsCollection) - return ArraysEqual((Array)x, (Array)y, ref tolerance); + return ArraysEqual((Array) x, (Array) y, recursionHelper, ref tolerance); if (x is IDictionary && y is IDictionary) - return DictionariesEqual((IDictionary)x, (IDictionary)y, ref tolerance); + return DictionariesEqual((IDictionary) x, (IDictionary) y, recursionHelper, ref tolerance); //if (x is ICollection && y is ICollection) // return CollectionsEqual((ICollection)x, (ICollection)y, ref tolerance); if (x is IEnumerable && y is IEnumerable && !(x is string && y is string)) - return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance); + return EnumerablesEqual((IEnumerable) x, (IEnumerable) y, recursionHelper, ref tolerance); if (x is string && y is string) - return StringsEqual((string)x, (string)y); + return StringsEqual((string) x, (string) y); if (x is Stream && y is Stream) - return StreamsEqual((Stream)x, (Stream)y); + return StreamsEqual((Stream) x, (Stream) y); if (x is DirectoryInfo && y is DirectoryInfo) - return DirectoriesEqual((DirectoryInfo)x, (DirectoryInfo)y); + return DirectoriesEqual((DirectoryInfo) x, (DirectoryInfo) y); if (Numerics.IsNumericType(x) && Numerics.IsNumericType(y)) return Numerics.AreEqual(x, y, ref tolerance); if (tolerance != null && tolerance.Value is TimeSpan) { - TimeSpan amount = (TimeSpan)tolerance.Value; + TimeSpan amount = (TimeSpan) tolerance.Value; if (x is DateTime && y is DateTime) - return ((DateTime)x - (DateTime)y).Duration() <= amount; + return ((DateTime) x - (DateTime) y).Duration() <= amount; if (x is TimeSpan && y is TimeSpan) - return ((TimeSpan)x - (TimeSpan)y).Duration() <= amount; + return ((TimeSpan) x - (TimeSpan) y).Duration() <= amount; } #if CLR_2_0 || CLR_4_0 @@ -211,7 +217,7 @@ namespace NUnit.Framework.Constraints /// /// Helper method to compare two arrays /// - private bool ArraysEqual(Array x, Array y, ref Tolerance tolerance) + private bool ArraysEqual(Array x, Array y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { int rank = x.Rank; @@ -222,10 +228,10 @@ namespace NUnit.Framework.Constraints if (x.GetLength(r) != y.GetLength(r)) return false; - return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance); + return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, recursionHelper, ref tolerance); } - private bool DictionariesEqual(IDictionary x, IDictionary y, ref Tolerance tolerance) + private bool DictionariesEqual(IDictionary x, IDictionary y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { if (x.Count != y.Count) return false; @@ -235,43 +241,12 @@ namespace NUnit.Framework.Constraints return false; foreach (object key in x.Keys) - if (!AreEqual(x[key], y[key], ref tolerance)) + if (!AreEqual(x[key], y[key], recursionHelper, ref tolerance)) return false; return true; } - private bool CollectionsEqual(ICollection x, ICollection y, ref Tolerance tolerance) - { - IEnumerator expectedEnum = x.GetEnumerator(); - IEnumerator actualEnum = y.GetEnumerator(); - - int count; - for (count = 0; ; count++) - { - bool expectedHasData = expectedEnum.MoveNext(); - bool actualHasData = actualEnum.MoveNext(); - - if (!expectedHasData && !actualHasData) - return true; - - if (expectedHasData != actualHasData || - !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance)) - { - FailurePoint fp = new FailurePoint(); - fp.Position = count; - fp.ExpectedHasData = expectedHasData; - if (expectedHasData) - fp.ExpectedValue = expectedEnum.Current; - fp.ActualHasData = actualHasData; - if (actualHasData) - fp.ActualValue = actualEnum.Current; - failurePoints.Insert(0, fp); - return false; - } - } - } - private bool StringsEqual(string x, string y) { string s1 = caseInsensitive ? x.ToLower() : x; @@ -279,11 +254,14 @@ namespace NUnit.Framework.Constraints return s1.Equals(s2); } - - private bool EnumerablesEqual(IEnumerable x, IEnumerable y, ref Tolerance tolerance) + + private bool EnumerablesEqual(IEnumerable expected, IEnumerable actual, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance) { - IEnumerator expectedEnum = x.GetEnumerator(); - IEnumerator actualEnum = y.GetEnumerator(); + if (recursionHelper.CheckRecursion(expected, actual)) + return false; + + IEnumerator expectedEnum = expected.GetEnumerator(); + IEnumerator actualEnum = actual.GetEnumerator(); int count; for (count = 0; ; count++) @@ -295,7 +273,7 @@ namespace NUnit.Framework.Constraints return true; if (expectedHasData != actualHasData || - !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance)) + !AreEqual(expectedEnum.Current, actualEnum.Current, recursionHelper, ref tolerance)) { FailurePoint fp = new FailurePoint(); fp.Position = count; @@ -421,5 +399,57 @@ namespace NUnit.Framework.Constraints } #endregion + + #region Nested EnumerableRecursionHelper class + + class EnumerableRecursionHelper + { + readonly Hashtable table = new Hashtable(); + + public bool CheckRecursion(IEnumerable expected, IEnumerable actual) + { + var key = new UnorderedReferencePair(expected, actual); + + if (table.Contains(key)) + return true; + + table.Add(key, null); + return false; + } + + class UnorderedReferencePair : IEquatable + { + private readonly object first; + private readonly object second; + + public UnorderedReferencePair(object first, object second) + { + this.first = first; + this.second = second; + } + + public bool Equals(UnorderedReferencePair other) + { + return (Equals(first, other.first) && Equals(second, other.second)) || + (Equals(first, other.second) && Equals(second, other.first)); + } + + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) return false; + return obj is UnorderedReferencePair && Equals((UnorderedReferencePair) obj); + } + + public override int GetHashCode() + { + unchecked + { + return ((first != null ? first.GetHashCode() : 0)*397) ^ ((second != null ? second.GetHashCode() : 0)*397); + } + } + } + } + + #endregion } -} +} \ No newline at end of file diff --git a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs index 5532138a..44e696c5 100644 --- a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs +++ b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs @@ -307,6 +307,23 @@ namespace NUnit.Framework.Constraints } } + [Test] + public void ContainsWithRecursiveStructure() + { + SelfRecursiveEnumerable item = new SelfRecursiveEnumerable(); + SelfRecursiveEnumerable[] container = new SelfRecursiveEnumerable[] {new SelfRecursiveEnumerable(), item}; + + Assert.That(container, new CollectionContainsConstraint(item)); + } + + class SelfRecursiveEnumerable : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return this; + } + } + #if CS_3_0 || CS_4_0 || CS_5_0 [Test] public void UsesProvidedLambdaExpression() diff --git a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs index 8bd29196..d8177ab6 100644 --- a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs +++ b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs @@ -31,6 +31,28 @@ namespace NUnit.Framework.Constraints } #if CLR_2_0 || CLR_4_0 + [Test] + public void RecursiveEnumerablesAreNotEqual() + { + var a1 = new object[1]; + a1[0] = a1; + var a2 = new object[1]; + a2[0] = a2; + + Assert.False(comparer.AreEqual(a1, a2, ref tolerance)); + } + + [Test] + public void RecursiveEnumerablesAreNotEqual2() + { + var a1 = new object[1]; + var a2 = new object[1]; + a1[0] = a2; + a2[0] = a1; + + Assert.False(comparer.AreEqual(a1, a2, ref tolerance)); + } + [Test] public void IEquatableSuccess() {