Added TokenEnumeration, which provides efficient direct enumeration of the Tokenizer's PLDHashTable. r/sr=brendan, a=asa [not part of build]

This commit is contained in:
beard%netscape.com 2002-10-18 06:31:47 +00:00
Родитель 09444442fd
Коммит 737a579e5e
2 изменённых файлов: 105 добавлений и 64 удалений

Просмотреть файл

@ -57,6 +57,39 @@ struct Token : public PLDHashEntryHdr {
double mProbability; // TODO: cache probabilities
};
TokenEnumeration::TokenEnumeration(PLDHashTable* table)
: mEntrySize(table->entrySize),
mEntryCount(table->entryCount),
mEntryOffset(0),
mEntryAddr(table->entryStore)
{
PRUint32 capacity = PL_DHASH_TABLE_SIZE(table);
mEntryLimit = mEntryAddr + capacity * mEntrySize;
}
inline bool TokenEnumeration::hasMoreTokens()
{
return (mEntryOffset < mEntryCount);
}
inline Token* TokenEnumeration::nextToken()
{
Token* token = NULL;
PRUint32 entrySize = mEntrySize;
char *entryAddr = mEntryAddr, *entryLimit = mEntryLimit;
while (entryAddr < entryLimit) {
PLDHashEntryHdr* entry = (PLDHashEntryHdr*) entryAddr;
entryAddr += entrySize;
if (PL_DHASH_ENTRY_IS_LIVE(entry)) {
token = NS_STATIC_CAST(Token*, entry);
++mEntryOffset;
break;
}
}
mEntryAddr = entryAddr;
return token;
}
// PLDHashTable operation callbacks
static const void* PR_CALLBACK GetKey(PLDHashTable* table, PLDHashEntryHdr* entry)
@ -235,44 +268,37 @@ void Tokenizer::visit(bool (*f) (Token*, void*), void* data)
NS_ASSERTION(visitCount == mTokenTable.entryCount, "visitCount != entryCount!");
}
struct GatherClosure {
PRUint32 count;
PRUint32 offset;
Token* tokens;
};
static bool gatherTokens(Token* token, void* data)
{
GatherClosure* closure = NS_REINTERPRET_CAST(GatherClosure*, data);
NS_ASSERTION(closure->offset < closure->count, "too many tokens");
if (closure->offset < closure->count)
closure->tokens[closure->offset++] = *token;
return true;
}
inline PRUint32 Tokenizer::countTokens()
{
return mTokenTable.entryCount;
}
Token* Tokenizer::getTokens()
Token* Tokenizer::copyTokens()
{
PRUint32 count = countTokens();
if (count > 0) {
Token* tokens = new Token[count];
if (tokens) {
GatherClosure closure = { count, 0, tokens };
visit(gatherTokens, &closure);
Token* tp = tokens;
TokenEnumeration e(&mTokenTable);
while (e.hasMoreTokens())
*tp++ = *e.nextToken();
}
return tokens;
}
return NULL;
}
inline TokenEnumeration Tokenizer::getTokens()
{
return TokenEnumeration(&mTokenTable);
}
class TokenAnalyzer {
public:
virtual ~TokenAnalyzer() {}
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[]) = 0;
virtual void analyzeTokens(const char* source, Tokenizer& tokenizer) = 0;
};
/**
@ -356,20 +382,20 @@ NS_IMETHODIMP TokenStreamListener::OnDataAvailable(nsIRequest *aRequest, nsISupp
/* consume the tokens up to the last legal token delimiter in the buffer. */
totalCount = (readCount + mLeftOverCount);
buffer[totalCount] = '\0';
char* last_delimiter = NULL;
char* lastDelimiter = NULL;
char* scan = buffer + totalCount;
while (scan > buffer) {
if (strchr(kBayesianFilterTokenDelimiters, *--scan)) {
last_delimiter = scan;
lastDelimiter = scan;
break;
}
}
if (last_delimiter) {
*last_delimiter = '\0';
if (lastDelimiter) {
*lastDelimiter = '\0';
mTokenizer.tokenize(buffer);
PRUint32 consumedCount = 1 + (last_delimiter - buffer);
PRUint32 consumedCount = 1 + (lastDelimiter - buffer);
mLeftOverCount = totalCount - consumedCount;
if (mLeftOverCount)
memmove(buffer, buffer + consumedCount, mLeftOverCount);
@ -402,14 +428,8 @@ NS_IMETHODIMP TokenStreamListener::OnStopRequest(nsIRequest *aRequest, nsISuppor
}
/* finally, analyze the tokenized message. */
if (mAnalyzer) {
PRUint32 count = mTokenizer.countTokens();
Token* tokens = mTokenizer.getTokens();
if (count && tokens) {
mAnalyzer->analyzeTokens(mTokenSource.get(), count, tokens);
delete[] tokens;
}
}
if (mAnalyzer)
mAnalyzer->analyzeTokens(mTokenSource.get(), mTokenizer);
return NS_OK;
}
@ -438,9 +458,9 @@ public:
{
}
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[])
virtual void analyzeTokens(const char* source, Tokenizer& tokenizer)
{
mFilter->classifyMessage(count, tokens, source, mListener);
mFilter->classifyMessage(tokenizer, source, mListener);
}
private:
@ -498,11 +518,14 @@ static int compareTokens(const void* p1, const void* p2, void* /* data */)
inline double max(double x, double y) { return (x > y ? x : y); }
inline double min(double x, double y) { return (x < y ? x : y); }
void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const char* messageURI,
void nsBayesianFilter::classifyMessage(Tokenizer& tokenizer, const char* messageURI,
nsIJunkMailClassificationListener* listener)
{
Token* tokens = tokenizer.copyTokens();
if (!tokens) return;
/* run the kernel of the Graham filter algorithm here. */
PRUint32 i;
PRUint32 i, count = tokenizer.countTokens();
double ngood = mGoodCount, nbad = mBadCount;
for (i = 0; i < count; ++i) {
Token& token = tokens[i];
@ -546,6 +569,8 @@ void nsBayesianFilter::classifyMessage(PRUint32 count, Token tokens[], const cha
double prob = (prod1 / (prod1 + prod2));
bool isJunk = (prob >= 0.90);
delete[] tokens;
if (listener)
listener->OnMessageClassified(messageURI, isJunk ? nsMsgJunkStatus(nsIJunkMailPlugin::JUNK) : nsMsgJunkStatus(nsIJunkMailPlugin::GOOD));
}
@ -576,7 +601,7 @@ NS_IMETHODIMP nsBayesianFilter::SetBatchUpdate(PRBool aBatchUpdate)
{
mBatchUpdate = aBatchUpdate;
if (mBatchUpdate && mTrainingDataDirty)
if (!mBatchUpdate && mTrainingDataDirty)
writeTrainingData();
return NS_OK;
@ -614,9 +639,9 @@ public:
{
}
virtual void analyzeTokens(const char* source, PRUint32 count, Token tokens[])
virtual void analyzeTokens(const char* source, Tokenizer& tokenizer)
{
mFilter->observeMessage(count, tokens, source, mOldClassification,
mFilter->observeMessage(tokenizer, source, mOldClassification,
mNewClassification, mListener);
}
@ -628,32 +653,33 @@ private:
nsMsgJunkStatus mNewClassification;
};
static void forgetTokens(Tokenizer& corpus, Token tokens[], PRUint32 count)
static void forgetTokens(Tokenizer& corpus, TokenEnumeration tokens)
{
for (PRUint32 i = 0; i < count; ++i) {
Token& token = tokens[i];
corpus.remove(token.mWord, token.mCount);
while (tokens.hasMoreTokens()) {
Token* token = tokens.nextToken();
corpus.remove(token->mWord, token->mCount);
}
}
static void rememberTokens(Tokenizer& corpus, Token tokens[], PRUint32 count)
static void rememberTokens(Tokenizer& corpus, TokenEnumeration tokens)
{
for (PRUint32 i = 0; i < count; ++i) {
Token& token = tokens[i];
corpus.add(token.mWord, token.mCount);
while (tokens.hasMoreTokens()) {
Token* token = tokens.nextToken();
corpus.add(token->mWord, token->mCount);
}
}
void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char* messageURL,
void nsBayesianFilter::observeMessage(Tokenizer& tokenizer, const char* messageURL,
nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification,
nsIJunkMailClassificationListener* listener)
{
TokenEnumeration tokens = tokenizer.getTokens();
switch (oldClassification) {
case nsIJunkMailPlugin::JUNK:
// remove tokens from junk corpus.
if (mBadCount > 0) {
--mBadCount;
forgetTokens(mBadTokens, tokens, count);
forgetTokens(mBadTokens, tokens);
mTrainingDataDirty = PR_TRUE;
}
break;
@ -661,7 +687,7 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char
// remove tokens from good corpus.
if (mGoodCount > 0) {
--mGoodCount;
forgetTokens(mGoodTokens, tokens, count);
forgetTokens(mGoodTokens, tokens);
mTrainingDataDirty = PR_TRUE;
}
break;
@ -670,13 +696,13 @@ void nsBayesianFilter::observeMessage(PRUint32 count, Token tokens[], const char
case nsIJunkMailPlugin::JUNK:
// put tokens into junk corpus.
++mBadCount;
rememberTokens(mBadTokens, tokens, count);
rememberTokens(mBadTokens, tokens);
mTrainingDataDirty = PR_TRUE;
break;
case nsIJunkMailPlugin::GOOD:
// put tokens into good corpus.
++mGoodCount;
rememberTokens(mGoodTokens, tokens, count);
rememberTokens(mGoodTokens, tokens);
mTrainingDataDirty = PR_TRUE;
break;
}
@ -744,21 +770,17 @@ static bool writeTokens(FILE* stream, Tokenizer& tokenizer)
return false;
if (tokenCount > 0) {
Token* tokens = tokenizer.getTokens();
if (!tokens) return false;
TokenEnumeration tokens = tokenizer.getTokens();
for (PRUint32 i = 0; i < tokenCount; ++i) {
Token& token = tokens[i];
if (writeUInt32(stream, token.mCount) != 1)
Token* token = tokens.nextToken();
if (writeUInt32(stream, token->mCount) != 1)
break;
PRUint32 tokenLength = token.mLength;
PRUint32 tokenLength = token->mLength;
if (writeUInt32(stream, tokenLength) != 1)
break;
if (fwrite(token.mWord, tokenLength, 1, stream) != 1)
if (fwrite(token->mWord, tokenLength, 1, stream) != 1)
break;
}
delete[] tokens;
}
return true;

Просмотреть файл

@ -48,8 +48,26 @@
#include "plarena.h"
class Token;
class TokenEnumeration;
class TokenAnalyzer;
/**
* Helper class to enumerate Token objects in a PLDHashTable
* safely and without copying (see bugzilla #174859). The
* enumeration is safe to use until a PL_DHASH_ADD
* or PL_DHASH_REMOVE is performed on the table.
*/
class TokenEnumeration {
public:
TokenEnumeration(PLDHashTable* table);
bool hasMoreTokens();
Token* nextToken();
private:
PRUint32 mEntrySize, mEntryCount, mEntryOffset;
char *mEntryAddr, *mEntryLimit;
};
class Tokenizer {
public:
Tokenizer();
@ -62,7 +80,8 @@ public:
void remove(const char* word, PRUint32 count = 1);
PRUint32 countTokens();
Token* getTokens();
Token* copyTokens();
TokenEnumeration getTokens();
/**
* Assumes that text is mutable and
@ -98,8 +117,8 @@ public:
virtual ~nsBayesianFilter();
nsresult tokenizeMessage(const char* messageURI, TokenAnalyzer* analyzer);
void classifyMessage(PRUint32 count, Token tokens[], const char* messageURI, nsIJunkMailClassificationListener* listener);
void observeMessage(PRUint32 count, Token tokens[], const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener);
void classifyMessage(Tokenizer& tokens, const char* messageURI, nsIJunkMailClassificationListener* listener);
void observeMessage(Tokenizer& tokens, const char* messageURI, nsMsgJunkStatus oldClassification, nsMsgJunkStatus newClassification, nsIJunkMailClassificationListener* listener);
void writeTrainingData();
void readTrainingData();