Simplified LU sequence by removing unnecessary code. Also use wstring to support reading wchar strings. Need to make sure that large amount of data can be read, because reader reads at most CACHE_BLOG_SIZE data.

This commit is contained in:
kaisheny 2015-04-20 16:04:39 -07:00
Родитель 3d958beb7a
Коммит fad6322a29
4 изменённых файлов: 286 добавлений и 1369 удалений

Просмотреть файл

@ -13,218 +13,6 @@
namespace Microsoft { namespace MSR { namespace CNTK {
// SetState for a particular value
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::SetState(int value, ParseState m_current_state, ParseState next_state)
{
DWORD ul = (DWORD)next_state;
int range_shift = ((int)m_current_state) << 8;
m_stateTable[range_shift+value] = ul;
}
// SetStateRange - set states transitions for a range of values
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state)
{
DWORD ul = (DWORD)next_state;
int range_shift = ((int)m_current_state) << 8;
for (int value = value1; value <= value2; value++)
{
m_stateTable[range_shift+value] = ul;
}
}
// SetupStateTables - setup state transition tables for each state
// each state has a block of 256 states indexed by the incoming character
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::SetupStateTables()
{
//=========================
// STATE = WHITESPACE
//=========================
SetStateRange(0,255, Whitespace, Label);
SetStateRange('0', '9', Whitespace, WholeNumber);
SetState('-', Whitespace, Sign);
SetState('+', Whitespace, Sign);
// whitespace
SetState(' ', Whitespace, Whitespace);
SetState('\t', Whitespace, Whitespace);
SetState('\r', Whitespace, Whitespace);
SetState(':', Whitespace, Whitespace); // intepret ':' as white space because it's a divider
SetState('\n', Whitespace, EndOfLine);
//=========================
// STATE = NEGATIVE_SIGN
//=========================
SetStateRange( 0, 255, Sign, Label);
SetStateRange( '0', '9', Sign, WholeNumber);
// whitespace
SetState(' ', Sign, Whitespace);
SetState('\t', Sign, Whitespace);
SetState('\r', Sign, Whitespace);
SetState('\n', Sign, EndOfLine);
//=========================
// STATE = NUMBER
//=========================
SetStateRange( 0, 255, WholeNumber, Label);
SetStateRange( '0', '9', WholeNumber, WholeNumber);
SetState('.', WholeNumber, Period);
SetState('e', WholeNumber, TheLetterE);
SetState('E', WholeNumber, TheLetterE);
// whitespace
SetState(' ', WholeNumber, Whitespace);
SetState('\t', WholeNumber, Whitespace);
SetState('\r', WholeNumber, Whitespace);
SetState(':', WholeNumber, Whitespace); // Add for 1234:0.9 usage in Sequences
SetState('\n', WholeNumber, EndOfLine);
//=========================
// STATE = PERIOD
//=========================
SetStateRange(0, 255, Period, Label);
SetStateRange('0', '9', Period, Remainder);
// whitespace
SetState(' ', Period, Whitespace);
SetState('\t', Period, Whitespace);
SetState('\r', Period, Whitespace);
SetState('\n', Period, EndOfLine);
//=========================
// STATE = REMAINDER
//=========================
SetStateRange(0, 255, Remainder, Label);
SetStateRange('0', '9', Remainder, Remainder);
SetState('e', Remainder, TheLetterE);
SetState('E', Remainder, TheLetterE);
// whitespace
SetState(' ', Remainder, Whitespace);
SetState('\t', Remainder, Whitespace);
SetState('\r', Remainder, Whitespace);
SetState(':', Remainder, Whitespace); // Add for 1234:0.9 usage in Sequences
SetState('\n', Remainder, EndOfLine);
//=========================
// STATE = THE_LETTER_E
//=========================
SetStateRange(0, 255, TheLetterE, Label);
SetStateRange('0', '9', TheLetterE, Exponent);
SetState('-', TheLetterE, ExponentSign);
SetState('+', TheLetterE, ExponentSign);
// whitespace
SetState(' ', TheLetterE, Whitespace);
SetState('\t', TheLetterE, Whitespace);
SetState('\r', TheLetterE, Whitespace);
SetState('\n', TheLetterE, EndOfLine);
//=========================
// STATE = EXPONENT_NEGATIVE_SIGN
//=========================
SetStateRange(0, 255, ExponentSign, Label);
SetStateRange('0', '9', ExponentSign, Exponent);
// whitespace
SetState(' ', ExponentSign, Whitespace);
SetState('\t', ExponentSign, Whitespace);
SetState('\r', ExponentSign, Whitespace);
SetState('\n', ExponentSign, EndOfLine);
//=========================
// STATE = EXPONENT
//=========================
SetStateRange(0, 255, Exponent, Label);
SetStateRange('0', '9', Exponent, Exponent);
// whitespace
SetState(' ', Exponent, Whitespace);
SetState('\t', Exponent, Whitespace);
SetState('\r', Exponent, Whitespace);
SetState(':', Exponent, Whitespace);
SetState('\n', Exponent, EndOfLine);
//=========================
// STATE = END_OF_LINE
//=========================
SetStateRange(0, 255, EndOfLine, Label);
SetStateRange( '0', '9', EndOfLine, WholeNumber);
SetState( '-', EndOfLine, Sign);
SetState( '\n', EndOfLine, EndOfLine);
// whitespace
SetState(' ', EndOfLine, Whitespace);
SetState('\t', EndOfLine, Whitespace);
SetState('\r', EndOfLine, Whitespace);
//=========================
// STATE = LABEL
//=========================
SetStateRange(0, 255, Label, Label);
SetState('\n', Label, EndOfLine);
// whitespace
SetState(' ', Label, Whitespace);
SetState('\t', Label, Whitespace);
SetState('\r', Label, Whitespace);
SetState(':', Label, Whitespace);
//=========================
// STATE = LINE_COUNT_EOL
//=========================
SetStateRange(0, 255, LineCountEOL, LineCountOther);
SetState( '\n', LineCountEOL, LineCountEOL);
//=========================
// STATE = LINE_COUNT_OTHER
//=========================
SetStateRange(0, 255, LineCountOther, LineCountOther);
SetState('\n', LineCountOther, LineCountEOL);
}
// reset all line state variables
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::PrepareStartLine()
{
m_numbersConvertedThisLine = 0;
m_labelsConvertedThisLine = 0;
m_elementsConvertedThisLine = 0;
m_spaceDelimitedStart = m_byteCounter;
m_spaceDelimitedMax = m_byteCounter;
m_lastLabelIsString = false;
m_beginSequence = m_endSequence = false;
}
// reset all number accumulation variables
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::PrepareStartNumber()
{
m_partialResult = 0;
m_builtUpNumber = 0;
m_divider = 0;
m_wholeNumberMultiplier = 1;
m_exponentMultiplier = 1;
}
// reset all state variables to start reading at a new position
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::PrepareStartPosition(size_t position)
{
m_current_state = Whitespace;
m_byteCounter = position; // must come before PrepareStartLine...
m_bufferStart = position;
// prepare state machine for new number and new line
PrepareStartNumber();
PrepareStartLine();
m_totalNumbersConverted = 0;
m_totalLabelsConverted = 0;
}
// LUSequenceParser constructor
template <typename NumType, typename LabelType>
LUSequenceParser<NumType, LabelType>::LUSequenceParser()
@ -236,121 +24,88 @@ LUSequenceParser<NumType, LabelType>::LUSequenceParser()
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::Init()
{
PrepareStartPosition(0);
m_fileBuffer = NULL;
m_pFile = NULL;
m_stateTable = new DWORD[AllStateMax * 256];
SetupStateTables();
}
// Parser destructor
template <typename NumType, typename LabelType>
LUSequenceParser<NumType, LabelType>::~LUSequenceParser()
{
delete m_stateTable;
delete m_fileBuffer;
if (m_pFile)
fclose(m_pFile);
}
// UpdateBuffer - load the next buffer full of data
// returns - number of records read
template <typename NumType, typename LabelType>
size_t LUSequenceParser<NumType, LabelType>::UpdateBuffer()
{
// state machine might want to look back this far, so copy to beginning
size_t saveBytes = m_byteCounter-m_spaceDelimitedStart;
assert(saveBytes < m_bufferSize);
if (saveBytes)
{
memcpy_s(m_fileBuffer, m_bufferSize, &m_fileBuffer[m_byteCounter-m_bufferStart-saveBytes], saveBytes);
m_bufferStart = m_byteCounter-saveBytes;
}
// read the next block
size_t bytesToRead = min(m_bufferSize, m_fileSize-m_bufferStart)-saveBytes;
size_t bytesRead = fread(m_fileBuffer+saveBytes, 1, bytesToRead, m_pFile);
if (bytesRead == 0 && ferror(m_pFile))
RuntimeError("LUSequenceParser::UpdateBuffer - error reading file");
return bytesRead;
}
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::SetParseMode(ParseMode mode)
{
// if already in this mode, nothing to do
if (m_parseMode == mode)
return;
// switching modes
if (mode == ParseLineCount)
m_current_state = LineCountOther;
else
{
m_current_state = Whitespace;
PrepareStartLine();
PrepareStartNumber();
}
m_parseMode = mode;
}
// SetTraceLevel - Set the level of screen output
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
template <typename NumType, typename LabelType>
void LUSequenceParser<NumType, LabelType>::SetTraceLevel(int traceLevel)
{
m_traceLevel = traceLevel;
}
// NOTE: Current code is identical to float, don't know how to specialize with template parameter that only covers one parameter
#ifdef STANDALONE
int wmain(int argc, wchar_t* argv[])
{
LUSequenceParser<double, int> parser;
std::vector<double> values;
values.reserve(784000*6);
std::vector<int> labels;
labels.reserve(60000);
parser.ParseInit(L"c:\\speech\\mnist\\mnist_train.txt", LabelFirst);
//parser.ParseInit("c:\\speech\\parseTest.txt", LabelNone);
int records = 0;
do
{
int recordsRead = parser.Parse(10000, &values, &labels);
if (recordsRead < 10000)
parser.SetFilePosition(0); // go around again
records += recordsRead;
values.clear();
labels.clear();
}
while (records < 150000);
return records;
}
#endif
// instantiate UCI parsers for supported types
template class LUSequenceParser<float, int>;
template class LUSequenceParser<float, float>;
template class LUSequenceParser<float, std::string>;
template class LUSequenceParser<float, std::wstring>;
template class LUSequenceParser<double, int>;
template class LUSequenceParser<double, double>;
template class LUSequenceParser<double, std::string>;
template class LUSequenceParser<double, std::wstring>;
template <typename NumType, typename LabelType>
void LUBatchLUSequenceParser<NumType, LabelType>::ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O")
template<class NumType, class LabelType>
long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested, std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, const map<wstring, long>& inputlabel2id, const map<wstring, long>& outputlabel2id)
{
LULUSequenceParser<NumType, LabelType>::ParseInit(fileName, dimLabelsIn, dimLabelsOut, beginSequenceIn, endSequenceIn, beginSequenceOut, endSequenceOut);
}
// transfer to member variables
m_inputs = input;
m_labels = labels;
template <typename NumType, typename LabelType>
long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *inputs, std::vector<SequencePosition> *seqPos)
{
long linecnt;
linecnt = LULUSequenceParser<NumType, LabelType>::Parse(recordsRequested, labels, inputs, seqPos);
long recordCount = 0;
long orgRecordCount = (long)labels->size();
long lineCount = 0;
bool bAtEOS = false; /// whether the reader is at the end of sentence position
SequencePosition sequencePositionLast(0, 0, 0);
/// get line
wstring ch;
int prvat = 0;
while (lineCount < recordsRequested && mFile.good())
{
getline(mFile, ch);
ch = wtrim(ch);
if (mFile.eof())
ParseReset(); /// restart from the corpus begining
std::vector<wstring> vstr;
bool bBlankLine = (ch.length() == 0);
if (bBlankLine && !bAtEOS && input->size() > 0 && labels->size() > 0)
{
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
bAtEOS = true;
continue;
}
vstr = wsep_string(ch, L" ");
if (vstr.size() < 2)
continue;
bAtEOS = false;
vector<long> vtmp;
for (size_t i = 0; i < vstr.size() - 1; i++)
{
if (inputlabel2id.find(vstr[i]) == inputlabel2id.end())
LogicError("cannot find item %s in input label", vstr[i]);
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
}
if (outputlabel2id.find(vstr[vstr.size() - 1]) == outputlabel2id.end())
LogicError("cannot find item %s in output label", vstr[vstr.size() - 1]);
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
input->push_back(vtmp);
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
/// below is for backward support
vstr[0] == m_endTag) && input->size() > 0 && labels->size() > 0)
{
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
bAtEOS = true;
}
} // while
int prvat = 0;
size_t i = 0;
for (auto ptr = seqPos->begin(); ptr != seqPos->end(); ptr++, i++)
{
@ -359,15 +114,18 @@ long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested,
stinfo.sLen = iln;
stinfo.sBegin = prvat;
stinfo.sEnd = (int)ptr->labelPos;
mSentenceIndex2SentenceInfo.push_back(stinfo);
mSentenceIndex2SentenceInfo.push_back(stinfo);
prvat = (int)ptr->labelPos;
}
assert(mSentenceIndex2SentenceInfo.size() == linecnt);
return linecnt;
fprintf(stderr, "LUBatchLUSequenceParser: parse %d lines\n", lineCount);
return lineCount;
}
template class LUBatchLUSequenceParser<float, std::string>;
template class LUBatchLUSequenceParser<double, std::string>;
template class LUBatchLUSequenceParser<float, std::wstring>;
template class LUBatchLUSequenceParser<double, std::wstring>;
}}}

Просмотреть файл

@ -20,151 +20,57 @@ using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
#define MAXSTRING 2048
// UCI label location types
enum LabelMode
{
LabelNone = 0,
LabelFirst = 1,
LabelLast = 2,
};
enum ParseMode
{
ParseNormal = 0,
ParseLineCount = 1
};
enum SequenceFlags
{
seqFlagNull = 0,
seqFlagLineBreak = 1, // line break on the parsed line
seqFlagEmptyLine = 2, // empty line
seqFlagStartLabel = 4,
seqFlagStopLabel = 8
};
// SequencePosition, save the ending indexes into the array for a sequence
struct SequencePosition
{
size_t inputPos; // max position in the number array for this sequence
size_t labelPos; // max position in the label array for this sequence
long inputPos; // max position in the number array for this sequence
long labelPos; // max position in the label array for this sequence
unsigned flags; // flags that apply to this sequence
SequencePosition(size_t inPos, size_t labelPos, unsigned flags):
inputPos(inPos), labelPos(labelPos), flags(flags)
SequencePosition(long inPos, long labelPos, unsigned flags) :
inputPos(inPos), labelPos(labelPos), flags(flags)
{}
};
// LUSequenceParser - the parser for the UCI format files
// LUSequenceParser - the parser for the UCI format files
// for ultimate speed, this class implements a state machine to read these format files
template <typename NumType, typename LabelType=int>
template <typename NumType, typename LabelType = wstring>
class LUSequenceParser
{
public:
using LabelIdType = long;
protected:
enum ParseState
{
WholeNumber = 0,
Remainder = 1,
Exponent = 2,
Whitespace = 3,
Sign = 4,
ExponentSign = 5,
Period = 6,
TheLetterE = 7,
EndOfLine = 8,
Label = 9, // any non-number things we run into
ParseStateMax = 10, // number of parse states
LineCountEOL = 10,
LineCountOther = 11,
AllStateMax = 12
};
// type of label processing
ParseMode m_parseMode;
// definition of label and feature dimensions
size_t m_dimFeatures;
size_t m_dimLabelsIn;
std::string m_beginSequenceIn; // starting sequence string (i.e. <s>)
std::string m_endSequenceIn; // ending sequence string (i.e. </s>)
wstring m_beginSequenceIn; // starting sequence string (i.e. <s>)
wstring m_endSequenceIn; // ending sequence string (i.e. </s>)
size_t m_dimLabelsOut;
std::string m_beginSequenceOut; // starting sequence string (i.e. 'O')
std::string m_endSequenceOut; // ending sequence string (i.e. 'O')
wstring m_beginSequenceOut; // starting sequence string (i.e. 'O')
wstring m_endSequenceOut; // ending sequence string (i.e. 'O')
// level of screen output
int m_traceLevel;
// current state of the state machine
ParseState m_current_state;
// state tables
DWORD *m_stateTable;
// numeric state machine variables
double m_partialResult;
double m_builtUpNumber;
double m_divider;
double m_wholeNumberMultiplier;
double m_exponentMultiplier;
// label state machine variables
size_t m_spaceDelimitedStart;
size_t m_spaceDelimitedMax; // start of the next whitespace sequence (one past the end of the last word)
int m_numbersConvertedThisLine;
int m_labelsConvertedThisLine;
int m_elementsConvertedThisLine;
// sequence state machine variables
bool m_beginSequence;
bool m_endSequence;
std::string m_beginTag;
std::string m_endTag;
// global stats
int m_totalNumbersConverted;
int m_totalLabelsConverted;
wstring m_beginTag;
wstring m_endTag;
// file positions/buffer
FILE * m_pFile;
int64_t m_byteCounter;
int64_t m_fileSize;
BYTE * m_fileBuffer;
size_t m_bufferStart;
size_t m_bufferSize;
// last label was a string (for last label processing)
bool m_lastLabelIsString;
// vectors to append to
std::vector<vector<LabelType>>* m_inputs; // pointer to vectors to append with numbers
std::vector<LabelType>* m_labels; // pointer to vector to append with labels (may be numeric)
std::vector<vector<LabelIdType>>* m_inputs; // pointer to vectors to append with numbers
std::vector<LabelIdType>* m_labels; // pointer to vector to append with labels (may be numeric)
// FUTURE: do we want a vector to collect string labels in the non string label case? (signifies an error)
// SetState for a particular value
void SetState(int value, ParseState m_current_state, ParseState next_state);
// SetStateRange - set states transitions for a range of values
void SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state);
// SetupStateTables - setup state transition tables for each state
// each state has a block of 256 states indexed by the incoming character
void SetupStateTables();
// reset all line state variables
void PrepareStartLine();
// reset all number accumulation variables
void PrepareStartNumber();
// reset all state variables to start reading at a new position
void PrepareStartPosition(size_t position);
// UpdateBuffer - load the next buffer full of data
// returns - number of records read
size_t UpdateBuffer();
public:
// LUSequenceParser constructor
@ -176,14 +82,6 @@ public:
~LUSequenceParser();
public:
// SetParseMode - Set the parsing mode
// mode - set mode to either ParseLineCount, or ParseNormal
void SetParseMode(ParseMode mode);
// SetTraceLevel - Set the level of screen output
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
void SetTraceLevel(int traceLevel);
// ParseInit - Initialize a parse of a file
// fileName - path to the file to open
@ -196,7 +94,7 @@ public:
// endSequenceOut - endSequence output label
// bufferSize - size of temporary buffer to store reads
// startPosition - file position on which we should start
void ParseInit(LPCWSTR fileName, size_t dimFeatures, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn="<s>", std::string endSequenceIn="</s>", std::string beginSequenceOut="O", std::string endSequenceOut="O", size_t bufferSize=1024*256, size_t startPosition=0)
void ParseInit(LPCWSTR fileName, size_t dimFeatures, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut )
{
assert(fileName != NULL);
m_dimFeatures = dimFeatures;
@ -207,10 +105,7 @@ public:
m_beginSequenceOut = beginSequenceOut;
m_endSequenceOut = endSequenceOut;
m_parseMode = ParseNormal;
m_traceLevel = 0;
m_bufferSize = bufferSize;
m_bufferStart = startPosition;
m_beginTag = m_beginSequenceIn;
m_endTag = m_endSequenceIn;
@ -225,18 +120,23 @@ public:
int rc = _fseeki64(m_pFile, 0, SEEK_END);
if (rc)
RuntimeError("LUSequenceParser::ParseInit - error seeking in file");
m_fileBuffer = new BYTE[m_bufferSize];
}
};
/// language model sequence parser
template <typename NumType, typename LabelType>
class LULUSequenceParser : public LUSequenceParser<NumType, LabelType>
typedef struct{
size_t sLen;
int sBegin;
int sEnd;
} stSentenceInfo;
template <typename NumType, typename LabelType = wstring>
class LUBatchLUSequenceParser : public LUSequenceParser<NumType, LabelType>
{
protected:
FILE * mFile;
public:
wifstream mFile;
std::wstring mFileName;
vector<stSentenceInfo> mSentenceIndex2SentenceInfo;
public:
using LUSequenceParser<NumType, LabelType>::m_dimFeatures;
@ -246,27 +146,21 @@ public:
using LUSequenceParser<NumType, LabelType>::m_dimLabelsOut;
using LUSequenceParser<NumType, LabelType>::m_beginSequenceOut;
using LUSequenceParser<NumType, LabelType>::m_endSequenceOut;
using LUSequenceParser<NumType, LabelType>::m_parseMode;
using LUSequenceParser<NumType, LabelType>::m_traceLevel;
using LUSequenceParser<NumType, LabelType>::m_bufferSize;
using LUSequenceParser<NumType, LabelType>::m_bufferStart;
using LUSequenceParser<NumType, LabelType>::m_beginTag;
using LUSequenceParser<NumType, LabelType>::m_endTag;
using LUSequenceParser<NumType, LabelType>::m_fileBuffer;
using LUSequenceParser<NumType, LabelType>::m_fileSize;
using LUSequenceParser<NumType, LabelType>::m_inputs;
using LUSequenceParser<NumType, LabelType>::m_labels;
using LUSequenceParser<NumType, LabelType>::m_beginSequence;
using LUSequenceParser<NumType, LabelType>::m_endSequence;
using LUSequenceParser<NumType, LabelType>::m_totalNumbersConverted;
LULUSequenceParser() {
mFile = nullptr;
LUBatchLUSequenceParser() {
};
~LULUSequenceParser() {
if (mFile) fclose(mFile);
~LUBatchLUSequenceParser() {
mFile.close();
}
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O")
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut)
{
assert(fileName != NULL);
mFileName = fileName;
@ -277,33 +171,27 @@ public:
m_beginSequenceOut = beginSequenceOut;
m_endSequenceOut = endSequenceOut;
m_parseMode = ParseNormal;
m_traceLevel = 0;
m_bufferSize = 0;
m_bufferStart = 0;
m_beginTag = m_beginSequenceIn;
m_endTag = m_endSequenceIn;
m_fileSize = -1;
m_fileBuffer = NULL;
mFile.close();
if (mFile) fclose(mFile);
if (_wfopen_s(&mFile, fileName, L"rt") != 0)
mFile.open(fileName, wifstream::in);
if (!mFile.good())
RuntimeError("cannot open file %ls", fileName);
}
void ParseReset()
{
if (mFile) fseek(mFile, 0, SEEK_SET);
mFile.seekg(0, mFile.beg);
}
void AddOneItem(std::vector<LabelType> *labels, std::vector<vector<LabelType>> *input, std::vector<SequencePosition> *seqPos, long& lineCount,
void AddOneItem(std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, long& lineCount,
long & recordCount, long orgRecordCount, SequencePosition& sequencePositionLast)
{
SequencePosition sequencePos(input->size(), labels->size(),
m_beginSequence ? seqFlagStartLabel : 0 | m_endSequence ? seqFlagStopLabel : 0 | seqFlagLineBreak);
SequencePosition sequencePos((long)input->size(), (long)labels->size(), 1);
seqPos->push_back(sequencePos);
sequencePositionLast = sequencePos;
@ -323,94 +211,8 @@ public:
// numbers - pointer to vector to return the numbers
// seqPos - pointers to the other two arrays showing positions of each sequence
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
long Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *input, std::vector<SequencePosition> *seqPos)
{
assert(labels != NULL || m_dimLabelsIn == 0 && m_dimLabelsOut == 0 || m_parseMode == ParseLineCount);
// transfer to member variables
m_inputs = input;
m_labels = labels;
long TickStart = GetTickCount();
long recordCount = 0;
long orgRecordCount = (long)labels->size();
long lineCount = 0;
bool bAtEOS = false; /// whether the reader is at the end of sentence position
SequencePosition sequencePositionLast(0, 0, seqFlagNull);
/// get line
char ch2[MAXSTRING];
while (lineCount < recordsRequested && fgets(ch2, MAXSTRING, mFile) != nullptr)
{
string ch = ch2;
std::vector<string> vstr;
bool bBlankLine = (trim(ch).length() == 0);
if (bBlankLine && !bAtEOS && input->size() > 0 && labels->size() > 0)
{
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
bAtEOS = true;
continue;
}
vstr = sep_string(ch, " ");
if (vstr.size() < 2)
continue;
bAtEOS = false;
vector<LabelType> vtmp;
for (size_t i = 0; i < vstr.size() - 1; i++)
{
vtmp.push_back(vstr[i]);
}
labels->push_back(vstr[vstr.size() - 1]);
input->push_back(vtmp);
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
/// below is for backward support
vstr[0] == m_endTag) && input->size() > 0 && labels->size() > 0)
{
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
bAtEOS = true;
}
} // while
long TickStop = GetTickCount();
long TickDelta = TickStop - TickStart;
if (m_traceLevel > 2)
fprintf(stderr, "\n%d ms, %d numbers parsed\n\n", TickDelta, m_totalNumbersConverted);
return lineCount;
}
long Parse(size_t recordsRequested, std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, const map<wstring, long>& inputlabel2id, const map<wstring, long>& outputlabel2id);
};
typedef struct{
size_t sLen;
int sBegin;
int sEnd;
} stSentenceInfo;
/// language model sequence parser
template <typename NumType, typename LabelType>
class LUBatchLUSequenceParser: public LULUSequenceParser<NumType, LabelType>
{
public:
vector<stSentenceInfo> mSentenceIndex2SentenceInfo;
public:
LUBatchLUSequenceParser() { };
~LUBatchLUSequenceParser() { }
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O");
// Parse - Parse the data
// recordsRequested - number of records requested
// labels - pointer to vector to return the labels
// numbers - pointer to vector to return the numbers
// seqPos - pointers to the other two arrays showing positions of each sequence
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
long Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *inputs, std::vector<SequencePosition> *seqPos);
};
}}};

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -18,12 +18,16 @@
namespace Microsoft { namespace MSR { namespace CNTK {
#ifdef DBG_SMT
#define CACHE_BLOG_SIZE 100
#else
#define CACHE_BLOG_SIZE 50000
#endif
#define STRIDX2CLS L"idx2cls"
#define CLASSINFO L"classinfo"
#define MAX_STRING 2048
#define MAX_STRING 100000
#define NULLLABEL 65532
@ -44,22 +48,18 @@ protected:
std::wstring m_file;
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
int nwords, dims, nsamps, nglen, nmefeats;
using LabelType = wstring;
using LabelIdType = long;
long nwords, dims, nsamps, nglen, nmefeats;
int m_seed;
bool mRandomize;
int class_size;
map<int, vector<int>> class_words;
vector<int>class_cn;
public:
/// deal with OOV
map<string, string> mWordMapping;
map<LabelType, LabelType> mWordMapping;
string mWordMappingFn;
string mUnkStr;
LabelType mUnkStr;
public:
/// accumulated number of sentneces read so far
@ -67,8 +67,7 @@ public:
protected:
LULUSequenceParser<ElemType, LabelType> m_parser;
// LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
size_t m_mbSize; // size of minibatch requested
size_t m_mbStartSample; // starting sample # of the next minibatch
size_t m_epochSize; // size of an epoch
@ -119,12 +118,12 @@ protected:
LabelKind type; // labels are categories, create mapping table
std::map<LabelIdType, LabelType> mapIdToLabel;
std::map<LabelType, LabelIdType> mapLabelToId;
map<string, int> word4idx;
map<int, string> idx4word;
map<LabelType, LabelIdType> word4idx;
map<LabelIdType, LabelType> idx4word;
LabelIdType idMax; // maximum label ID we have encountered so far
LabelIdType dim; // maximum label ID we will ever see (used for array dimensions)
std::string beginSequence; // starting sequence string (i.e. <s>)
std::string endSequence; // ending sequence string (i.e. </s>)
long dim; // maximum label ID we will ever see (used for array dimensions)
LabelType beginSequence; // starting sequence string (i.e. <s>)
LabelType endSequence; // ending sequence string (i.e. </s>)
bool busewordmap; /// whether using wordmap to map unseen words to unk
std::wstring mapName;
std::wstring fileToWrite; // set to the path if we need to write out the label file
@ -146,37 +145,34 @@ protected:
void WriteLabelFile();
void LoadLabelFile(const std::wstring &filePath, std::vector<LabelType>& retLabels);
LabelIdType GetIdFromLabel(const std::string& label, LabelInfo& labelInfo);
bool GetIdFromLabel(const vector<string>& label, LabelInfo& labelInfo, vector<LabelIdType>& val);
bool CheckIdFromLabel(const std::string& labelValue, const LabelInfo& labelInfo, unsigned & labelId);
LabelIdType GetIdFromLabel(const LabelType& label, LabelInfo& labelInfo);
bool GetIdFromLabel(const vector<LabelIdType>& label, vector<LabelIdType>& val);
bool CheckIdFromLabel(const LabelType& labelValue, const LabelInfo& labelInfo, unsigned & labelId);
virtual bool ReadRecord(size_t readSample);
bool SentenceEnd();
public:
virtual void Init(const ConfigParameters& config);
void ReadLabelInfo(const wstring & vocfile, map<string, int> & word4idx,
map<int, string>& idx4word) ;
void ChangeMaping(const map<string, string>& maplist,
const string & unkstr ,
map<string, int> & word4idx);
void Init(const ConfigParameters& ){};
void ReadLabelInfo(const wstring & vocfile, map<LabelType, LabelIdType> & word4idx,
map<LabelIdType, LabelType>& idx4word);
void ChangeMaping(const map<LabelType, LabelType>& maplist,
const LabelType& unkstr,
map<LabelType, LabelIdType> & word4idx);
void ReadWord(char *wrod, FILE *fin);
void Destroy() {};
virtual void Destroy();
LUSequenceReader() {
m_featuresBuffer=NULL; m_labelsBuffer=NULL; m_clsinfoRead = false; m_idx2clsRead = false;
}
virtual ~LUSequenceReader();
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
~LUSequenceReader(){};
void StartMinibatchLoop(size_t , size_t , size_t = requestDataSize) {};
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
void SentenceEnd(std::vector<size_t> &/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
public:
int GetSentenceEndIdFromOutputLabel();
};
@ -185,9 +181,9 @@ template<class ElemType>
class BatchLUSequenceReader : public LUSequenceReader<ElemType>
{
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
using LUSequenceReader<ElemType>::mWordMappingFn;
using LabelType = wstring;
using LabelIdType = long;
using LUSequenceReader<ElemType>::mWordMappingFn;
using LUSequenceReader<ElemType>::m_cachingReader;
using LUSequenceReader<ElemType>::mWordMapping;
using LUSequenceReader<ElemType>::mUnkStr;
@ -197,7 +193,6 @@ public:
using LUSequenceReader<ElemType>::labelInfoMin;
using LUSequenceReader<ElemType>::labelInfoMax;
using LUSequenceReader<ElemType>::m_featureDim;
using LUSequenceReader<ElemType>::class_size;
using LUSequenceReader<ElemType>::m_labelInfo;
// using LUSequenceReader<ElemType>::m_labelInfoIn;
using LUSequenceReader<ElemType>::m_mbStartSample;
@ -247,8 +242,8 @@ private:
size_t mLastPosInSentence;
size_t mNumRead ;
std::vector<vector<LabelType>> m_featureTemp;
std::vector<LabelType> m_labelTemp;
std::vector<vector<LabelIdType>> m_featureTemp;
std::vector<LabelIdType> m_labelTemp;
bool mSentenceEnd;
bool mSentenceBegin;