Added TokenEnumeration, which provides efficient direct enumeration of the Tokenizer's PLDHashTable. r/sr=brendan, a=asa [not part of build]

This commit is contained in:
beard%netscape.com 2002-10-18 06:31:47 +00:00
Родитель 09444442fd
Коммит 737a579e5e
2 изменённых файлов: 105 добавлений и 64 удалений

Просмотреть файл

@ -57,6 +57,39 @@ struct Token : public PLDHashEntryHdr {
double mProbability; // TODO: cache probabilities double mProbability; // TODO: cache probabilities
}; };
TokenEnumeration::TokenEnumeration(PLDHashTable* table)
: mEntrySize(table->entrySize),
mEntryCount(table->entryCount),
mEntryOffset(0),
mEntryAddr(table->entryStore)
{
PRUint32 capacity = PL_DHASH_TABLE_SIZE(table);
mEntryLimit = mEntryAddr + capacity * mEntrySize;
}
inline bool TokenEnumeration::hasMoreTokens()
{
return (mEntryOffset < mEntryCount);
}
inline Token* TokenEnumeration::nextToken()
{
Token* token = NULL;
PRUint32 entrySize = mEntrySize;
char *entryAddr = mEntryAddr, *entryLimit = mEntryLimit;
while (entryAddr < entryLimit) {
PLDHashEntryHdr* entry = (PLDHashEntryHdr*) entryAddr;
entryAddr += entrySize;
if (PL_DHASH_ENTRY_IS_LIVE(entry)) {
token = NS_STATIC_CAST(Token*, entry);
++mEntryOffset;
break;
}
}
mEntryAddr = entryAddr;
return token;
}
// PLDHashTable operation callbacks // PLDHashTable operation callbacks
static const void* PR_CALLBACK GetKey(PLDHashTable* table, PLDHashEntryHdr* entry) static const void* PR_CALLBACK GetKey(PLDHashTable* table, PLDHashEntryHdr* entry)
@ -235,44 +268,37 @@ void Tokenizer::visit(bool (*f) (Token*, void*), void* data)
NS_ASSERTION(visitCount == mTokenTable.entryCount, "visitCount != entryCount!"); NS_ASSERTION(visitCount == mTokenTable.entryCount, "visitCount != entryCount!");
} }
struct GatherClosure {
PRUint32 count;
PRUint32 offset;
Token* tokens;
};
static bool gatherTokens(Token* token, void* data)
{
GatherClosure* closure = NS_REINTERPRET_CAST(GatherClosure*, data);
NS_ASSERTION(closure->offset < closure->count, "too many tokens");
if (closure->offset < closure->count)
closure->tokens[closure->offset++] = *token;
return true;
}
inline PRUint32 Tokenizer::countTokens() inline PRUint32 Tokenizer::countTokens()
{ {
return mTokenTable.entryCount; return mTokenTable.entryCount;
} }
Token* Tokenizer::getTokens() Token* Tokenizer::copyTokens()
{ {
PRUint32 count = countTokens(); PRUint32 count = countTokens();
if (count > 0) { if (count > 0) {
Token* tokens = new Token[count]; Token* tokens = new Token[count];
if (tokens) { if (tokens) {
GatherClosure closure = { count, 0, tokens }; Token* tp = tokens;
visit(gatherTokens, &closure); TokenEnumeration e(&mTokenTable);
while (e.hasMoreTokens())
*tp++ = *e.nextToken();
} }
return tokens; return tokens;
} }
return NULL; return NULL;
} }
inline TokenEnumeration Tokenizer::getTokens()
{
return TokenEnumeration(&mTokenTable);
}
class TokenAnalyzer { class TokenAnalyzer {
public: public:
virtual ~TokenAnalyzer() {} virtual ~TokenAnalyzer() {}
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) = 0; virtual void analyzeTokens(const char* source, Tokenizer& tokenizer) = 0;
}; };
/** /**
@ -356,20 +382,20 @@ NS_IMETHODIMP TokenStreamListener::OnDataAvailable(nsIRequest *aRequest, nsISupp
/* consume the tokens up to the last legal token delimiter in the buffer. */ /* consume the tokens up to the last legal token delimiter in the buffer. */
totalCount = (readCount + mLeftOverCount); totalCount = (readCount + mLeftOverCount);
buffer[totalCount] = '\0'; buffer[totalCount] = '\0';
char* last_delimiter = NULL; char* lastDelimiter = NULL;
char* scan = buffer + totalCount; char* scan = buffer + totalCount;
while (scan > buffer) { while (scan > buffer) {
if (strchr(kBayesianFilterTokenDelimiters, *--scan)) { if (strchr(kBayesianFilterTokenDelimiters, *--scan)) {
last_delimiter = scan; lastDelimiter = scan;
break; break;
} }
} }
if (last_delimiter) { if (lastDelimiter) {
*last_delimiter = '\0'; *lastDelimiter = '\0';
mTokenizer.tokenize(buffer); mTokenizer.tokenize(buffer);
PRUint32 consumedCount = 1 + (last_delimiter - buffer); PRUint32 consumedCount = 1 + (lastDelimiter - buffer);
mLeftOverCount = totalCount - consumedCount; mLeftOverCount = totalCount - consumedCount;
if (mLeftOverCount) if (mLeftOverCount)
memmove(buffer, buffer + consumedCount, mLeftOverCount); memmove(buffer, buffer + consumedCount, mLeftOverCount);
@ -402,14 +428,8 @@ NS_IMETHODIMP TokenStreamListener::OnStopRequest(nsIRequest *aRequest, nsISuppor
} }
/* finally, analyze the tokenized message. */ /* finally, analyze the tokenized message. */
if (mAnalyzer) { if (mAnalyzer)
PRUint32 count = mTokenizer.countTokens(); mAnalyzer->analyzeTokens(mTokenSource.get(), mTokenizer);
Token* tokens = mTokenizer.getTokens();
if (count && tokens) {
mAnalyzer->analyzeTokens(mTokenSource.get(), count, tokens);
delete[] tokens;
}
}
return NS_OK; return NS_OK;
} }
@ -438,9 +458,9 @@ public:
{ {
} }
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) virtual void analyzeTokens(const char* source, Tokenizer& tokenizer)
{ {
mFilter->classifyMessage(count, tokens, source, mListener); mFilter->classifyMessage(tokenizer, source, mListener);
} }
private: private:
@ -498,11 +518,14 @@ static int compareTokens(const void* p1, const void* p2, void* /* data */)
inline double max(double x, double y) { return (x > y ? x : y); } inline double max(double x, double y) { return (x > y ? x : y); }
inline double min(double x, double y) { return (x < y ? x : y); } inline double min(double x, double y) { return (x < y ? x : y); }
void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const char* messageURI, void nsBayesianFilter::classifyMessage(Tokenizer& tokenizer, const char* messageURI,
nsIJunkMailClassificationListener* listener) nsIJunkMailClassificationListener* listener)
{ {
Token* tokens = tokenizer.copyTokens();
if (!tokens) return;
/* run the kernel of the Graham filter algorithm here. */ /* run the kernel of the Graham filter algorithm here. */
PRUint32 i; PRUint32 i, count = tokenizer.countTokens();
double ngood = mGoodCount, nbad = mBadCount; double ngood = mGoodCount, nbad = mBadCount;
for (i = 0; i < count; ++i) { for (i = 0; i < count; ++i) {
Token& token = tokens[i]; Token& token = tokens[i];
@ -546,6 +569,8 @@ void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const cha
double prob = (prod1 / (prod1 + prod2)); double prob = (prod1 / (prod1 + prod2));
bool isJunk = (prob >= 0.90); bool isJunk = (prob >= 0.90);
delete[] tokens;
if (listener) if (listener)
listener->OnMessageClassified(messageURI, isJunk ? nsMsgJunkStatus(nsIJunkMailPlugin::JUNK) : nsMsgJunkStatus(nsIJunkMailPlugin::GOOD)); listener->OnMessageClassified(messageURI, isJunk ? nsMsgJunkStatus(nsIJunkMailPlugin::JUNK) : nsMsgJunkStatus(nsIJunkMailPlugin::GOOD));
} }
@ -576,7 +601,7 @@ NS_IMETHODIMP nsBayesianFilter::SetBatchUpdate(PRBool aBatchUpdate)
{ {
mBatchUpdate = aBatchUpdate; mBatchUpdate = aBatchUpdate;
if (mBatchUpdate && mTrainingDataDirty) if (!mBatchUpdate && mTrainingDataDirty)
writeTrainingData(); writeTrainingData();
return NS_OK; return NS_OK;
@ -614,9 +639,9 @@ public:
{ {
} }
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) virtual void analyzeTokens(const char* source, Tokenizer& tokenizer)
{ {
mFilter->observeMessage(count, tokens, source, mOldClassification, mFilter->observeMessage(tokenizer, source, mOldClassification,
mNewClassification, mListener); mNewClassification, mListener);
} }
@ -628,32 +653,33 @@ private:
nsMsgJunkStatus mNewClassification; nsMsgJunkStatus mNewClassification;
}; };
static void forgetTokens(Tokenizer& corpus, Token tokens[], PRUint32 count) static void forgetTokens(Tokenizer& corpus, TokenEnumeration tokens)
{ {
for (PRUint32 i = 0; i < count; ++i) { while (tokens.hasMoreTokens()) {
Token& token = tokens[i]; Token* token = tokens.nextToken();
corpus.remove(token.mWord, token.mCount); corpus.remove(token->mWord, token->mCount);
} }
} }
static void rememberTokens(Tokenizer& corpus, Token tokens[], PRUint32 count) static void rememberTokens(Tokenizer& corpus, TokenEnumeration tokens)
{ {
for (PRUint32 i = 0; i < count; ++i) { while (tokens.hasMoreTokens()) {
Token& token = tokens[i]; Token* token = tokens.nextToken();
corpus.add(token.mWord, token.mCount); corpus.add(token->mWord, token->mCount);
} }
} }
void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char* messageURL, void nsBayesianFilter::observeMessage(Tokenizer& tokenizer, const char* messageURL,
nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification,
nsIJunkMailClassificationListener* listener) nsIJunkMailClassificationListener* listener)
{ {
TokenEnumeration tokens = tokenizer.getTokens();
switch (oldClassification) { switch (oldClassification) {
case nsIJunkMailPlugin::JUNK: case nsIJunkMailPlugin::JUNK:
// remove tokens from junk corpus. // remove tokens from junk corpus.
if (mBadCount > 0) { if (mBadCount > 0) {
--mBadCount; --mBadCount;
forgetTokens(mBadTokens, tokens, count); forgetTokens(mBadTokens, tokens);
mTrainingDataDirty = PR_TRUE; mTrainingDataDirty = PR_TRUE;
} }
break; break;
@ -661,7 +687,7 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char
// remove tokens from good corpus. // remove tokens from good corpus.
if (mGoodCount > 0) { if (mGoodCount > 0) {
--mGoodCount; --mGoodCount;
forgetTokens(mGoodTokens, tokens, count); forgetTokens(mGoodTokens, tokens);
mTrainingDataDirty = PR_TRUE; mTrainingDataDirty = PR_TRUE;
} }
break; break;
@ -670,13 +696,13 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char
case nsIJunkMailPlugin::JUNK: case nsIJunkMailPlugin::JUNK:
// put tokens into junk corpus. // put tokens into junk corpus.
++mBadCount; ++mBadCount;
rememberTokens(mBadTokens, tokens, count); rememberTokens(mBadTokens, tokens);
mTrainingDataDirty = PR_TRUE; mTrainingDataDirty = PR_TRUE;
break; break;
case nsIJunkMailPlugin::GOOD: case nsIJunkMailPlugin::GOOD:
// put tokens into good corpus. // put tokens into good corpus.
++mGoodCount; ++mGoodCount;
rememberTokens(mGoodTokens, tokens, count); rememberTokens(mGoodTokens, tokens);
mTrainingDataDirty = PR_TRUE; mTrainingDataDirty = PR_TRUE;
break; break;
} }
@ -744,21 +770,17 @@ static bool writeTokens(FILE* stream, Tokenizer& tokenizer)
return false; return false;
if (tokenCount > 0) { if (tokenCount > 0) {
Token* tokens = tokenizer.getTokens(); TokenEnumeration tokens = tokenizer.getTokens();
if (!tokens) return false;
for (PRUint32 i = 0; i < tokenCount; ++i) { for (PRUint32 i = 0; i < tokenCount; ++i) {
Token& token = tokens[i]; Token* token = tokens.nextToken();
if (writeUInt32(stream, token.mCount) != 1) if (writeUInt32(stream, token->mCount) != 1)
break; break;
PRUint32 tokenLength = token.mLength; PRUint32 tokenLength = token->mLength;
if (writeUInt32(stream, tokenLength) != 1) if (writeUInt32(stream, tokenLength) != 1)
break; break;
if (fwrite(token.mWord, tokenLength, 1, stream) != 1) if (fwrite(token->mWord, tokenLength, 1, stream) != 1)
break; break;
} }
delete[] tokens;
} }
return true; return true;

Просмотреть файл

@ -48,8 +48,26 @@
#include "plarena.h" #include "plarena.h"
class Token; class Token;
class TokenEnumeration;
class TokenAnalyzer; class TokenAnalyzer;
/**
* Helper class to enumerate Token objects in a PLDHashTable
* safely and without copying (see bugzilla #174859). The
* enumeration is safe to use until a PL_DHASH_ADD
* or PL_DHASH_REMOVE is performed on the table.
*/
class TokenEnumeration {
public:
TokenEnumeration(PLDHashTable* table);
bool hasMoreTokens();
Token* nextToken();
private:
PRUint32 mEntrySize, mEntryCount, mEntryOffset;
char *mEntryAddr, *mEntryLimit;
};
class Tokenizer { class Tokenizer {
public: public:
Tokenizer(); Tokenizer();
@ -62,7 +80,8 @@ public:
void remove(const char* word, PRUint32 count = 1); void remove(const char* word, PRUint32 count = 1);
PRUint32 countTokens(); PRUint32 countTokens();
Token* getTokens(); Token* copyTokens();
TokenEnumeration getTokens();
/** /**
* Assumes that text is mutable and * Assumes that text is mutable and
@ -98,8 +117,8 @@ public:
virtual ~nsBayesianFilter(); virtual ~nsBayesianFilter();
nsresult tokenizeMessage(const char* messageURI, TokenAnalyzer* analyzer); nsresult tokenizeMessage(const char* messageURI, TokenAnalyzer* analyzer);
void classifyMessage(PRUint32 count, Token tokens[], const char* messageURI, nsIJunkMailClassificationListener* listener); void classifyMessage(Tokenizer& tokens, const char* messageURI, nsIJunkMailClassificationListener* listener);
void observeMessage(PRUint32 count, Token tokens[], const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener); void observeMessage(Tokenizer& tokens, const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener);
void writeTrainingData(); void writeTrainingData();
void readTrainingData(); void readTrainingData();