diff --git a/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.cpp b/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.cpp index 94be7bc2d23..63d98e2353e 100644 --- a/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.cpp +++ b/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.cpp @@ -57,6 +57,39 @@ struct Token : public PLDHashEntryHdr { double mProbability; // TODO: cache probabilities }; +TokenEnumeration::TokenEnumeration(PLDHashTable* table) + : mEntrySize(table->entrySize), + mEntryCount(table->entryCount), + mEntryOffset(0), + mEntryAddr(table->entryStore) +{ + PRUint32 capacity = PL_DHASH_TABLE_SIZE(table); + mEntryLimit = mEntryAddr + capacity * mEntrySize; +} + +inline bool TokenEnumeration::hasMoreTokens() +{ + return (mEntryOffset < mEntryCount); +} + +inline Token* TokenEnumeration::nextToken() +{ + Token* token = NULL; + PRUint32 entrySize = mEntrySize; + char *entryAddr = mEntryAddr, *entryLimit = mEntryLimit; + while (entryAddr < entryLimit) { + PLDHashEntryHdr* entry = (PLDHashEntryHdr*) entryAddr; + entryAddr += entrySize; + if (PL_DHASH_ENTRY_IS_LIVE(entry)) { + token = NS_STATIC_CAST(Token*, entry); + ++mEntryOffset; + break; + } + } + mEntryAddr = entryAddr; + return token; +} + // PLDHashTable operation callbacks static const void* PR_CALLBACK GetKey(PLDHashTable* table, PLDHashEntryHdr* entry) @@ -235,44 +268,37 @@ void Tokenizer::visit(bool (*f) (Token*, void*), void* data) NS_ASSERTION(visitCount == mTokenTable.entryCount, "visitCount != entryCount!"); } -struct GatherClosure { - PRUint32 count; - PRUint32 offset; - Token* tokens; -}; -static bool gatherTokens(Token* token, void* data) -{ - GatherClosure* closure = NS_REINTERPRET_CAST(GatherClosure*, data); - NS_ASSERTION(closure->offset < closure->count, "too many tokens"); - if (closure->offset < closure->count) - closure->tokens[closure->offset++] = *token; - return true; -} - inline PRUint32 Tokenizer::countTokens() { return mTokenTable.entryCount; } -Token* Tokenizer::getTokens() +Token* Tokenizer::copyTokens() { PRUint32 count = countTokens(); if (count > 0) { Token* tokens = new Token[count]; if (tokens) { - GatherClosure closure = { count, 0, tokens }; - visit(gatherTokens, &closure); + Token* tp = tokens; + TokenEnumeration e(&mTokenTable); + while (e.hasMoreTokens()) + *tp++ = *e.nextToken(); } return tokens; } return NULL; } +inline TokenEnumeration Tokenizer::getTokens() +{ + return TokenEnumeration(&mTokenTable); +} + class TokenAnalyzer { public: virtual ~TokenAnalyzer() {} - virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) = 0; + virtual void analyzeTokens(const char* source, Tokenizer& tokenizer) = 0; }; /** @@ -356,20 +382,20 @@ NS_IMETHODIMP TokenStreamListener::OnDataAvailable(nsIRequest *aRequest, nsISupp /* consume the tokens up to the last legal token delimiter in the buffer. */ totalCount = (readCount + mLeftOverCount); buffer[totalCount] = '\0'; - char* last_delimiter = NULL; + char* lastDelimiter = NULL; char* scan = buffer + totalCount; while (scan > buffer) { if (strchr(kBayesianFilterTokenDelimiters, *--scan)) { - last_delimiter = scan; + lastDelimiter = scan; break; } } - if (last_delimiter) { - *last_delimiter = '\0'; + if (lastDelimiter) { + *lastDelimiter = '\0'; mTokenizer.tokenize(buffer); - PRUint32 consumedCount = 1 + (last_delimiter - buffer); + PRUint32 consumedCount = 1 + (lastDelimiter - buffer); mLeftOverCount = totalCount - consumedCount; if (mLeftOverCount) memmove(buffer, buffer + consumedCount, mLeftOverCount); @@ -402,14 +428,8 @@ NS_IMETHODIMP TokenStreamListener::OnStopRequest(nsIRequest *aRequest, nsISuppor } /* finally, analyze the tokenized message. */ - if (mAnalyzer) { - PRUint32 count = mTokenizer.countTokens(); - Token* tokens = mTokenizer.getTokens(); - if (count && tokens) { - mAnalyzer->analyzeTokens(mTokenSource.get(), count, tokens); - delete[] tokens; - } - } + if (mAnalyzer) + mAnalyzer->analyzeTokens(mTokenSource.get(), mTokenizer); return NS_OK; } @@ -438,9 +458,9 @@ public: { } - virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) + virtual void analyzeTokens(const char* source, Tokenizer& tokenizer) { - mFilter->classifyMessage(count, tokens, source, mListener); + mFilter->classifyMessage(tokenizer, source, mListener); } private: @@ -498,11 +518,14 @@ static int compareTokens(const void* p1, const void* p2, void* /* data */) inline double max(double x, double y) { return (x > y ? x : y); } inline double min(double x, double y) { return (x < y ? x : y); } -void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const char* messageURI, +void nsBayesianFilter::classifyMessage(Tokenizer& tokenizer, const char* messageURI, nsIJunkMailClassificationListener* listener) { + Token* tokens = tokenizer.copyTokens(); + if (!tokens) return; + /* run the kernel of the Graham filter algorithm here. */ - PRUint32 i; + PRUint32 i, count = tokenizer.countTokens(); double ngood = mGoodCount, nbad = mBadCount; for (i = 0; i < count; ++i) { Token& token = tokens[i]; @@ -546,6 +569,8 @@ void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const cha double prob = (prod1 / (prod1 + prod2)); bool isJunk = (prob >= 0.90); + delete[] tokens; + if (listener) listener->OnMessageClassified(messageURI, isJunk ? nsMsgJunkStatus(nsIJunkMailPlugin::JUNK) : nsMsgJunkStatus(nsIJunkMailPlugin::GOOD)); } @@ -576,9 +601,9 @@ NS_IMETHODIMP nsBayesianFilter::SetBatchUpdate(PRBool aBatchUpdate) { mBatchUpdate = aBatchUpdate; - if (mBatchUpdate && mTrainingDataDirty) + if (!mBatchUpdate && mTrainingDataDirty) writeTrainingData(); - + return NS_OK; } @@ -614,9 +639,9 @@ public: { } - virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) + virtual void analyzeTokens(const char* source, Tokenizer& tokenizer) { - mFilter->observeMessage(count, tokens, source, mOldClassification, + mFilter->observeMessage(tokenizer, source, mOldClassification, mNewClassification, mListener); } @@ -628,32 +653,33 @@ private: nsMsgJunkStatus mNewClassification; }; -static void forgetTokens(Tokenizer& corpus, Token tokens[], PRUint32 count) +static void forgetTokens(Tokenizer& corpus, TokenEnumeration tokens) { - for (PRUint32 i = 0; i < count; ++i) { - Token& token = tokens[i]; - corpus.remove(token.mWord, token.mCount); + while (tokens.hasMoreTokens()) { + Token* token = tokens.nextToken(); + corpus.remove(token->mWord, token->mCount); } } -static void rememberTokens(Tokenizer& corpus, Token tokens[], PRUint32 count) +static void rememberTokens(Tokenizer& corpus, TokenEnumeration tokens) { - for (PRUint32 i = 0; i < count; ++i) { - Token& token = tokens[i]; - corpus.add(token.mWord, token.mCount); + while (tokens.hasMoreTokens()) { + Token* token = tokens.nextToken(); + corpus.add(token->mWord, token->mCount); } } -void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char* messageURL, +void nsBayesianFilter::observeMessage(Tokenizer& tokenizer, const char* messageURL, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener) { + TokenEnumeration tokens = tokenizer.getTokens(); switch (oldClassification) { case nsIJunkMailPlugin::JUNK: // remove tokens from junk corpus. if (mBadCount > 0) { --mBadCount; - forgetTokens(mBadTokens, tokens, count); + forgetTokens(mBadTokens, tokens); mTrainingDataDirty = PR_TRUE; } break; @@ -661,7 +687,7 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char // remove tokens from good corpus. if (mGoodCount > 0) { --mGoodCount; - forgetTokens(mGoodTokens, tokens, count); + forgetTokens(mGoodTokens, tokens); mTrainingDataDirty = PR_TRUE; } break; @@ -670,13 +696,13 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char case nsIJunkMailPlugin::JUNK: // put tokens into junk corpus. ++mBadCount; - rememberTokens(mBadTokens, tokens, count); + rememberTokens(mBadTokens, tokens); mTrainingDataDirty = PR_TRUE; break; case nsIJunkMailPlugin::GOOD: // put tokens into good corpus. ++mGoodCount; - rememberTokens(mGoodTokens, tokens, count); + rememberTokens(mGoodTokens, tokens); mTrainingDataDirty = PR_TRUE; break; } @@ -744,21 +770,17 @@ static bool writeTokens(FILE* stream, Tokenizer& tokenizer) return false; if (tokenCount > 0) { - Token* tokens = tokenizer.getTokens(); - if (!tokens) return false; - + TokenEnumeration tokens = tokenizer.getTokens(); for (PRUint32 i = 0; i < tokenCount; ++i) { - Token& token = tokens[i]; - if (writeUInt32(stream, token.mCount) != 1) + Token* token = tokens.nextToken(); + if (writeUInt32(stream, token->mCount) != 1) break; - PRUint32 tokenLength = token.mLength; + PRUint32 tokenLength = token->mLength; if (writeUInt32(stream, tokenLength) != 1) break; - if (fwrite(token.mWord, tokenLength, 1, stream) != 1) + if (fwrite(token->mWord, tokenLength, 1, stream) != 1) break; } - - delete[] tokens; } return true; diff --git a/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.h b/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.h index 581167811c7..03dd98bd30b 100644 --- a/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.h +++ b/mailnews/extensions/bayesian-spam-filter/src/nsBayesianFilter.h @@ -48,8 +48,26 @@ #include "plarena.h" class Token; +class TokenEnumeration; class TokenAnalyzer; +/** + * Helper class to enumerate Token objects in a PLDHashTable + * safely and without copying (see bugzilla #174859). The + * enumeration is safe to use until a PL_DHASH_ADD + * or PL_DHASH_REMOVE is performed on the table. + */ +class TokenEnumeration { +public: + TokenEnumeration(PLDHashTable* table); + bool hasMoreTokens(); + Token* nextToken(); + +private: + PRUint32 mEntrySize, mEntryCount, mEntryOffset; + char *mEntryAddr, *mEntryLimit; +}; + class Tokenizer { public: Tokenizer(); @@ -62,7 +80,8 @@ public: void remove(const char* word, PRUint32 count = 1); PRUint32 countTokens(); - Token* getTokens(); + Token* copyTokens(); + TokenEnumeration getTokens(); /** * Assumes that text is mutable and @@ -98,8 +117,8 @@ public: virtual ~nsBayesianFilter(); nsresult tokenizeMessage(const char* messageURI, TokenAnalyzer* analyzer); - void classifyMessage(PRUint32 count, Token tokens[], const char* messageURI, nsIJunkMailClassificationListener* listener); - void observeMessage(PRUint32 count, Token tokens[], const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener); + void classifyMessage(Tokenizer& tokens, const char* messageURI, nsIJunkMailClassificationListener* listener); + void observeMessage(Tokenizer& tokens, const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener); void writeTrainingData(); void readTrainingData();