diff --git a/container/intsets/sparse.go b/container/intsets/sparse.go index fa32a0f15..8847febf1 100644 --- a/container/intsets/sparse.go +++ b/container/intsets/sparse.go @@ -14,8 +14,6 @@ package intsets // import "golang.org/x/tools/container/intsets" // TODO(adonovan): -// - Add SymmetricDifference(x, y *Sparse), i.e. x ∆ y. -// - Add SubsetOf (x∖y=∅) and Intersects (x∩y≠∅) predicates. // - Add InsertAll(...int), RemoveAll(...int) // - Add 'bool changed' results for {Intersection,Difference}With too. // @@ -485,6 +483,29 @@ func (s *Sparse) Intersection(x, y *Sparse) { s.discardTail(sb) } +// Intersects reports whether s ∩ x ≠ ∅. +func (s *Sparse) Intersects(x *Sparse) bool { + sb := s.start() + xb := x.start() + for sb != &s.root && xb != &x.root { + switch { + case xb.offset < sb.offset: + xb = xb.next + case xb.offset > sb.offset: + sb = sb.next + default: + for i := range sb.bits { + if sb.bits[i]&xb.bits[i] != 0 { + return true + } + } + sb = sb.next + xb = xb.next + } + } + return false +} + // UnionWith sets s to the union s ∪ x, and reports whether s grew. func (s *Sparse) UnionWith(x *Sparse) bool { if s == x { @@ -667,6 +688,146 @@ func (s *Sparse) Difference(x, y *Sparse) { s.discardTail(sb) } +// SymmetricDifferenceWith sets s to the symmetric difference s ∆ x. +func (s *Sparse) SymmetricDifferenceWith(x *Sparse) { + if s == x { + s.Clear() + return + } + + sb := s.start() + xb := x.start() + for xb != &x.root && sb != &s.root { + switch { + case sb.offset < xb.offset: + sb = sb.next + case xb.offset < sb.offset: + nb := s.insertBlockBefore(sb) + nb.offset = xb.offset + nb.bits = xb.bits + xb = xb.next + default: + var sum word + for i := range sb.bits { + r := sb.bits[i] ^ xb.bits[i] + sb.bits[i] = r + sum |= r + } + sb = sb.next + xb = xb.next + if sum == 0 { + s.removeBlock(sb.prev) + } + } + } + + for xb != &x.root { // append the tail of x to s + sb = s.insertBlockBefore(sb) + sb.offset = xb.offset + sb.bits = xb.bits + sb = sb.next + xb = xb.next + } +} + +// SymmetricDifference sets s to the symmetric difference x ∆ y. +func (s *Sparse) SymmetricDifference(x, y *Sparse) { + switch { + case x == y: + s.Clear() + return + case s == x: + s.SymmetricDifferenceWith(y) + return + case s == y: + s.SymmetricDifferenceWith(x) + return + } + + sb := s.start() + xb := x.start() + yb := y.start() + for xb != &x.root && yb != &y.root { + if sb == &s.root { + sb = s.insertBlockBefore(sb) + } + switch { + case yb.offset < xb.offset: + sb.offset = yb.offset + sb.bits = yb.bits + sb = sb.next + yb = yb.next + case xb.offset < yb.offset: + sb.offset = xb.offset + sb.bits = xb.bits + sb = sb.next + xb = xb.next + default: + var sum word + for i := range sb.bits { + r := xb.bits[i] ^ yb.bits[i] + sb.bits[i] = r + sum |= r + } + if sum != 0 { + sb.offset = xb.offset + sb = sb.next + } + xb = xb.next + yb = yb.next + } + } + + for xb != &x.root { // append the tail of x to s + if sb == &s.root { + sb = s.insertBlockBefore(sb) + } + sb.offset = xb.offset + sb.bits = xb.bits + sb = sb.next + xb = xb.next + } + + for yb != &y.root { // append the tail of y to s + if sb == &s.root { + sb = s.insertBlockBefore(sb) + } + sb.offset = yb.offset + sb.bits = yb.bits + sb = sb.next + yb = yb.next + } + + s.discardTail(sb) +} + +// SubsetOf reports whether s ∖ x = ∅. +func (s *Sparse) SubsetOf(x *Sparse) bool { + if s == x { + return true + } + + sb := s.start() + xb := x.start() + for sb != &s.root { + switch { + case xb == &x.root || xb.offset > sb.offset: + return false + case xb.offset < sb.offset: + xb = xb.next + default: + for i := range sb.bits { + if sb.bits[i]&^xb.bits[i] != 0 { + return false + } + } + sb = sb.next + xb = xb.next + } + } + return true +} + // Equals reports whether the sets s and t have the same elements. func (s *Sparse) Equals(t *Sparse) bool { if s == t { diff --git a/container/intsets/sparse_test.go b/container/intsets/sparse_test.go index ec303d913..34b9a4e7f 100644 --- a/container/intsets/sparse_test.go +++ b/container/intsets/sparse_test.go @@ -397,6 +397,47 @@ func TestSetOperations(t *testing.T) { D.bits.Copy(&X.bits) D.bits.DifferenceWith(&D.bits) D.check(t, "D.DifferenceWith(D)") + + // SD.SymmetricDifference(X, Y) + SD := makePset() + SD.bits.SymmetricDifference(&X.bits, &Y.bits) + for n := range X.hash { + if !Y.hash[n] { + SD.hash[n] = true + } + } + for n := range Y.hash { + if !X.hash[n] { + SD.hash[n] = true + } + } + SD.check(t, "SD.SymmetricDifference(X, Y)") + + // X.SymmetricDifferenceWith(Y) + SD.bits.Copy(&X.bits) + SD.bits.SymmetricDifferenceWith(&Y.bits) + SD.check(t, "X.SymmetricDifference(Y)") + + // Y.SymmetricDifferenceWith(X) + SD.bits.Copy(&Y.bits) + SD.bits.SymmetricDifferenceWith(&X.bits) + SD.check(t, "Y.SymmetricDifference(X)") + + // SD.SymmetricDifference(X, X) + SD.bits.SymmetricDifference(&X.bits, &X.bits) + SD.hash = nil + SD.check(t, "SD.SymmetricDifference(X, X)") + + // SD.SymmetricDifference(X, Copy(X)) + X2 := makePset() + X2.bits.Copy(&X.bits) + SD.bits.SymmetricDifference(&X.bits, &X2.bits) + SD.check(t, "SD.SymmetricDifference(X, Copy(X))") + + // Copy(X).SymmetricDifferenceWith(X) + SD.bits.Copy(&X.bits) + SD.bits.SymmetricDifferenceWith(&X.bits) + SD.check(t, "Copy(X).SymmetricDifferenceWith(X)") } } @@ -417,6 +458,82 @@ func TestIntersectionWith(t *testing.T) { } } +func TestIntersects(t *testing.T) { + prng := rand.New(rand.NewSource(0)) + + for i := uint(0); i < 12; i++ { + X, Y := randomPset(prng, 1<