diff --git a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs
index 1d006347..e4bc9a56 100644
--- a/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs
+++ b/src/NUnitFramework/framework/Constraints/NUnitEqualityComparer.cs
@@ -5,6 +5,7 @@
// ****************************************************************
using System;
+using System.Collections.ObjectModel;
using System.IO;
using System.Collections;
#if CLR_2_0 || CLR_4_0
@@ -97,6 +98,11 @@ namespace NUnit.Framework.Constraints
/// Compares two objects for equality within a tolerance.
///
public bool AreEqual(object x, object y, ref Tolerance tolerance)
+ {
+ return AreEqual(x, y, new EnumerableRecursionHelper(), ref tolerance);
+ }
+
+ private bool AreEqual(object x, object y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance)
{
this.failurePoints = new ArrayList();
@@ -105,9 +111,9 @@ namespace NUnit.Framework.Constraints
if (x == null || y == null)
return false;
-
- if (object.ReferenceEquals(x, y))
- return true;
+
+ if (object.ReferenceEquals(x, y))
+ return true;
Type xType = x.GetType();
Type yType = y.GetType();
@@ -117,38 +123,38 @@ namespace NUnit.Framework.Constraints
return externalComparer.AreEqual(x, y);
if (xType.IsArray && yType.IsArray && !compareAsCollection)
- return ArraysEqual((Array)x, (Array)y, ref tolerance);
+ return ArraysEqual((Array) x, (Array) y, recursionHelper, ref tolerance);
if (x is IDictionary && y is IDictionary)
- return DictionariesEqual((IDictionary)x, (IDictionary)y, ref tolerance);
+ return DictionariesEqual((IDictionary) x, (IDictionary) y, recursionHelper, ref tolerance);
//if (x is ICollection && y is ICollection)
// return CollectionsEqual((ICollection)x, (ICollection)y, ref tolerance);
if (x is IEnumerable && y is IEnumerable && !(x is string && y is string))
- return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance);
+ return EnumerablesEqual((IEnumerable) x, (IEnumerable) y, recursionHelper, ref tolerance);
if (x is string && y is string)
- return StringsEqual((string)x, (string)y);
+ return StringsEqual((string) x, (string) y);
if (x is Stream && y is Stream)
- return StreamsEqual((Stream)x, (Stream)y);
+ return StreamsEqual((Stream) x, (Stream) y);
if (x is DirectoryInfo && y is DirectoryInfo)
- return DirectoriesEqual((DirectoryInfo)x, (DirectoryInfo)y);
+ return DirectoriesEqual((DirectoryInfo) x, (DirectoryInfo) y);
if (Numerics.IsNumericType(x) && Numerics.IsNumericType(y))
return Numerics.AreEqual(x, y, ref tolerance);
if (tolerance != null && tolerance.Value is TimeSpan)
{
- TimeSpan amount = (TimeSpan)tolerance.Value;
+ TimeSpan amount = (TimeSpan) tolerance.Value;
if (x is DateTime && y is DateTime)
- return ((DateTime)x - (DateTime)y).Duration() <= amount;
+ return ((DateTime) x - (DateTime) y).Duration() <= amount;
if (x is TimeSpan && y is TimeSpan)
- return ((TimeSpan)x - (TimeSpan)y).Duration() <= amount;
+ return ((TimeSpan) x - (TimeSpan) y).Duration() <= amount;
}
#if CLR_2_0 || CLR_4_0
@@ -211,7 +217,7 @@ namespace NUnit.Framework.Constraints
///
/// Helper method to compare two arrays
///
- private bool ArraysEqual(Array x, Array y, ref Tolerance tolerance)
+ private bool ArraysEqual(Array x, Array y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance)
{
int rank = x.Rank;
@@ -222,10 +228,10 @@ namespace NUnit.Framework.Constraints
if (x.GetLength(r) != y.GetLength(r))
return false;
- return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, ref tolerance);
+ return EnumerablesEqual((IEnumerable)x, (IEnumerable)y, recursionHelper, ref tolerance);
}
- private bool DictionariesEqual(IDictionary x, IDictionary y, ref Tolerance tolerance)
+ private bool DictionariesEqual(IDictionary x, IDictionary y, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance)
{
if (x.Count != y.Count)
return false;
@@ -235,43 +241,12 @@ namespace NUnit.Framework.Constraints
return false;
foreach (object key in x.Keys)
- if (!AreEqual(x[key], y[key], ref tolerance))
+ if (!AreEqual(x[key], y[key], recursionHelper, ref tolerance))
return false;
return true;
}
- private bool CollectionsEqual(ICollection x, ICollection y, ref Tolerance tolerance)
- {
- IEnumerator expectedEnum = x.GetEnumerator();
- IEnumerator actualEnum = y.GetEnumerator();
-
- int count;
- for (count = 0; ; count++)
- {
- bool expectedHasData = expectedEnum.MoveNext();
- bool actualHasData = actualEnum.MoveNext();
-
- if (!expectedHasData && !actualHasData)
- return true;
-
- if (expectedHasData != actualHasData ||
- !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance))
- {
- FailurePoint fp = new FailurePoint();
- fp.Position = count;
- fp.ExpectedHasData = expectedHasData;
- if (expectedHasData)
- fp.ExpectedValue = expectedEnum.Current;
- fp.ActualHasData = actualHasData;
- if (actualHasData)
- fp.ActualValue = actualEnum.Current;
- failurePoints.Insert(0, fp);
- return false;
- }
- }
- }
-
private bool StringsEqual(string x, string y)
{
string s1 = caseInsensitive ? x.ToLower() : x;
@@ -279,11 +254,14 @@ namespace NUnit.Framework.Constraints
return s1.Equals(s2);
}
-
- private bool EnumerablesEqual(IEnumerable x, IEnumerable y, ref Tolerance tolerance)
+
+ private bool EnumerablesEqual(IEnumerable expected, IEnumerable actual, EnumerableRecursionHelper recursionHelper, ref Tolerance tolerance)
{
- IEnumerator expectedEnum = x.GetEnumerator();
- IEnumerator actualEnum = y.GetEnumerator();
+ if (recursionHelper.CheckRecursion(expected, actual))
+ return false;
+
+ IEnumerator expectedEnum = expected.GetEnumerator();
+ IEnumerator actualEnum = actual.GetEnumerator();
int count;
for (count = 0; ; count++)
@@ -295,7 +273,7 @@ namespace NUnit.Framework.Constraints
return true;
if (expectedHasData != actualHasData ||
- !AreEqual(expectedEnum.Current, actualEnum.Current, ref tolerance))
+ !AreEqual(expectedEnum.Current, actualEnum.Current, recursionHelper, ref tolerance))
{
FailurePoint fp = new FailurePoint();
fp.Position = count;
@@ -421,5 +399,57 @@ namespace NUnit.Framework.Constraints
}
#endregion
+
+ #region Nested EnumerableRecursionHelper class
+
+ class EnumerableRecursionHelper
+ {
+ readonly Hashtable table = new Hashtable();
+
+ public bool CheckRecursion(IEnumerable expected, IEnumerable actual)
+ {
+ var key = new UnorderedReferencePair(expected, actual);
+
+ if (table.Contains(key))
+ return true;
+
+ table.Add(key, null);
+ return false;
+ }
+
+ class UnorderedReferencePair : IEquatable
+ {
+ private readonly object first;
+ private readonly object second;
+
+ public UnorderedReferencePair(object first, object second)
+ {
+ this.first = first;
+ this.second = second;
+ }
+
+ public bool Equals(UnorderedReferencePair other)
+ {
+ return (Equals(first, other.first) && Equals(second, other.second)) ||
+ (Equals(first, other.second) && Equals(second, other.first));
+ }
+
+ public override bool Equals(object obj)
+ {
+ if (ReferenceEquals(null, obj)) return false;
+ return obj is UnorderedReferencePair && Equals((UnorderedReferencePair) obj);
+ }
+
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ return ((first != null ? first.GetHashCode() : 0)*397) ^ ((second != null ? second.GetHashCode() : 0)*397);
+ }
+ }
+ }
+ }
+
+ #endregion
}
-}
+}
\ No newline at end of file
diff --git a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs
index 5532138a..44e696c5 100644
--- a/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs
+++ b/src/NUnitFramework/tests/Constraints/CollectionConstraintTests.cs
@@ -307,6 +307,23 @@ namespace NUnit.Framework.Constraints
}
}
+ [Test]
+ public void ContainsWithRecursiveStructure()
+ {
+ SelfRecursiveEnumerable item = new SelfRecursiveEnumerable();
+ SelfRecursiveEnumerable[] container = new SelfRecursiveEnumerable[] {new SelfRecursiveEnumerable(), item};
+
+ Assert.That(container, new CollectionContainsConstraint(item));
+ }
+
+ class SelfRecursiveEnumerable : IEnumerable
+ {
+ public IEnumerator GetEnumerator()
+ {
+ yield return this;
+ }
+ }
+
#if CS_3_0 || CS_4_0 || CS_5_0
[Test]
public void UsesProvidedLambdaExpression()
diff --git a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs
index 8bd29196..d8177ab6 100644
--- a/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs
+++ b/src/NUnitFramework/tests/Constraints/NUnitEqualityComparerTests.cs
@@ -31,6 +31,28 @@ namespace NUnit.Framework.Constraints
}
#if CLR_2_0 || CLR_4_0
+ [Test]
+ public void RecursiveEnumerablesAreNotEqual()
+ {
+ var a1 = new object[1];
+ a1[0] = a1;
+ var a2 = new object[1];
+ a2[0] = a2;
+
+ Assert.False(comparer.AreEqual(a1, a2, ref tolerance));
+ }
+
+ [Test]
+ public void RecursiveEnumerablesAreNotEqual2()
+ {
+ var a1 = new object[1];
+ var a2 = new object[1];
+ a1[0] = a2;
+ a2[0] = a1;
+
+ Assert.False(comparer.AreEqual(a1, a2, ref tolerance));
+ }
+
[Test]
public void IEquatableSuccess()
{