From 9cb0b76439bf54228b1f085ddb716c9d0fbaee7d Mon Sep 17 00:00:00 2001 From: Jan Varga Date: Fri, 2 Feb 2024 09:43:10 +0000 Subject: [PATCH] Bug 1873140 - Avoid using invalid enum values during EnumSet iteration; r=glandium Differential Revision: https://phabricator.services.mozilla.com/D197824 --- mfbt/EnumSet.h | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/mfbt/EnumSet.h b/mfbt/EnumSet.h index 1b0ef2e5d12c..712e03d3f3d9 100644 --- a/mfbt/EnumSet.h +++ b/mfbt/EnumSet.h @@ -182,9 +182,7 @@ class EnumSet { /** * Test is an element is contained in the set. */ - bool contains(T aEnum) const { - return static_cast(mBitField & BitFor(aEnum)); - } + bool contains(T aEnum) const { return HasBitFor(aEnum); } /** * Test if a set is contained in the set. @@ -241,7 +239,7 @@ class EnumSet { mVersion = mSet->mVersion; #endif MOZ_ASSERT(aPos <= kMaxBits); - if (aPos != kMaxBits && !mSet->contains(T(mPos))) { + if (aPos != kMaxBits && !mSet->HasBitAt(mPos)) { ++*this; } } @@ -278,7 +276,7 @@ class EnumSet { T operator*() const { MOZ_ASSERT(mSet); MOZ_ASSERT(mPos < kMaxBits); - MOZ_ASSERT(mSet->contains(T(mPos))); + MOZ_ASSERT(mSet->HasBitAt(mPos)); checkVersion(); return T(mPos); } @@ -289,7 +287,7 @@ class EnumSet { checkVersion(); do { mPos++; - } while (mPos < kMaxBits && !mSet->contains(T(mPos))); + } while (mPos < kMaxBits && !mSet->HasBitAt(mPos)); return *this; } }; @@ -300,17 +298,30 @@ class EnumSet { private: constexpr static Serialized BitFor(T aEnum) { - auto bitNumber = static_cast(aEnum); - MOZ_DIAGNOSTIC_ASSERT(bitNumber < kMaxBits); + const auto pos = static_cast(aEnum); + return BitAt(pos); + } + + constexpr static Serialized BitAt(size_t aPos) { + MOZ_DIAGNOSTIC_ASSERT(aPos < kMaxBits); if constexpr (std::is_unsigned_v) { - return static_cast(Serialized{1} << bitNumber); + return static_cast(Serialized{1} << aPos); } else { Serialized bitField; - bitField[bitNumber] = true; + bitField[aPos] = true; return bitField; } } + constexpr bool HasBitFor(T aEnum) const { + const auto pos = static_cast(aEnum); + return HasBitAt(pos); + } + + constexpr bool HasBitAt(size_t aPos) const { + return static_cast(mBitField & BitAt(aPos)); + } + constexpr void IncVersion() { #ifdef DEBUG mVersion++;