Simplified LU sequence by removing unnecessary code. Also use wstring to support reading wchar strings. Need to make sure that large amount of data can be read, because reader reads at most CACHE_BLOG_SIZE data.
This commit is contained in:
Родитель
3d958beb7a
Коммит
fad6322a29
|
@ -13,218 +13,6 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// SetState for a particular value
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::SetState(int value, ParseState m_current_state, ParseState next_state)
|
||||
{
|
||||
DWORD ul = (DWORD)next_state;
|
||||
int range_shift = ((int)m_current_state) << 8;
|
||||
m_stateTable[range_shift+value] = ul;
|
||||
}
|
||||
|
||||
// SetStateRange - set states transitions for a range of values
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state)
|
||||
{
|
||||
DWORD ul = (DWORD)next_state;
|
||||
int range_shift = ((int)m_current_state) << 8;
|
||||
for (int value = value1; value <= value2; value++)
|
||||
{
|
||||
m_stateTable[range_shift+value] = ul;
|
||||
}
|
||||
}
|
||||
|
||||
// SetupStateTables - setup state transition tables for each state
|
||||
// each state has a block of 256 states indexed by the incoming character
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::SetupStateTables()
|
||||
{
|
||||
//=========================
|
||||
// STATE = WHITESPACE
|
||||
//=========================
|
||||
|
||||
SetStateRange(0,255, Whitespace, Label);
|
||||
SetStateRange('0', '9', Whitespace, WholeNumber);
|
||||
SetState('-', Whitespace, Sign);
|
||||
SetState('+', Whitespace, Sign);
|
||||
// whitespace
|
||||
SetState(' ', Whitespace, Whitespace);
|
||||
SetState('\t', Whitespace, Whitespace);
|
||||
SetState('\r', Whitespace, Whitespace);
|
||||
SetState(':', Whitespace, Whitespace); // intepret ':' as white space because it's a divider
|
||||
SetState('\n', Whitespace, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = NEGATIVE_SIGN
|
||||
//=========================
|
||||
|
||||
SetStateRange( 0, 255, Sign, Label);
|
||||
SetStateRange( '0', '9', Sign, WholeNumber);
|
||||
// whitespace
|
||||
SetState(' ', Sign, Whitespace);
|
||||
SetState('\t', Sign, Whitespace);
|
||||
SetState('\r', Sign, Whitespace);
|
||||
SetState('\n', Sign, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = NUMBER
|
||||
//=========================
|
||||
|
||||
SetStateRange( 0, 255, WholeNumber, Label);
|
||||
SetStateRange( '0', '9', WholeNumber, WholeNumber);
|
||||
SetState('.', WholeNumber, Period);
|
||||
SetState('e', WholeNumber, TheLetterE);
|
||||
SetState('E', WholeNumber, TheLetterE);
|
||||
// whitespace
|
||||
SetState(' ', WholeNumber, Whitespace);
|
||||
SetState('\t', WholeNumber, Whitespace);
|
||||
SetState('\r', WholeNumber, Whitespace);
|
||||
SetState(':', WholeNumber, Whitespace); // Add for 1234:0.9 usage in Sequences
|
||||
SetState('\n', WholeNumber, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = PERIOD
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Period, Label);
|
||||
SetStateRange('0', '9', Period, Remainder);
|
||||
// whitespace
|
||||
SetState(' ', Period, Whitespace);
|
||||
SetState('\t', Period, Whitespace);
|
||||
SetState('\r', Period, Whitespace);
|
||||
SetState('\n', Period, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = REMAINDER
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Remainder, Label);
|
||||
SetStateRange('0', '9', Remainder, Remainder);
|
||||
SetState('e', Remainder, TheLetterE);
|
||||
SetState('E', Remainder, TheLetterE);
|
||||
// whitespace
|
||||
SetState(' ', Remainder, Whitespace);
|
||||
SetState('\t', Remainder, Whitespace);
|
||||
SetState('\r', Remainder, Whitespace);
|
||||
SetState(':', Remainder, Whitespace); // Add for 1234:0.9 usage in Sequences
|
||||
SetState('\n', Remainder, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = THE_LETTER_E
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, TheLetterE, Label);
|
||||
SetStateRange('0', '9', TheLetterE, Exponent);
|
||||
SetState('-', TheLetterE, ExponentSign);
|
||||
SetState('+', TheLetterE, ExponentSign);
|
||||
// whitespace
|
||||
SetState(' ', TheLetterE, Whitespace);
|
||||
SetState('\t', TheLetterE, Whitespace);
|
||||
SetState('\r', TheLetterE, Whitespace);
|
||||
SetState('\n', TheLetterE, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = EXPONENT_NEGATIVE_SIGN
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, ExponentSign, Label);
|
||||
SetStateRange('0', '9', ExponentSign, Exponent);
|
||||
// whitespace
|
||||
SetState(' ', ExponentSign, Whitespace);
|
||||
SetState('\t', ExponentSign, Whitespace);
|
||||
SetState('\r', ExponentSign, Whitespace);
|
||||
SetState('\n', ExponentSign, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = EXPONENT
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Exponent, Label);
|
||||
SetStateRange('0', '9', Exponent, Exponent);
|
||||
// whitespace
|
||||
SetState(' ', Exponent, Whitespace);
|
||||
SetState('\t', Exponent, Whitespace);
|
||||
SetState('\r', Exponent, Whitespace);
|
||||
SetState(':', Exponent, Whitespace);
|
||||
SetState('\n', Exponent, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = END_OF_LINE
|
||||
//=========================
|
||||
SetStateRange(0, 255, EndOfLine, Label);
|
||||
SetStateRange( '0', '9', EndOfLine, WholeNumber);
|
||||
SetState( '-', EndOfLine, Sign);
|
||||
SetState( '\n', EndOfLine, EndOfLine);
|
||||
// whitespace
|
||||
SetState(' ', EndOfLine, Whitespace);
|
||||
SetState('\t', EndOfLine, Whitespace);
|
||||
SetState('\r', EndOfLine, Whitespace);
|
||||
|
||||
|
||||
//=========================
|
||||
// STATE = LABEL
|
||||
//=========================
|
||||
SetStateRange(0, 255, Label, Label);
|
||||
SetState('\n', Label, EndOfLine);
|
||||
// whitespace
|
||||
SetState(' ', Label, Whitespace);
|
||||
SetState('\t', Label, Whitespace);
|
||||
SetState('\r', Label, Whitespace);
|
||||
SetState(':', Label, Whitespace);
|
||||
|
||||
//=========================
|
||||
// STATE = LINE_COUNT_EOL
|
||||
//=========================
|
||||
SetStateRange(0, 255, LineCountEOL, LineCountOther);
|
||||
SetState( '\n', LineCountEOL, LineCountEOL);
|
||||
|
||||
//=========================
|
||||
// STATE = LINE_COUNT_OTHER
|
||||
//=========================
|
||||
SetStateRange(0, 255, LineCountOther, LineCountOther);
|
||||
SetState('\n', LineCountOther, LineCountEOL);
|
||||
}
|
||||
|
||||
|
||||
// reset all line state variables
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::PrepareStartLine()
|
||||
{
|
||||
m_numbersConvertedThisLine = 0;
|
||||
m_labelsConvertedThisLine = 0;
|
||||
m_elementsConvertedThisLine = 0;
|
||||
m_spaceDelimitedStart = m_byteCounter;
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
m_lastLabelIsString = false;
|
||||
m_beginSequence = m_endSequence = false;
|
||||
}
|
||||
|
||||
// reset all number accumulation variables
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::PrepareStartNumber()
|
||||
{
|
||||
m_partialResult = 0;
|
||||
m_builtUpNumber = 0;
|
||||
m_divider = 0;
|
||||
m_wholeNumberMultiplier = 1;
|
||||
m_exponentMultiplier = 1;
|
||||
}
|
||||
|
||||
// reset all state variables to start reading at a new position
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::PrepareStartPosition(size_t position)
|
||||
{
|
||||
m_current_state = Whitespace;
|
||||
m_byteCounter = position; // must come before PrepareStartLine...
|
||||
m_bufferStart = position;
|
||||
|
||||
// prepare state machine for new number and new line
|
||||
PrepareStartNumber();
|
||||
PrepareStartLine();
|
||||
m_totalNumbersConverted = 0;
|
||||
m_totalLabelsConverted = 0;
|
||||
}
|
||||
|
||||
// LUSequenceParser constructor
|
||||
template <typename NumType, typename LabelType>
|
||||
LUSequenceParser<NumType, LabelType>::LUSequenceParser()
|
||||
|
@ -236,121 +24,88 @@ LUSequenceParser<NumType, LabelType>::LUSequenceParser()
|
|||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::Init()
|
||||
{
|
||||
PrepareStartPosition(0);
|
||||
m_fileBuffer = NULL;
|
||||
m_pFile = NULL;
|
||||
m_stateTable = new DWORD[AllStateMax * 256];
|
||||
SetupStateTables();
|
||||
}
|
||||
|
||||
// Parser destructor
|
||||
template <typename NumType, typename LabelType>
|
||||
LUSequenceParser<NumType, LabelType>::~LUSequenceParser()
|
||||
{
|
||||
delete m_stateTable;
|
||||
delete m_fileBuffer;
|
||||
if (m_pFile)
|
||||
fclose(m_pFile);
|
||||
}
|
||||
|
||||
|
||||
// UpdateBuffer - load the next buffer full of data
|
||||
// returns - number of records read
|
||||
template <typename NumType, typename LabelType>
|
||||
size_t LUSequenceParser<NumType, LabelType>::UpdateBuffer()
|
||||
{
|
||||
// state machine might want to look back this far, so copy to beginning
|
||||
size_t saveBytes = m_byteCounter-m_spaceDelimitedStart;
|
||||
assert(saveBytes < m_bufferSize);
|
||||
if (saveBytes)
|
||||
{
|
||||
memcpy_s(m_fileBuffer, m_bufferSize, &m_fileBuffer[m_byteCounter-m_bufferStart-saveBytes], saveBytes);
|
||||
m_bufferStart = m_byteCounter-saveBytes;
|
||||
}
|
||||
|
||||
// read the next block
|
||||
size_t bytesToRead = min(m_bufferSize, m_fileSize-m_bufferStart)-saveBytes;
|
||||
size_t bytesRead = fread(m_fileBuffer+saveBytes, 1, bytesToRead, m_pFile);
|
||||
if (bytesRead == 0 && ferror(m_pFile))
|
||||
RuntimeError("LUSequenceParser::UpdateBuffer - error reading file");
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::SetParseMode(ParseMode mode)
|
||||
{
|
||||
// if already in this mode, nothing to do
|
||||
if (m_parseMode == mode)
|
||||
return;
|
||||
|
||||
// switching modes
|
||||
if (mode == ParseLineCount)
|
||||
m_current_state = LineCountOther;
|
||||
else
|
||||
{
|
||||
m_current_state = Whitespace;
|
||||
PrepareStartLine();
|
||||
PrepareStartNumber();
|
||||
}
|
||||
m_parseMode = mode;
|
||||
}
|
||||
|
||||
// SetTraceLevel - Set the level of screen output
|
||||
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUSequenceParser<NumType, LabelType>::SetTraceLevel(int traceLevel)
|
||||
{
|
||||
m_traceLevel = traceLevel;
|
||||
}
|
||||
|
||||
// NOTE: Current code is identical to float, don't know how to specialize with template parameter that only covers one parameter
|
||||
|
||||
#ifdef STANDALONE
|
||||
int wmain(int argc, wchar_t* argv[])
|
||||
{
|
||||
LUSequenceParser<double, int> parser;
|
||||
std::vector<double> values;
|
||||
values.reserve(784000*6);
|
||||
std::vector<int> labels;
|
||||
labels.reserve(60000);
|
||||
parser.ParseInit(L"c:\\speech\\mnist\\mnist_train.txt", LabelFirst);
|
||||
//parser.ParseInit("c:\\speech\\parseTest.txt", LabelNone);
|
||||
int records = 0;
|
||||
do
|
||||
{
|
||||
int recordsRead = parser.Parse(10000, &values, &labels);
|
||||
if (recordsRead < 10000)
|
||||
parser.SetFilePosition(0); // go around again
|
||||
records += recordsRead;
|
||||
values.clear();
|
||||
labels.clear();
|
||||
}
|
||||
while (records < 150000);
|
||||
return records;
|
||||
}
|
||||
#endif
|
||||
|
||||
// instantiate UCI parsers for supported types
|
||||
template class LUSequenceParser<float, int>;
|
||||
template class LUSequenceParser<float, float>;
|
||||
template class LUSequenceParser<float, std::string>;
|
||||
template class LUSequenceParser<float, std::wstring>;
|
||||
template class LUSequenceParser<double, int>;
|
||||
template class LUSequenceParser<double, double>;
|
||||
template class LUSequenceParser<double, std::string>;
|
||||
template class LUSequenceParser<double, std::wstring>;
|
||||
|
||||
template <typename NumType, typename LabelType>
|
||||
void LUBatchLUSequenceParser<NumType, LabelType>::ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O")
|
||||
template<class NumType, class LabelType>
|
||||
long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested, std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, const map<wstring, long>& inputlabel2id, const map<wstring, long>& outputlabel2id)
|
||||
{
|
||||
LULUSequenceParser<NumType, LabelType>::ParseInit(fileName, dimLabelsIn, dimLabelsOut, beginSequenceIn, endSequenceIn, beginSequenceOut, endSequenceOut);
|
||||
}
|
||||
// transfer to member variables
|
||||
m_inputs = input;
|
||||
m_labels = labels;
|
||||
|
||||
template <typename NumType, typename LabelType>
|
||||
long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *inputs, std::vector<SequencePosition> *seqPos)
|
||||
{
|
||||
long linecnt;
|
||||
linecnt = LULUSequenceParser<NumType, LabelType>::Parse(recordsRequested, labels, inputs, seqPos);
|
||||
long recordCount = 0;
|
||||
long orgRecordCount = (long)labels->size();
|
||||
long lineCount = 0;
|
||||
bool bAtEOS = false; /// whether the reader is at the end of sentence position
|
||||
SequencePosition sequencePositionLast(0, 0, 0);
|
||||
/// get line
|
||||
wstring ch;
|
||||
|
||||
int prvat = 0;
|
||||
while (lineCount < recordsRequested && mFile.good())
|
||||
{
|
||||
getline(mFile, ch);
|
||||
ch = wtrim(ch);
|
||||
|
||||
if (mFile.eof())
|
||||
ParseReset(); /// restart from the corpus begining
|
||||
|
||||
std::vector<wstring> vstr;
|
||||
bool bBlankLine = (ch.length() == 0);
|
||||
if (bBlankLine && !bAtEOS && input->size() > 0 && labels->size() > 0)
|
||||
{
|
||||
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
|
||||
bAtEOS = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
vstr = wsep_string(ch, L" ");
|
||||
if (vstr.size() < 2)
|
||||
continue;
|
||||
|
||||
bAtEOS = false;
|
||||
vector<long> vtmp;
|
||||
for (size_t i = 0; i < vstr.size() - 1; i++)
|
||||
{
|
||||
if (inputlabel2id.find(vstr[i]) == inputlabel2id.end())
|
||||
LogicError("cannot find item %s in input label", vstr[i]);
|
||||
|
||||
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
|
||||
}
|
||||
if (outputlabel2id.find(vstr[vstr.size() - 1]) == outputlabel2id.end())
|
||||
LogicError("cannot find item %s in output label", vstr[vstr.size() - 1]);
|
||||
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
|
||||
input->push_back(vtmp);
|
||||
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
|
||||
/// below is for backward support
|
||||
vstr[0] == m_endTag) && input->size() > 0 && labels->size() > 0)
|
||||
{
|
||||
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
|
||||
bAtEOS = true;
|
||||
}
|
||||
|
||||
} // while
|
||||
|
||||
int prvat = 0;
|
||||
size_t i = 0;
|
||||
for (auto ptr = seqPos->begin(); ptr != seqPos->end(); ptr++, i++)
|
||||
{
|
||||
|
@ -359,15 +114,18 @@ long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested,
|
|||
stinfo.sLen = iln;
|
||||
stinfo.sBegin = prvat;
|
||||
stinfo.sEnd = (int)ptr->labelPos;
|
||||
mSentenceIndex2SentenceInfo.push_back(stinfo);
|
||||
mSentenceIndex2SentenceInfo.push_back(stinfo);
|
||||
|
||||
prvat = (int)ptr->labelPos;
|
||||
}
|
||||
|
||||
assert(mSentenceIndex2SentenceInfo.size() == linecnt);
|
||||
return linecnt;
|
||||
fprintf(stderr, "LUBatchLUSequenceParser: parse %d lines\n", lineCount);
|
||||
return lineCount;
|
||||
}
|
||||
|
||||
|
||||
template class LUBatchLUSequenceParser<float, std::string>;
|
||||
template class LUBatchLUSequenceParser<double, std::string>;
|
||||
template class LUBatchLUSequenceParser<float, std::wstring>;
|
||||
template class LUBatchLUSequenceParser<double, std::wstring>;
|
||||
}}}
|
||||
|
|
|
@ -20,151 +20,57 @@ using namespace std;
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
#define MAXSTRING 2048
|
||||
// UCI label location types
|
||||
enum LabelMode
|
||||
{
|
||||
LabelNone = 0,
|
||||
LabelFirst = 1,
|
||||
LabelLast = 2,
|
||||
};
|
||||
|
||||
enum ParseMode
|
||||
{
|
||||
ParseNormal = 0,
|
||||
ParseLineCount = 1
|
||||
};
|
||||
|
||||
enum SequenceFlags
|
||||
{
|
||||
seqFlagNull = 0,
|
||||
seqFlagLineBreak = 1, // line break on the parsed line
|
||||
seqFlagEmptyLine = 2, // empty line
|
||||
seqFlagStartLabel = 4,
|
||||
seqFlagStopLabel = 8
|
||||
};
|
||||
|
||||
// SequencePosition, save the ending indexes into the array for a sequence
|
||||
struct SequencePosition
|
||||
{
|
||||
size_t inputPos; // max position in the number array for this sequence
|
||||
size_t labelPos; // max position in the label array for this sequence
|
||||
long inputPos; // max position in the number array for this sequence
|
||||
long labelPos; // max position in the label array for this sequence
|
||||
unsigned flags; // flags that apply to this sequence
|
||||
SequencePosition(size_t inPos, size_t labelPos, unsigned flags):
|
||||
inputPos(inPos), labelPos(labelPos), flags(flags)
|
||||
SequencePosition(long inPos, long labelPos, unsigned flags) :
|
||||
inputPos(inPos), labelPos(labelPos), flags(flags)
|
||||
{}
|
||||
};
|
||||
|
||||
// LUSequenceParser - the parser for the UCI format files
|
||||
|
||||
// LUSequenceParser - the parser for the UCI format files
|
||||
// for ultimate speed, this class implements a state machine to read these format files
|
||||
template <typename NumType, typename LabelType=int>
|
||||
template <typename NumType, typename LabelType = wstring>
|
||||
class LUSequenceParser
|
||||
{
|
||||
public:
|
||||
using LabelIdType = long;
|
||||
|
||||
protected:
|
||||
enum ParseState
|
||||
{
|
||||
WholeNumber = 0,
|
||||
Remainder = 1,
|
||||
Exponent = 2,
|
||||
Whitespace = 3,
|
||||
Sign = 4,
|
||||
ExponentSign = 5,
|
||||
Period = 6,
|
||||
TheLetterE = 7,
|
||||
EndOfLine = 8,
|
||||
Label = 9, // any non-number things we run into
|
||||
ParseStateMax = 10, // number of parse states
|
||||
LineCountEOL = 10,
|
||||
LineCountOther = 11,
|
||||
AllStateMax = 12
|
||||
};
|
||||
|
||||
// type of label processing
|
||||
ParseMode m_parseMode;
|
||||
|
||||
// definition of label and feature dimensions
|
||||
size_t m_dimFeatures;
|
||||
|
||||
size_t m_dimLabelsIn;
|
||||
std::string m_beginSequenceIn; // starting sequence string (i.e. <s>)
|
||||
std::string m_endSequenceIn; // ending sequence string (i.e. </s>)
|
||||
wstring m_beginSequenceIn; // starting sequence string (i.e. <s>)
|
||||
wstring m_endSequenceIn; // ending sequence string (i.e. </s>)
|
||||
|
||||
size_t m_dimLabelsOut;
|
||||
std::string m_beginSequenceOut; // starting sequence string (i.e. 'O')
|
||||
std::string m_endSequenceOut; // ending sequence string (i.e. 'O')
|
||||
wstring m_beginSequenceOut; // starting sequence string (i.e. 'O')
|
||||
wstring m_endSequenceOut; // ending sequence string (i.e. 'O')
|
||||
|
||||
// level of screen output
|
||||
int m_traceLevel;
|
||||
|
||||
// current state of the state machine
|
||||
ParseState m_current_state;
|
||||
|
||||
// state tables
|
||||
DWORD *m_stateTable;
|
||||
|
||||
// numeric state machine variables
|
||||
double m_partialResult;
|
||||
double m_builtUpNumber;
|
||||
double m_divider;
|
||||
double m_wholeNumberMultiplier;
|
||||
double m_exponentMultiplier;
|
||||
|
||||
// label state machine variables
|
||||
size_t m_spaceDelimitedStart;
|
||||
size_t m_spaceDelimitedMax; // start of the next whitespace sequence (one past the end of the last word)
|
||||
int m_numbersConvertedThisLine;
|
||||
int m_labelsConvertedThisLine;
|
||||
int m_elementsConvertedThisLine;
|
||||
|
||||
// sequence state machine variables
|
||||
bool m_beginSequence;
|
||||
bool m_endSequence;
|
||||
std::string m_beginTag;
|
||||
std::string m_endTag;
|
||||
|
||||
// global stats
|
||||
int m_totalNumbersConverted;
|
||||
int m_totalLabelsConverted;
|
||||
wstring m_beginTag;
|
||||
wstring m_endTag;
|
||||
|
||||
// file positions/buffer
|
||||
FILE * m_pFile;
|
||||
int64_t m_byteCounter;
|
||||
int64_t m_fileSize;
|
||||
|
||||
BYTE * m_fileBuffer;
|
||||
size_t m_bufferStart;
|
||||
size_t m_bufferSize;
|
||||
|
||||
// last label was a string (for last label processing)
|
||||
bool m_lastLabelIsString;
|
||||
|
||||
// vectors to append to
|
||||
std::vector<vector<LabelType>>* m_inputs; // pointer to vectors to append with numbers
|
||||
std::vector<LabelType>* m_labels; // pointer to vector to append with labels (may be numeric)
|
||||
std::vector<vector<LabelIdType>>* m_inputs; // pointer to vectors to append with numbers
|
||||
std::vector<LabelIdType>* m_labels; // pointer to vector to append with labels (may be numeric)
|
||||
// FUTURE: do we want a vector to collect string labels in the non string label case? (signifies an error)
|
||||
|
||||
// SetState for a particular value
|
||||
void SetState(int value, ParseState m_current_state, ParseState next_state);
|
||||
|
||||
// SetStateRange - set states transitions for a range of values
|
||||
void SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state);
|
||||
|
||||
// SetupStateTables - setup state transition tables for each state
|
||||
// each state has a block of 256 states indexed by the incoming character
|
||||
void SetupStateTables();
|
||||
|
||||
// reset all line state variables
|
||||
void PrepareStartLine();
|
||||
|
||||
// reset all number accumulation variables
|
||||
void PrepareStartNumber();
|
||||
|
||||
// reset all state variables to start reading at a new position
|
||||
void PrepareStartPosition(size_t position);
|
||||
|
||||
// UpdateBuffer - load the next buffer full of data
|
||||
// returns - number of records read
|
||||
size_t UpdateBuffer();
|
||||
|
||||
public:
|
||||
|
||||
// LUSequenceParser constructor
|
||||
|
@ -176,14 +82,6 @@ public:
|
|||
~LUSequenceParser();
|
||||
|
||||
public:
|
||||
// SetParseMode - Set the parsing mode
|
||||
// mode - set mode to either ParseLineCount, or ParseNormal
|
||||
void SetParseMode(ParseMode mode);
|
||||
|
||||
// SetTraceLevel - Set the level of screen output
|
||||
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
|
||||
void SetTraceLevel(int traceLevel);
|
||||
|
||||
|
||||
// ParseInit - Initialize a parse of a file
|
||||
// fileName - path to the file to open
|
||||
|
@ -196,7 +94,7 @@ public:
|
|||
// endSequenceOut - endSequence output label
|
||||
// bufferSize - size of temporary buffer to store reads
|
||||
// startPosition - file position on which we should start
|
||||
void ParseInit(LPCWSTR fileName, size_t dimFeatures, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn="<s>", std::string endSequenceIn="</s>", std::string beginSequenceOut="O", std::string endSequenceOut="O", size_t bufferSize=1024*256, size_t startPosition=0)
|
||||
void ParseInit(LPCWSTR fileName, size_t dimFeatures, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut )
|
||||
{
|
||||
assert(fileName != NULL);
|
||||
m_dimFeatures = dimFeatures;
|
||||
|
@ -207,10 +105,7 @@ public:
|
|||
m_beginSequenceOut = beginSequenceOut;
|
||||
m_endSequenceOut = endSequenceOut;
|
||||
|
||||
m_parseMode = ParseNormal;
|
||||
m_traceLevel = 0;
|
||||
m_bufferSize = bufferSize;
|
||||
m_bufferStart = startPosition;
|
||||
|
||||
m_beginTag = m_beginSequenceIn;
|
||||
m_endTag = m_endSequenceIn;
|
||||
|
@ -225,18 +120,23 @@ public:
|
|||
int rc = _fseeki64(m_pFile, 0, SEEK_END);
|
||||
if (rc)
|
||||
RuntimeError("LUSequenceParser::ParseInit - error seeking in file");
|
||||
|
||||
m_fileBuffer = new BYTE[m_bufferSize];
|
||||
}
|
||||
};
|
||||
|
||||
/// language model sequence parser
|
||||
template <typename NumType, typename LabelType>
|
||||
class LULUSequenceParser : public LUSequenceParser<NumType, LabelType>
|
||||
typedef struct{
|
||||
size_t sLen;
|
||||
int sBegin;
|
||||
int sEnd;
|
||||
} stSentenceInfo;
|
||||
|
||||
template <typename NumType, typename LabelType = wstring>
|
||||
class LUBatchLUSequenceParser : public LUSequenceParser<NumType, LabelType>
|
||||
{
|
||||
protected:
|
||||
FILE * mFile;
|
||||
public:
|
||||
wifstream mFile;
|
||||
std::wstring mFileName;
|
||||
vector<stSentenceInfo> mSentenceIndex2SentenceInfo;
|
||||
|
||||
public:
|
||||
using LUSequenceParser<NumType, LabelType>::m_dimFeatures;
|
||||
|
@ -246,27 +146,21 @@ public:
|
|||
using LUSequenceParser<NumType, LabelType>::m_dimLabelsOut;
|
||||
using LUSequenceParser<NumType, LabelType>::m_beginSequenceOut;
|
||||
using LUSequenceParser<NumType, LabelType>::m_endSequenceOut;
|
||||
using LUSequenceParser<NumType, LabelType>::m_parseMode;
|
||||
using LUSequenceParser<NumType, LabelType>::m_traceLevel;
|
||||
using LUSequenceParser<NumType, LabelType>::m_bufferSize;
|
||||
using LUSequenceParser<NumType, LabelType>::m_bufferStart;
|
||||
using LUSequenceParser<NumType, LabelType>::m_beginTag;
|
||||
using LUSequenceParser<NumType, LabelType>::m_endTag;
|
||||
using LUSequenceParser<NumType, LabelType>::m_fileBuffer;
|
||||
using LUSequenceParser<NumType, LabelType>::m_fileSize;
|
||||
using LUSequenceParser<NumType, LabelType>::m_inputs;
|
||||
using LUSequenceParser<NumType, LabelType>::m_labels;
|
||||
using LUSequenceParser<NumType, LabelType>::m_beginSequence;
|
||||
using LUSequenceParser<NumType, LabelType>::m_endSequence;
|
||||
using LUSequenceParser<NumType, LabelType>::m_totalNumbersConverted;
|
||||
LULUSequenceParser() {
|
||||
mFile = nullptr;
|
||||
LUBatchLUSequenceParser() {
|
||||
};
|
||||
~LULUSequenceParser() {
|
||||
if (mFile) fclose(mFile);
|
||||
~LUBatchLUSequenceParser() {
|
||||
mFile.close();
|
||||
}
|
||||
|
||||
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O")
|
||||
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut)
|
||||
{
|
||||
assert(fileName != NULL);
|
||||
mFileName = fileName;
|
||||
|
@ -277,33 +171,27 @@ public:
|
|||
m_beginSequenceOut = beginSequenceOut;
|
||||
m_endSequenceOut = endSequenceOut;
|
||||
|
||||
m_parseMode = ParseNormal;
|
||||
m_traceLevel = 0;
|
||||
m_bufferSize = 0;
|
||||
m_bufferStart = 0;
|
||||
|
||||
m_beginTag = m_beginSequenceIn;
|
||||
m_endTag = m_endSequenceIn;
|
||||
|
||||
m_fileSize = -1;
|
||||
m_fileBuffer = NULL;
|
||||
mFile.close();
|
||||
|
||||
if (mFile) fclose(mFile);
|
||||
|
||||
if (_wfopen_s(&mFile, fileName, L"rt") != 0)
|
||||
mFile.open(fileName, wifstream::in);
|
||||
if (!mFile.good())
|
||||
RuntimeError("cannot open file %ls", fileName);
|
||||
}
|
||||
|
||||
void ParseReset()
|
||||
{
|
||||
if (mFile) fseek(mFile, 0, SEEK_SET);
|
||||
mFile.seekg(0, mFile.beg);
|
||||
}
|
||||
|
||||
void AddOneItem(std::vector<LabelType> *labels, std::vector<vector<LabelType>> *input, std::vector<SequencePosition> *seqPos, long& lineCount,
|
||||
void AddOneItem(std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, long& lineCount,
|
||||
long & recordCount, long orgRecordCount, SequencePosition& sequencePositionLast)
|
||||
{
|
||||
SequencePosition sequencePos(input->size(), labels->size(),
|
||||
m_beginSequence ? seqFlagStartLabel : 0 | m_endSequence ? seqFlagStopLabel : 0 | seqFlagLineBreak);
|
||||
SequencePosition sequencePos((long)input->size(), (long)labels->size(), 1);
|
||||
seqPos->push_back(sequencePos);
|
||||
sequencePositionLast = sequencePos;
|
||||
|
||||
|
@ -323,94 +211,8 @@ public:
|
|||
// numbers - pointer to vector to return the numbers
|
||||
// seqPos - pointers to the other two arrays showing positions of each sequence
|
||||
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
|
||||
long Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *input, std::vector<SequencePosition> *seqPos)
|
||||
{
|
||||
assert(labels != NULL || m_dimLabelsIn == 0 && m_dimLabelsOut == 0 || m_parseMode == ParseLineCount);
|
||||
|
||||
// transfer to member variables
|
||||
m_inputs = input;
|
||||
m_labels = labels;
|
||||
|
||||
long TickStart = GetTickCount();
|
||||
long recordCount = 0;
|
||||
long orgRecordCount = (long)labels->size();
|
||||
long lineCount = 0;
|
||||
bool bAtEOS = false; /// whether the reader is at the end of sentence position
|
||||
SequencePosition sequencePositionLast(0, 0, seqFlagNull);
|
||||
/// get line
|
||||
char ch2[MAXSTRING];
|
||||
while (lineCount < recordsRequested && fgets(ch2, MAXSTRING, mFile) != nullptr)
|
||||
{
|
||||
|
||||
string ch = ch2;
|
||||
std::vector<string> vstr;
|
||||
bool bBlankLine = (trim(ch).length() == 0);
|
||||
if (bBlankLine && !bAtEOS && input->size() > 0 && labels->size() > 0)
|
||||
{
|
||||
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
|
||||
bAtEOS = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
vstr = sep_string(ch, " ");
|
||||
if (vstr.size() < 2)
|
||||
continue;
|
||||
|
||||
bAtEOS = false;
|
||||
vector<LabelType> vtmp;
|
||||
for (size_t i = 0; i < vstr.size() - 1; i++)
|
||||
{
|
||||
vtmp.push_back(vstr[i]);
|
||||
}
|
||||
labels->push_back(vstr[vstr.size() - 1]);
|
||||
input->push_back(vtmp);
|
||||
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
|
||||
/// below is for backward support
|
||||
vstr[0] == m_endTag) && input->size() > 0 && labels->size() > 0)
|
||||
{
|
||||
AddOneItem(labels, input, seqPos, lineCount, recordCount, orgRecordCount, sequencePositionLast);
|
||||
bAtEOS = true;
|
||||
}
|
||||
|
||||
} // while
|
||||
|
||||
long TickStop = GetTickCount();
|
||||
|
||||
long TickDelta = TickStop - TickStart;
|
||||
|
||||
if (m_traceLevel > 2)
|
||||
fprintf(stderr, "\n%d ms, %d numbers parsed\n\n", TickDelta, m_totalNumbersConverted);
|
||||
return lineCount;
|
||||
}
|
||||
|
||||
long Parse(size_t recordsRequested, std::vector<long> *labels, std::vector<vector<long>> *input, std::vector<SequencePosition> *seqPos, const map<wstring, long>& inputlabel2id, const map<wstring, long>& outputlabel2id);
|
||||
|
||||
};
|
||||
|
||||
typedef struct{
|
||||
size_t sLen;
|
||||
int sBegin;
|
||||
int sEnd;
|
||||
} stSentenceInfo;
|
||||
/// language model sequence parser
|
||||
template <typename NumType, typename LabelType>
|
||||
class LUBatchLUSequenceParser: public LULUSequenceParser<NumType, LabelType>
|
||||
{
|
||||
public:
|
||||
vector<stSentenceInfo> mSentenceIndex2SentenceInfo;
|
||||
|
||||
public:
|
||||
LUBatchLUSequenceParser() { };
|
||||
~LUBatchLUSequenceParser() { }
|
||||
|
||||
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, std::string beginSequenceIn = "<s>", std::string endSequenceIn = "</s>", std::string beginSequenceOut = "O", std::string endSequenceOut = "O");
|
||||
|
||||
// Parse - Parse the data
|
||||
// recordsRequested - number of records requested
|
||||
// labels - pointer to vector to return the labels
|
||||
// numbers - pointer to vector to return the numbers
|
||||
// seqPos - pointers to the other two arrays showing positions of each sequence
|
||||
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
|
||||
long Parse(size_t recordsRequested, std::vector<LabelType> *labels, std::vector<vector<LabelType>> *inputs, std::vector<SequencePosition> *seqPos);
|
||||
|
||||
};
|
||||
}}};
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -18,12 +18,16 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
#ifdef DBG_SMT
|
||||
#define CACHE_BLOG_SIZE 100
|
||||
#else
|
||||
#define CACHE_BLOG_SIZE 50000
|
||||
#endif
|
||||
|
||||
#define STRIDX2CLS L"idx2cls"
|
||||
#define CLASSINFO L"classinfo"
|
||||
|
||||
#define MAX_STRING 2048
|
||||
#define MAX_STRING 100000
|
||||
|
||||
#define NULLLABEL 65532
|
||||
|
||||
|
@ -44,22 +48,18 @@ protected:
|
|||
|
||||
std::wstring m_file;
|
||||
public:
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
int nwords, dims, nsamps, nglen, nmefeats;
|
||||
using LabelType = wstring;
|
||||
using LabelIdType = long;
|
||||
long nwords, dims, nsamps, nglen, nmefeats;
|
||||
|
||||
int m_seed;
|
||||
bool mRandomize;
|
||||
|
||||
int class_size;
|
||||
map<int, vector<int>> class_words;
|
||||
vector<int>class_cn;
|
||||
|
||||
public:
|
||||
/// deal with OOV
|
||||
map<string, string> mWordMapping;
|
||||
map<LabelType, LabelType> mWordMapping;
|
||||
string mWordMappingFn;
|
||||
string mUnkStr;
|
||||
LabelType mUnkStr;
|
||||
|
||||
public:
|
||||
/// accumulated number of sentneces read so far
|
||||
|
@ -67,8 +67,7 @@ public:
|
|||
|
||||
protected:
|
||||
|
||||
LULUSequenceParser<ElemType, LabelType> m_parser;
|
||||
// LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
|
||||
LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
|
||||
size_t m_mbSize; // size of minibatch requested
|
||||
size_t m_mbStartSample; // starting sample # of the next minibatch
|
||||
size_t m_epochSize; // size of an epoch
|
||||
|
@ -119,12 +118,12 @@ protected:
|
|||
LabelKind type; // labels are categories, create mapping table
|
||||
std::map<LabelIdType, LabelType> mapIdToLabel;
|
||||
std::map<LabelType, LabelIdType> mapLabelToId;
|
||||
map<string, int> word4idx;
|
||||
map<int, string> idx4word;
|
||||
map<LabelType, LabelIdType> word4idx;
|
||||
map<LabelIdType, LabelType> idx4word;
|
||||
LabelIdType idMax; // maximum label ID we have encountered so far
|
||||
LabelIdType dim; // maximum label ID we will ever see (used for array dimensions)
|
||||
std::string beginSequence; // starting sequence string (i.e. <s>)
|
||||
std::string endSequence; // ending sequence string (i.e. </s>)
|
||||
long dim; // maximum label ID we will ever see (used for array dimensions)
|
||||
LabelType beginSequence; // starting sequence string (i.e. <s>)
|
||||
LabelType endSequence; // ending sequence string (i.e. </s>)
|
||||
bool busewordmap; /// whether using wordmap to map unseen words to unk
|
||||
std::wstring mapName;
|
||||
std::wstring fileToWrite; // set to the path if we need to write out the label file
|
||||
|
@ -146,37 +145,34 @@ protected:
|
|||
void WriteLabelFile();
|
||||
void LoadLabelFile(const std::wstring &filePath, std::vector<LabelType>& retLabels);
|
||||
|
||||
LabelIdType GetIdFromLabel(const std::string& label, LabelInfo& labelInfo);
|
||||
bool GetIdFromLabel(const vector<string>& label, LabelInfo& labelInfo, vector<LabelIdType>& val);
|
||||
bool CheckIdFromLabel(const std::string& labelValue, const LabelInfo& labelInfo, unsigned & labelId);
|
||||
LabelIdType GetIdFromLabel(const LabelType& label, LabelInfo& labelInfo);
|
||||
bool GetIdFromLabel(const vector<LabelIdType>& label, vector<LabelIdType>& val);
|
||||
bool CheckIdFromLabel(const LabelType& labelValue, const LabelInfo& labelInfo, unsigned & labelId);
|
||||
|
||||
virtual bool ReadRecord(size_t readSample);
|
||||
bool SentenceEnd();
|
||||
|
||||
public:
|
||||
virtual void Init(const ConfigParameters& config);
|
||||
void ReadLabelInfo(const wstring & vocfile, map<string, int> & word4idx,
|
||||
map<int, string>& idx4word) ;
|
||||
void ChangeMaping(const map<string, string>& maplist,
|
||||
const string & unkstr ,
|
||||
map<string, int> & word4idx);
|
||||
void Init(const ConfigParameters& ){};
|
||||
void ReadLabelInfo(const wstring & vocfile, map<LabelType, LabelIdType> & word4idx,
|
||||
map<LabelIdType, LabelType>& idx4word);
|
||||
void ChangeMaping(const map<LabelType, LabelType>& maplist,
|
||||
const LabelType& unkstr,
|
||||
map<LabelType, LabelIdType> & word4idx);
|
||||
|
||||
void ReadWord(char *wrod, FILE *fin);
|
||||
void Destroy() {};
|
||||
|
||||
virtual void Destroy();
|
||||
LUSequenceReader() {
|
||||
m_featuresBuffer=NULL; m_labelsBuffer=NULL; m_clsinfoRead = false; m_idx2clsRead = false;
|
||||
}
|
||||
virtual ~LUSequenceReader();
|
||||
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
|
||||
~LUSequenceReader(){};
|
||||
void StartMinibatchLoop(size_t , size_t , size_t = requestDataSize) {};
|
||||
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
|
||||
void SentenceEnd(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
|
||||
|
||||
public:
|
||||
int GetSentenceEndIdFromOutputLabel();
|
||||
};
|
||||
|
@ -185,9 +181,9 @@ template<class ElemType>
|
|||
class BatchLUSequenceReader : public LUSequenceReader<ElemType>
|
||||
{
|
||||
public:
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
using LUSequenceReader<ElemType>::mWordMappingFn;
|
||||
using LabelType = wstring;
|
||||
using LabelIdType = long;
|
||||
using LUSequenceReader<ElemType>::mWordMappingFn;
|
||||
using LUSequenceReader<ElemType>::m_cachingReader;
|
||||
using LUSequenceReader<ElemType>::mWordMapping;
|
||||
using LUSequenceReader<ElemType>::mUnkStr;
|
||||
|
@ -197,7 +193,6 @@ public:
|
|||
using LUSequenceReader<ElemType>::labelInfoMin;
|
||||
using LUSequenceReader<ElemType>::labelInfoMax;
|
||||
using LUSequenceReader<ElemType>::m_featureDim;
|
||||
using LUSequenceReader<ElemType>::class_size;
|
||||
using LUSequenceReader<ElemType>::m_labelInfo;
|
||||
// using LUSequenceReader<ElemType>::m_labelInfoIn;
|
||||
using LUSequenceReader<ElemType>::m_mbStartSample;
|
||||
|
@ -247,8 +242,8 @@ private:
|
|||
size_t mLastPosInSentence;
|
||||
size_t mNumRead ;
|
||||
|
||||
std::vector<vector<LabelType>> m_featureTemp;
|
||||
std::vector<LabelType> m_labelTemp;
|
||||
std::vector<vector<LabelIdType>> m_featureTemp;
|
||||
std::vector<LabelIdType> m_labelTemp;
|
||||
|
||||
bool mSentenceEnd;
|
||||
bool mSentenceBegin;
|
||||
|
|
Загрузка…
Ссылка в новой задаче