Somehow these files were not committed to source control.
This commit is contained in:
Родитель
539c9968aa
Коммит
590172dcc2
|
@ -0,0 +1,765 @@
|
|||
// DSSMParser.cpp : Parses the DSSM format using a custom state machine (for speed)
|
||||
//
|
||||
//
|
||||
// <copyright file="DSSMParser.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "DSSMParser.h"
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
|
||||
|
||||
// SetState for a particular value
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetState(int value, ParseState m_current_state, ParseState next_state)
|
||||
{
|
||||
DWORD ul = (DWORD)next_state;
|
||||
int range_shift = ((int)m_current_state) << 8;
|
||||
m_stateTable[range_shift+value] = ul;
|
||||
}
|
||||
|
||||
// SetStateRange - set states transitions for a range of values
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state)
|
||||
{
|
||||
DWORD ul = (DWORD)next_state;
|
||||
int range_shift = ((int)m_current_state) << 8;
|
||||
for (int value = value1; value <= value2; value++)
|
||||
{
|
||||
m_stateTable[range_shift+value] = ul;
|
||||
}
|
||||
}
|
||||
|
||||
// SetupStateTables - setup state transition tables for each state
|
||||
// each state has a block of 256 states indexed by the incoming character
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetupStateTables()
|
||||
{
|
||||
//=========================
|
||||
// STATE = WHITESPACE
|
||||
//=========================
|
||||
|
||||
SetStateRange(0,255, Whitespace, Label);
|
||||
SetStateRange('0', '9', Whitespace, WholeNumber);
|
||||
SetState('-', Whitespace, Sign);
|
||||
SetState('+', Whitespace, Sign);
|
||||
// whitespace
|
||||
SetState(' ', Whitespace, Whitespace);
|
||||
SetState('\t', Whitespace, Whitespace);
|
||||
SetState('\r', Whitespace, Whitespace);
|
||||
SetState('\n', Whitespace, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = NEGATIVE_SIGN
|
||||
//=========================
|
||||
|
||||
SetStateRange( 0, 255, Sign, Label);
|
||||
SetStateRange( '0', '9', Sign, WholeNumber);
|
||||
// whitespace
|
||||
SetState(' ', Sign, Whitespace);
|
||||
SetState('\t', Sign, Whitespace);
|
||||
SetState('\r', Sign, Whitespace);
|
||||
SetState('\n', Sign, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = NUMBER
|
||||
//=========================
|
||||
|
||||
SetStateRange( 0, 255, WholeNumber, Label);
|
||||
SetStateRange( '0', '9', WholeNumber, WholeNumber);
|
||||
SetState('.', WholeNumber, Period);
|
||||
SetState('e', WholeNumber, TheLetterE);
|
||||
SetState('E', WholeNumber, TheLetterE);
|
||||
// whitespace
|
||||
SetState(' ', WholeNumber, Whitespace);
|
||||
SetState('\t', WholeNumber, Whitespace);
|
||||
SetState('\r', WholeNumber, Whitespace);
|
||||
SetState('\n', WholeNumber, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = PERIOD
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Period, Label);
|
||||
SetStateRange('0', '9', Period, Remainder);
|
||||
// whitespace
|
||||
SetState(' ', Period, Whitespace);
|
||||
SetState('\t', Period, Whitespace);
|
||||
SetState('\r', Period, Whitespace);
|
||||
SetState('\n', Period, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = REMAINDER
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Remainder, Label);
|
||||
SetStateRange('0', '9', Remainder, Remainder);
|
||||
SetState('e', Remainder, TheLetterE);
|
||||
SetState('E', Remainder, TheLetterE);
|
||||
// whitespace
|
||||
SetState(' ', Remainder, Whitespace);
|
||||
SetState('\t', Remainder, Whitespace);
|
||||
SetState('\r', Remainder, Whitespace);
|
||||
SetState('\n', Remainder, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = THE_LETTER_E
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, TheLetterE, Label);
|
||||
SetStateRange('0', '9', TheLetterE, Exponent);
|
||||
SetState('-', TheLetterE, ExponentSign);
|
||||
SetState('+', TheLetterE, ExponentSign);
|
||||
// whitespace
|
||||
SetState(' ', TheLetterE, Whitespace);
|
||||
SetState('\t', TheLetterE, Whitespace);
|
||||
SetState('\r', TheLetterE, Whitespace);
|
||||
SetState('\n', TheLetterE, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = EXPONENT_NEGATIVE_SIGN
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, ExponentSign, Label);
|
||||
SetStateRange('0', '9', ExponentSign, Exponent);
|
||||
// whitespace
|
||||
SetState(' ', ExponentSign, Whitespace);
|
||||
SetState('\t', ExponentSign, Whitespace);
|
||||
SetState('\r', ExponentSign, Whitespace);
|
||||
SetState('\n', ExponentSign, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = EXPONENT
|
||||
//=========================
|
||||
|
||||
SetStateRange(0, 255, Exponent, Label);
|
||||
SetStateRange('0', '9', Exponent, Exponent);
|
||||
// whitespace
|
||||
SetState(' ', Exponent, Whitespace);
|
||||
SetState('\t', Exponent, Whitespace);
|
||||
SetState('\r', Exponent, Whitespace);
|
||||
SetState('\n', Exponent, EndOfLine);
|
||||
|
||||
//=========================
|
||||
// STATE = END_OF_LINE
|
||||
//=========================
|
||||
SetStateRange(0, 255, EndOfLine, Label);
|
||||
SetStateRange( '0', '9', EndOfLine, WholeNumber);
|
||||
SetState( '-', EndOfLine, Sign);
|
||||
SetState( '\n', EndOfLine, EndOfLine);
|
||||
// whitespace
|
||||
SetState(' ', EndOfLine, Whitespace);
|
||||
SetState('\t', EndOfLine, Whitespace);
|
||||
SetState('\r', EndOfLine, Whitespace);
|
||||
|
||||
|
||||
//=========================
|
||||
// STATE = LABEL
|
||||
//=========================
|
||||
SetStateRange(0, 255, Label, Label);
|
||||
SetState('\n', Label, EndOfLine);
|
||||
// whitespace
|
||||
SetState(' ', Label, Whitespace);
|
||||
SetState('\t', Label, Whitespace);
|
||||
SetState('\r', Label, Whitespace);
|
||||
|
||||
//=========================
|
||||
// STATE = LINE_COUNT_EOL
|
||||
//=========================
|
||||
SetStateRange(0, 255, LineCountEOL, LineCountOther);
|
||||
SetState( '\n', LineCountEOL, LineCountEOL);
|
||||
|
||||
//=========================
|
||||
// STATE = LINE_COUNT_OTHER
|
||||
//=========================
|
||||
SetStateRange(0, 255, LineCountOther, LineCountOther);
|
||||
SetState('\n', LineCountOther, LineCountEOL);
|
||||
}
|
||||
|
||||
|
||||
// reset all line state variables
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::PrepareStartLine()
|
||||
{
|
||||
m_numbersConvertedThisLine = 0;
|
||||
m_labelsConvertedThisLine = 0;
|
||||
m_elementsConvertedThisLine = 0;
|
||||
m_spaceDelimitedStart = m_byteCounter;
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
m_lastLabelIsString = false;
|
||||
}
|
||||
|
||||
// reset all number accumulation variables
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::PrepareStartNumber()
|
||||
{
|
||||
m_partialResult = 0;
|
||||
m_builtUpNumber = 0;
|
||||
m_divider = 0;
|
||||
m_wholeNumberMultiplier = 1;
|
||||
m_exponentMultiplier = 1;
|
||||
}
|
||||
|
||||
// reset all state variables to start reading at a new position
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::PrepareStartPosition(size_t position)
|
||||
{
|
||||
m_current_state = Whitespace;
|
||||
m_byteCounter = position; // must come before PrepareStartLine...
|
||||
m_bufferStart = position;
|
||||
|
||||
// prepare state machine for new number and new line
|
||||
PrepareStartNumber();
|
||||
PrepareStartLine();
|
||||
m_totalNumbersConverted = 0;
|
||||
m_totalLabelsConverted = 0;
|
||||
}
|
||||
|
||||
// DSSMParser constructor
|
||||
template <typename NumType, typename LabelType>
|
||||
DSSMParser<NumType, LabelType>::DSSMParser()
|
||||
{
|
||||
Init();
|
||||
}
|
||||
|
||||
// setup all the state variables and state tables for state machine
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::Init()
|
||||
{
|
||||
PrepareStartPosition(0);
|
||||
m_fileBuffer = NULL;
|
||||
m_pFile = NULL;
|
||||
m_stateTable = new DWORD[AllStateMax * 256];
|
||||
SetupStateTables();
|
||||
}
|
||||
|
||||
// Parser destructor
|
||||
template <typename NumType, typename LabelType>
|
||||
DSSMParser<NumType, LabelType>::~DSSMParser()
|
||||
{
|
||||
delete m_stateTable;
|
||||
delete m_fileBuffer;
|
||||
if (m_pFile)
|
||||
fclose(m_pFile);
|
||||
}
|
||||
|
||||
// DoneWithLabel - Called when a string label is found
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::DoneWithLabel()
|
||||
{
|
||||
// if we haven't set the max yet, use the current byte Counter
|
||||
if (m_spaceDelimitedMax <= m_spaceDelimitedStart)
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
{
|
||||
std::string label((LPCSTR)&m_fileBuffer[m_spaceDelimitedStart-m_bufferStart], m_spaceDelimitedMax-m_spaceDelimitedStart);
|
||||
fprintf(stderr, "\n** String found in numeric-only file: %s\n", label.c_str());
|
||||
m_labelsConvertedThisLine++;
|
||||
m_elementsConvertedThisLine++;
|
||||
m_lastLabelIsString = true;
|
||||
}
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
// Called when a number is complete
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::DoneWithValue()
|
||||
{
|
||||
// if we are storing it
|
||||
if (m_numbers != NULL)
|
||||
{
|
||||
NumType FinalResult = 0;
|
||||
if (m_current_state == Exponent)
|
||||
{
|
||||
FinalResult = (NumType)(m_partialResult*pow(10.0,m_exponentMultiplier * m_builtUpNumber));
|
||||
}
|
||||
else if (m_divider != 0)
|
||||
{
|
||||
FinalResult = (NumType)(m_partialResult + (m_builtUpNumber / m_divider));
|
||||
}
|
||||
else
|
||||
{
|
||||
FinalResult = (NumType)m_builtUpNumber;
|
||||
}
|
||||
|
||||
FinalResult = (NumType)(FinalResult*m_wholeNumberMultiplier);
|
||||
|
||||
// if it's a label, store in label location instead of number location
|
||||
int index = m_elementsConvertedThisLine;
|
||||
bool stored=false;
|
||||
if (m_startLabels <= index && index < m_startLabels + m_dimLabels)
|
||||
{
|
||||
StoreLabel(FinalResult);
|
||||
stored=true;
|
||||
}
|
||||
if (m_startFeatures <= index && index < m_startFeatures + m_dimFeatures)
|
||||
{
|
||||
m_numbers->push_back(FinalResult);
|
||||
m_totalNumbersConverted++;
|
||||
m_numbersConvertedThisLine++;
|
||||
m_elementsConvertedThisLine++;
|
||||
m_lastLabelIsString = false;
|
||||
stored=true;
|
||||
}
|
||||
// if we haven't stored anything we need to skip the current symbol, so increment
|
||||
if (!stored)
|
||||
{
|
||||
m_elementsConvertedThisLine++;
|
||||
}
|
||||
}
|
||||
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
// store label is specialized by LabelType
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::StoreLabel(NumType value)
|
||||
{
|
||||
m_labels->push_back((LabelType)value);
|
||||
m_totalNumbersConverted++;
|
||||
m_numbersConvertedThisLine++;
|
||||
m_elementsConvertedThisLine++;
|
||||
m_lastLabelIsString = false;
|
||||
}
|
||||
|
||||
// StoreLastLabel - store the last label (for numeric types), tranfers to label vector
|
||||
// string label types handled in specialization
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::StoreLastLabel()
|
||||
{
|
||||
assert(!m_lastLabelIsString); // file format error, last label was a string...
|
||||
NumType value = m_numbers->back();
|
||||
m_numbers->pop_back();
|
||||
m_labels->push_back((LabelType)value);
|
||||
}
|
||||
|
||||
// ParseInit - Initialize a parse of a file
|
||||
// fileName - path to the file to open
|
||||
// startFeatures - column (zero based) where features start
|
||||
// dimFeatures - number of features
|
||||
// startLabels - column (zero based) where Labels start
|
||||
// dimLabels - number of Labels
|
||||
// bufferSize - size of temporary buffer to store reads
|
||||
// startPosition - file position on which we should start
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::ParseInit(LPCWSTR fileName, size_t startFeatures, size_t dimFeatures, size_t startLabels, size_t dimLabels, size_t bufferSize, size_t startPosition)
|
||||
{
|
||||
assert(fileName != NULL);
|
||||
m_startLabels = startLabels;
|
||||
m_dimLabels = dimLabels;
|
||||
m_startFeatures = startFeatures;
|
||||
m_dimFeatures = dimFeatures;
|
||||
m_parseMode = ParseNormal;
|
||||
m_traceLevel = 0;
|
||||
m_bufferSize = bufferSize;
|
||||
m_bufferStart = startPosition;
|
||||
|
||||
// if we have a file already open, cleanup
|
||||
if (m_pFile != NULL)
|
||||
DSSMParser<NumType, LabelType>::~DSSMParser();
|
||||
|
||||
errno_t err = _wfopen_s( &m_pFile, fileName, L"rb" );
|
||||
if (err)
|
||||
std::runtime_error("DSSMParser::ParseInit - error opening file");
|
||||
int rc = _fseeki64(m_pFile, 0, SEEK_END);
|
||||
if (rc)
|
||||
std::runtime_error("DSSMParser::ParseInit - error seeking in file");
|
||||
|
||||
m_fileSize = GetFilePosition();
|
||||
m_fileBuffer = new BYTE[m_bufferSize];
|
||||
SetFilePosition(startPosition);
|
||||
}
|
||||
|
||||
// GetFilePosition - Get the current file position in the text file
|
||||
// returns current position in the file
|
||||
template <typename NumType, typename LabelType>
|
||||
int64_t DSSMParser<NumType, LabelType>::GetFilePosition()
|
||||
{
|
||||
int64_t position = _ftelli64(m_pFile);
|
||||
if (position == -1L)
|
||||
std::runtime_error("DSSMParser::GetFilePosition - error retrieving file position in file");
|
||||
return position;
|
||||
}
|
||||
|
||||
// SetFilePosition - Set the current file position from the beginning of the file, and read in the first block of data
|
||||
// state machine mode will be initialized similar to the beginning of the file
|
||||
// it is recommneded that only return values from GetFilePosition() known to be the start of a line
|
||||
// and zero be passed to this function
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetFilePosition(int64_t position)
|
||||
{
|
||||
int rc = _fseeki64(m_pFile, position, SEEK_SET);
|
||||
if (rc)
|
||||
std::runtime_error("DSSMParser::SetFilePosition - error seeking in file");
|
||||
|
||||
// setup state machine to start at this position
|
||||
PrepareStartPosition(position);
|
||||
|
||||
// read in the first buffer of data from this position, first buffer is expected to be read after a reposition
|
||||
UpdateBuffer();
|
||||
|
||||
// FUTURE: in debug we could validate the value is either 0, or the previous character is a '\n'
|
||||
}
|
||||
|
||||
// HasMoreData - test if the current dataset have more data, or just whitespace
|
||||
// returns - true if it has more data, false if not
|
||||
template <typename NumType, typename LabelType>
|
||||
bool DSSMParser<NumType, LabelType>::HasMoreData()
|
||||
{
|
||||
long long byteCounter = m_byteCounter;
|
||||
size_t bufferIndex = m_byteCounter-m_bufferStart;
|
||||
|
||||
// test without moving parser state
|
||||
for (;byteCounter < m_fileSize; byteCounter++, bufferIndex++)
|
||||
{
|
||||
// if we reach the end of the buffer, just assume we have more data
|
||||
// won't be right 100% of the time, but close enough
|
||||
if (bufferIndex >= m_bufferSize)
|
||||
return true;
|
||||
|
||||
char ch = m_fileBuffer[bufferIndex];
|
||||
ParseState nextState = (ParseState)m_stateTable[(Whitespace<<8)+ch];
|
||||
if (!(nextState == Whitespace || nextState == EndOfLine))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// UpdateBuffer - load the next buffer full of data
|
||||
// returns - number of records read
|
||||
template <typename NumType, typename LabelType>
|
||||
size_t DSSMParser<NumType, LabelType>::UpdateBuffer()
|
||||
{
|
||||
// state machine might want to look back this far, so copy to beginning
|
||||
size_t saveBytes = m_byteCounter-m_spaceDelimitedStart;
|
||||
assert(saveBytes < m_bufferSize);
|
||||
if (saveBytes)
|
||||
{
|
||||
memcpy_s(m_fileBuffer, m_bufferSize, &m_fileBuffer[m_byteCounter-m_bufferStart-saveBytes], saveBytes);
|
||||
m_bufferStart = m_byteCounter-saveBytes;
|
||||
}
|
||||
|
||||
// read the next block
|
||||
size_t bytesToRead = min(m_bufferSize, m_fileSize-m_bufferStart)-saveBytes;
|
||||
size_t bytesRead = fread(m_fileBuffer+saveBytes, 1, bytesToRead, m_pFile);
|
||||
if (bytesRead == 0 && ferror(m_pFile))
|
||||
std::runtime_error("DSSMParser::UpdateBuffer - error reading file");
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetParseMode(ParseMode mode)
|
||||
{
|
||||
// if already in this mode, nothing to do
|
||||
if (m_parseMode == mode)
|
||||
return;
|
||||
|
||||
// switching modes
|
||||
if (mode == ParseLineCount)
|
||||
m_current_state = LineCountOther;
|
||||
else
|
||||
{
|
||||
m_current_state = Whitespace;
|
||||
PrepareStartLine();
|
||||
PrepareStartNumber();
|
||||
}
|
||||
m_parseMode = mode;
|
||||
}
|
||||
|
||||
// SetTraceLevel - Set the level of screen output
|
||||
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
|
||||
template <typename NumType, typename LabelType>
|
||||
void DSSMParser<NumType, LabelType>::SetTraceLevel(int traceLevel)
|
||||
{
|
||||
m_traceLevel = traceLevel;
|
||||
}
|
||||
|
||||
// Parse - Parse the data
|
||||
// recordsRequested - number of records requested
|
||||
// numbers - pointer to vector to return the numbers (must be allocated)
|
||||
// labels - pointer to vector to return the labels (defaults to null)
|
||||
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
|
||||
template <typename NumType, typename LabelType>
|
||||
long DSSMParser<NumType, LabelType>::Parse(size_t recordsRequested, std::vector<NumType> *numbers, std::vector<LabelType> *labels)
|
||||
{
|
||||
assert(numbers != NULL || m_dimFeatures == 0 || m_parseMode == ParseLineCount);
|
||||
assert(labels != NULL || m_dimLabels == 0 || m_parseMode == ParseLineCount);
|
||||
|
||||
// transfer to member variables
|
||||
m_numbers = numbers;
|
||||
m_labels = labels;
|
||||
|
||||
long TickStart = GetTickCount( );
|
||||
long recordCount = 0;
|
||||
size_t bufferIndex = m_byteCounter-m_bufferStart;
|
||||
while (m_byteCounter < m_fileSize && recordCount < recordsRequested)
|
||||
{
|
||||
// check to see if we need to update the buffer
|
||||
if (bufferIndex >= m_bufferSize)
|
||||
{
|
||||
UpdateBuffer();
|
||||
bufferIndex = m_byteCounter-m_bufferStart;
|
||||
}
|
||||
|
||||
char ch = m_fileBuffer[bufferIndex];
|
||||
|
||||
ParseState nextState = (ParseState)m_stateTable[(m_current_state<<8)+ch];
|
||||
|
||||
if( nextState <= Exponent )
|
||||
{
|
||||
m_builtUpNumber = m_builtUpNumber * 10 + (ch - '0');
|
||||
// if we are in the decimal portion of a number increase the divider
|
||||
if (nextState == Remainder)
|
||||
m_divider *= 10;
|
||||
}
|
||||
|
||||
// only do a test on a state transition
|
||||
if (m_current_state != nextState)
|
||||
{
|
||||
// System.Diagnostics.Debug.WriteLine("Current state = " + m_current_state + ", next state = " + nextState);
|
||||
|
||||
// if the nextState is a label, we don't want to do any number processing, it's a number prefixed string
|
||||
if (nextState != Label)
|
||||
{
|
||||
// do the numeric processing
|
||||
switch (m_current_state)
|
||||
{
|
||||
case TheLetterE:
|
||||
if (m_divider != 0) // decimal number
|
||||
m_partialResult += m_builtUpNumber / m_divider;
|
||||
else // integer
|
||||
m_partialResult = m_builtUpNumber;
|
||||
m_builtUpNumber = 0;
|
||||
break;
|
||||
case WholeNumber:
|
||||
// could be followed by a remainder, or an exponent
|
||||
if (nextState != TheLetterE)
|
||||
if( nextState != Period)
|
||||
DoneWithValue();
|
||||
if (nextState == Period)
|
||||
{
|
||||
m_partialResult = m_builtUpNumber;
|
||||
m_divider = 1;
|
||||
m_builtUpNumber = 0;
|
||||
}
|
||||
break;
|
||||
case Remainder:
|
||||
// can only be followed by a exponent
|
||||
if (nextState != TheLetterE)
|
||||
DoneWithValue();
|
||||
break;
|
||||
case Exponent:
|
||||
DoneWithValue();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// label handling
|
||||
switch (m_current_state)
|
||||
{
|
||||
case Label:
|
||||
DoneWithLabel();
|
||||
break;
|
||||
case EndOfLine:
|
||||
PrepareStartLine();
|
||||
break;
|
||||
case Whitespace:
|
||||
// this is the start of the next space delimited entity
|
||||
if (nextState != EndOfLine)
|
||||
m_spaceDelimitedStart = m_byteCounter;
|
||||
break;
|
||||
}
|
||||
|
||||
// label handling for next state
|
||||
switch (nextState)
|
||||
{
|
||||
// do sign processing on nextState, since we still have the character handy
|
||||
case Sign:
|
||||
if (ch == '-')
|
||||
m_wholeNumberMultiplier = -1;
|
||||
break;
|
||||
case ExponentSign:
|
||||
if (ch == '-')
|
||||
m_exponentMultiplier = -1;
|
||||
break;
|
||||
// going into whitespace or endOfLine, so end of space delimited entity
|
||||
case Whitespace:
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
// hit whitespace and nobody processed anything, so add as label
|
||||
//if (m_elementsConvertedThisLine == elementsProcessed)
|
||||
// DoneWithLabel();
|
||||
break;
|
||||
case EndOfLine:
|
||||
if (m_current_state != Whitespace)
|
||||
{
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
// hit whitespace and nobody processed anything, so add as label
|
||||
//if (m_elementsConvertedThisLine == elementsProcessed)
|
||||
// DoneWithLabel();
|
||||
}
|
||||
// process the label at the end of a line
|
||||
//if (m_labelMode == LabelLast && m_labels != NULL)
|
||||
//{
|
||||
// StoreLastLabel();
|
||||
//}
|
||||
// intentional fall-through
|
||||
case LineCountEOL:
|
||||
recordCount++; // done with another record
|
||||
if (m_traceLevel > 1)
|
||||
{
|
||||
// print progress dots
|
||||
if (recordCount % 100 == 0)
|
||||
{
|
||||
if (recordCount % 1000 == 0)
|
||||
{
|
||||
if (recordCount % 10000 == 0)
|
||||
{
|
||||
fprintf(stderr, "#");
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "+");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, ".");
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LineCountOther:
|
||||
m_spaceDelimitedStart = m_byteCounter;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
m_current_state = nextState;
|
||||
|
||||
// move to next character
|
||||
m_byteCounter++;
|
||||
bufferIndex++;
|
||||
} // while 1
|
||||
|
||||
long TickStop = GetTickCount( );
|
||||
|
||||
long TickDelta = TickStop - TickStart;
|
||||
|
||||
if (m_traceLevel > 2)
|
||||
fprintf(stderr, "\n%d ms, %d numbers parsed\n\n", TickDelta, m_totalNumbersConverted );
|
||||
return recordCount;
|
||||
}
|
||||
|
||||
// StoreLabel - string version gets last space delimited string and stores in labels vector
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::StoreLabel(float /*finalResult*/)
|
||||
{
|
||||
// for LabelFirst, Max will not be set yet, but the current byte counter is the Max, so set it
|
||||
if (m_spaceDelimitedMax <= m_spaceDelimitedStart)
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
std::string label((LPCSTR)&m_fileBuffer[m_spaceDelimitedStart-m_bufferStart], m_spaceDelimitedMax-m_spaceDelimitedStart);
|
||||
m_labels->push_back(move(label));
|
||||
m_labelsConvertedThisLine++;
|
||||
m_elementsConvertedThisLine++;
|
||||
m_lastLabelIsString = true;
|
||||
}
|
||||
|
||||
// DoneWithLabel - string version stores string label
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::DoneWithLabel()
|
||||
{
|
||||
if (m_labels != NULL)
|
||||
StoreLabel(0); // store the string label
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
// StoreLastLabel - string version
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::StoreLastLabel()
|
||||
{
|
||||
// see if it was already stored as a string label
|
||||
if (m_lastLabelIsString)
|
||||
return;
|
||||
StoreLabel(0);
|
||||
|
||||
// we already stored a numeric version of this label in the numbers array
|
||||
// so get rid of that, the user wants it as a string
|
||||
m_numbers->pop_back();
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
// NOTE: Current code is identical to float, don't know how to specialize with template parameter that only covers one parameter
|
||||
|
||||
// StoreLabel - string version gets last space delimited string and stores in labels vector
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::StoreLabel(double /*finalResult*/)
|
||||
{
|
||||
// for LabelFirst, Max will not be set yet, but the current byte counter is the Max, so set it
|
||||
if (m_spaceDelimitedMax <= m_spaceDelimitedStart)
|
||||
m_spaceDelimitedMax = m_byteCounter;
|
||||
std::string label((LPCSTR)&m_fileBuffer[m_spaceDelimitedStart-m_bufferStart], m_spaceDelimitedMax-m_spaceDelimitedStart);
|
||||
m_labels->push_back(move(label));
|
||||
m_labelsConvertedThisLine++;
|
||||
m_elementsConvertedThisLine++;
|
||||
m_lastLabelIsString = true;
|
||||
}
|
||||
|
||||
// DoneWithLabel - string version stores string label
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::DoneWithLabel()
|
||||
{
|
||||
if (m_labels != NULL)
|
||||
StoreLabel(0); // store the string label
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
// StoreLastLabel - string version
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::StoreLastLabel()
|
||||
{
|
||||
// see if it was already stored as a string label
|
||||
if (m_lastLabelIsString)
|
||||
return;
|
||||
StoreLabel(0);
|
||||
|
||||
// we already stored a numeric version of this label in the numbers array
|
||||
// so get rid of that, the user wants it as a string
|
||||
m_numbers->pop_back();
|
||||
PrepareStartNumber();
|
||||
}
|
||||
|
||||
#ifdef STANDALONE
|
||||
int wmain(int argc, wchar_t* argv[])
|
||||
{
|
||||
DSSMParser<double, int> parser;
|
||||
std::vector<double> values;
|
||||
values.reserve(784000*6);
|
||||
std::vector<int> labels;
|
||||
labels.reserve(60000);
|
||||
parser.ParseInit(L"c:\\speech\\mnist\\mnist_train.txt", LabelFirst);
|
||||
//parser.ParseInit("c:\\speech\\parseTest.txt", LabelNone);
|
||||
int records = 0;
|
||||
do
|
||||
{
|
||||
int recordsRead = parser.Parse(10000, &values, &labels);
|
||||
if (recordsRead < 10000)
|
||||
parser.SetFilePosition(0); // go around again
|
||||
records += recordsRead;
|
||||
values.clear();
|
||||
labels.clear();
|
||||
}
|
||||
while (records < 150000);
|
||||
return records;
|
||||
}
|
||||
#endif
|
||||
|
||||
// instantiate DSSM parsers for supported types
|
||||
template class DSSMParser<float, int>;
|
||||
template class DSSMParser<float, float>;
|
||||
template class DSSMParser<float, std::string>;
|
||||
template class DSSMParser<double, int>;
|
||||
template class DSSMParser<double, double>;
|
||||
template class DSSMParser<double, std::string>;
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
// DSSMParser.h : Parses the DSSM format using a custom state machine (for speed)
|
||||
//
|
||||
//
|
||||
// <copyright file="DSSMParser.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// DSSM label location types
|
||||
enum LabelMode
|
||||
{
|
||||
LabelNone = 0,
|
||||
LabelFirst = 1,
|
||||
LabelLast = 2,
|
||||
};
|
||||
|
||||
enum ParseMode
|
||||
{
|
||||
ParseNormal = 0,
|
||||
ParseLineCount = 1
|
||||
};
|
||||
|
||||
// DSSMParser - the parser for the DSSM format files
|
||||
// for ultimate speed, this class implements a state machine to read these format files
|
||||
template <typename NumType, typename LabelType=int>
|
||||
class DSSMParser
|
||||
{
|
||||
private:
|
||||
enum ParseState
|
||||
{
|
||||
WholeNumber = 0,
|
||||
Remainder = 1,
|
||||
Exponent = 2,
|
||||
Whitespace = 3,
|
||||
Sign = 4,
|
||||
ExponentSign = 5,
|
||||
Period = 6,
|
||||
TheLetterE = 7,
|
||||
EndOfLine = 8,
|
||||
Label = 9, // any non-number things we run into
|
||||
ParseStateMax = 10, // number of parse states
|
||||
LineCountEOL = 10,
|
||||
LineCountOther = 11,
|
||||
AllStateMax = 12,
|
||||
Error = 12
|
||||
};
|
||||
|
||||
// type of label processing
|
||||
ParseMode m_parseMode;
|
||||
|
||||
// definition of label and feature locations
|
||||
size_t m_startLabels;
|
||||
size_t m_dimLabels;
|
||||
size_t m_startFeatures;
|
||||
size_t m_dimFeatures;
|
||||
|
||||
// level of screen output
|
||||
int m_traceLevel;
|
||||
|
||||
// current state of the state machine
|
||||
ParseState m_current_state;
|
||||
|
||||
// state tables
|
||||
DWORD *m_stateTable;
|
||||
|
||||
// numeric state machine variables
|
||||
double m_partialResult;
|
||||
double m_builtUpNumber;
|
||||
double m_divider;
|
||||
double m_wholeNumberMultiplier;
|
||||
double m_exponentMultiplier;
|
||||
|
||||
// label state machine variables
|
||||
size_t m_spaceDelimitedStart;
|
||||
size_t m_spaceDelimitedMax; // start of the next whitespace sequence (one past the end of the last word)
|
||||
int m_numbersConvertedThisLine;
|
||||
int m_labelsConvertedThisLine;
|
||||
int m_elementsConvertedThisLine;
|
||||
|
||||
// global stats
|
||||
int m_totalNumbersConverted;
|
||||
int m_totalLabelsConverted;
|
||||
|
||||
// file positions/buffer
|
||||
FILE * m_pFile;
|
||||
int64_t m_byteCounter;
|
||||
int64_t m_fileSize;
|
||||
|
||||
BYTE * m_fileBuffer;
|
||||
size_t m_bufferStart;
|
||||
size_t m_bufferSize;
|
||||
|
||||
// last label was a string (for last label processing)
|
||||
bool m_lastLabelIsString;
|
||||
|
||||
// vectors to append to
|
||||
std::vector<NumType>* m_numbers; // pointer to vectors to append with numbers
|
||||
std::vector<LabelType>* m_labels; // pointer to vector to append with labels (may be numeric)
|
||||
// FUTURE: do we want a vector to collect string labels in the non string label case? (signifies an error)
|
||||
|
||||
// SetState for a particular value
|
||||
void SetState(int value, ParseState m_current_state, ParseState next_state);
|
||||
|
||||
// SetStateRange - set states transitions for a range of values
|
||||
void SetStateRange(int value1, int value2, ParseState m_current_state, ParseState next_state);
|
||||
|
||||
// SetupStateTables - setup state transition tables for each state
|
||||
// each state has a block of 256 states indexed by the incoming character
|
||||
void SetupStateTables();
|
||||
|
||||
// reset all line state variables
|
||||
void PrepareStartLine();
|
||||
|
||||
// reset all number accumulation variables
|
||||
void PrepareStartNumber();
|
||||
|
||||
// reset all state variables to start reading at a new position
|
||||
void PrepareStartPosition(size_t position);
|
||||
|
||||
// UpdateBuffer - load the next buffer full of data
|
||||
// returns - number of records read
|
||||
size_t UpdateBuffer();
|
||||
|
||||
public:
|
||||
|
||||
// DSSMParser constructor
|
||||
DSSMParser();
|
||||
// setup all the state variables and state tables for state machine
|
||||
void Init();
|
||||
|
||||
// Parser destructor
|
||||
~DSSMParser();
|
||||
|
||||
private:
|
||||
// DoneWithLabel - Called when a string label is found
|
||||
void DoneWithLabel();
|
||||
|
||||
// Called when a number is complete
|
||||
void DoneWithValue();
|
||||
|
||||
// store label is specialized by LabelType
|
||||
void StoreLabel(NumType value);
|
||||
|
||||
// StoreLastLabel - store the last label (for numeric types), tranfers to label vector
|
||||
// string label types handled in specialization
|
||||
void StoreLastLabel();
|
||||
|
||||
public:
|
||||
// SetParseMode - Set the parsing mode
|
||||
// mode - set mode to either ParseLineCount, or ParseNormal
|
||||
void SetParseMode(ParseMode mode);
|
||||
|
||||
// SetTraceLevel - Set the level of screen output
|
||||
// traceLevel - traceLevel, zero means no output, 1 epoch related output, > 1 all output
|
||||
void SetTraceLevel(int traceLevel);
|
||||
|
||||
|
||||
// ParseInit - Initialize a parse of a file
|
||||
// fileName - path to the file to open
|
||||
// startFeatures - column (zero based) where features start
|
||||
// dimFeatures - number of features
|
||||
// startLabels - column (zero based) where Labels start
|
||||
// dimLabels - number of Labels
|
||||
// bufferSize - size of temporary buffer to store reads
|
||||
// startPosition - file position on which we should start
|
||||
void ParseInit(LPCWSTR fileName, size_t startFeatures, size_t dimFeatures, size_t startLabels, size_t dimLabels, size_t bufferSize=1024*256, size_t startPosition=0);
|
||||
|
||||
// Parse - Parse the data
|
||||
// recordsRequested - number of records requested
|
||||
// numbers - pointer to vector to return the numbers (must be allocated)
|
||||
// labels - pointer to vector to return the labels (defaults to null)
|
||||
// returns - number of records actually read, if the end of file is reached the return value will be < requested records
|
||||
long Parse(size_t recordsRequested, std::vector<NumType> *numbers, std::vector<LabelType> *labels=NULL);
|
||||
|
||||
int64_t GetFilePosition();
|
||||
void SetFilePosition(int64_t position);
|
||||
|
||||
// HasMoreData - test if the current dataset have more data
|
||||
// returns - true if it does, false if not
|
||||
bool HasMoreData();
|
||||
};
|
||||
|
||||
// StoreLabel - string version gets last space delimited string and stores in labels vector
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::StoreLabel(float finalResult);
|
||||
|
||||
// DoneWithLabel - string version stores string label
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::DoneWithLabel();
|
||||
|
||||
// StoreLastLabel - string version
|
||||
template <>
|
||||
void DSSMParser<float, std::string>::StoreLastLabel();
|
||||
|
||||
// NOTE: Current code is identical to float, don't know how to specialize with template parameter that only covers one parameter
|
||||
|
||||
// StoreLabel - string version gets last space delimited string and stores in labels vector
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::StoreLabel(double finalResult);
|
||||
|
||||
// DoneWithLabel - string version stores string label
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::DoneWithLabel();
|
||||
|
||||
// StoreLastLabel - string version
|
||||
template <>
|
||||
void DSSMParser<double, std::string>::StoreLastLabel();
|
|
@ -0,0 +1,679 @@
|
|||
//
|
||||
// <copyright file="DSSMReader.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// DSSMReader.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#define DATAREADER_EXPORTS // creating the exports here
|
||||
#include "DataReader.h"
|
||||
#include "DSSMReader.h"
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // leak detection
|
||||
#endif
|
||||
#include "fileutil.h" // for fexists()
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
DWORD HIDWORD(size_t size) {return size>>32;}
|
||||
DWORD LODWORD(size_t size) {return size&0xFFFFFFFF;}
|
||||
|
||||
std::string ws2s(const std::wstring& wstr)
|
||||
{
|
||||
int size_needed = WideCharToMultiByte(CP_ACP, 0, wstr.c_str(), int(wstr.length() + 1), 0, 0, 0, 0);
|
||||
std::string strTo(size_needed, 0);
|
||||
WideCharToMultiByte(CP_ACP, 0, wstr.c_str(), int(wstr.length() + 1), &strTo[0], size_needed, 0, 0);
|
||||
return strTo;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
size_t DSSMReader<ElemType>::RandomizeSweep(size_t mbStartSample)
|
||||
{
|
||||
//size_t randomRangePerEpoch = (m_epochSize+m_randomizeRange-1)/m_randomizeRange;
|
||||
//return m_epoch*randomRangePerEpoch + epochSample/m_randomizeRange;
|
||||
return mbStartSample/m_randomizeRange;
|
||||
}
|
||||
|
||||
// ReadLine - Read a line
|
||||
// readSample - sample to read in global sample space
|
||||
// returns - true if we successfully read a record, otherwise false
|
||||
template<class ElemType>
|
||||
bool DSSMReader<ElemType>::ReadRecord(size_t /*readSample*/)
|
||||
{
|
||||
return false; // not used
|
||||
}
|
||||
|
||||
// RecordsToRead - Determine number of records to read to populate record buffers
|
||||
// mbStartSample - the starting sample from which to read
|
||||
// tail - we are checking for possible remainer records to read (default false)
|
||||
// returns - true if we have more to read, false if we hit the end of the dataset
|
||||
template<class ElemType>
|
||||
size_t DSSMReader<ElemType>::RecordsToRead(size_t mbStartSample, bool tail)
|
||||
{
|
||||
assert(mbStartSample >= m_epochStartSample);
|
||||
// determine how far ahead we need to read
|
||||
bool randomize = Randomize();
|
||||
// need to read to the end of the next minibatch
|
||||
size_t epochSample = mbStartSample;
|
||||
epochSample %= m_epochSize;
|
||||
|
||||
// determine number left to read for this epoch
|
||||
size_t numberToEpoch = m_epochSize - epochSample;
|
||||
// we will take either a minibatch or the number left in the epoch
|
||||
size_t numberToRead = min(numberToEpoch, m_mbSize);
|
||||
if (numberToRead == 0 && !tail)
|
||||
numberToRead = m_mbSize;
|
||||
|
||||
if (randomize)
|
||||
{
|
||||
size_t randomizeSweep = RandomizeSweep(mbStartSample);
|
||||
// if first read or read takes us to another randomization range
|
||||
// we need to read at least randomization range records
|
||||
if (randomizeSweep != m_randomordering.CurrentSeed()) // the range has changed since last time
|
||||
{
|
||||
numberToRead = RoundUp(epochSample, m_randomizeRange) - epochSample;
|
||||
if (numberToRead == 0 && !tail)
|
||||
numberToRead = m_randomizeRange;
|
||||
}
|
||||
}
|
||||
return numberToRead;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::WriteLabelFile()
|
||||
{
|
||||
// write out the label file if they don't have one
|
||||
if (!m_labelFileToWrite.empty())
|
||||
{
|
||||
if (m_mapIdToLabel.size() > 0)
|
||||
{
|
||||
File labelFile(m_labelFileToWrite, fileOptionsWrite | fileOptionsText);
|
||||
for (int i=0; i < m_mapIdToLabel.size(); ++i)
|
||||
{
|
||||
labelFile << m_mapIdToLabel[i] << '\n';
|
||||
}
|
||||
fprintf(stderr, "label file %ws written to disk\n", m_labelFileToWrite.c_str());
|
||||
m_labelFileToWrite.clear();
|
||||
}
|
||||
else if (!m_cachingWriter)
|
||||
{
|
||||
fprintf(stderr, "WARNING: file %ws NOT written to disk yet, will be written the first time the end of the entire dataset is found.\n", m_labelFileToWrite.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::Destroy()
|
||||
{
|
||||
delete this;
|
||||
}
|
||||
|
||||
// Init - Reader Initialize for multiple data sets
|
||||
// config - [in] configuration parameters for the datareader
|
||||
// Sample format below:
|
||||
//# Parameter values for the reader
|
||||
//reader=[
|
||||
// # reader to use
|
||||
// readerType=DSSMReader
|
||||
// miniBatchMode=Partial
|
||||
// randomize=None
|
||||
// features=[
|
||||
// dim=784
|
||||
// start=1
|
||||
// file=c:\speech\mnist\mnist_test.txt
|
||||
// ]
|
||||
// labels=[
|
||||
// dim=1
|
||||
// start=0
|
||||
// file=c:\speech\mnist\mnist_test.txt
|
||||
// labelMappingFile=c:\speech\mnist\labels.txt
|
||||
// labelDim=10
|
||||
// labelType=Category
|
||||
// ]
|
||||
//]
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
||||
{
|
||||
std::vector<std::wstring> features;
|
||||
std::vector<std::wstring> labels;
|
||||
|
||||
// Determine the names of the features and lables sections in the config file.
|
||||
// features - [in,out] a vector of feature name strings
|
||||
// labels - [in,out] a vector of label name strings
|
||||
// For DSSM dataset, we only need features. No label is necessary. The following "labels" just serves as a place holder
|
||||
GetFileConfigNames(readerConfig, features, labels);
|
||||
|
||||
// For DSSM dataset, it must have exactly two features
|
||||
// In the config file, we must specify query features first, then document features. The sequence is different here. Pay attention
|
||||
if (features.size() == 2 && labels.size() == 1)
|
||||
{
|
||||
m_featuresNameQuery = features[1];
|
||||
m_featuresNameDoc = features[0];
|
||||
m_labelsName = labels[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
RuntimeError("DSSM requires exactly two features and one label. Their names should match those in NDL definition");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
m_mbStartSample = m_epoch = m_totalSamples = m_epochStartSample = 0;
|
||||
m_labelIdMax = m_labelDim = 0;
|
||||
m_partialMinibatch = m_endReached = false;
|
||||
m_labelType = labelCategory;
|
||||
m_readNextSample = 0;
|
||||
m_traceLevel = readerConfig("traceLevel", "0");
|
||||
|
||||
if (readerConfig.Exists("randomize"))
|
||||
{
|
||||
string randomizeString = readerConfig("randomize");
|
||||
if (randomizeString == "None")
|
||||
{
|
||||
m_randomizeRange = randomizeNone;
|
||||
}
|
||||
else if (randomizeString == "Auto")
|
||||
{
|
||||
m_randomizeRange = randomizeAuto;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_randomizeRange = readerConfig("randomize");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
m_randomizeRange = randomizeNone;
|
||||
}
|
||||
|
||||
std::string minibatchMode(readerConfig("minibatchMode", "Partial"));
|
||||
m_partialMinibatch = !_stricmp(minibatchMode.c_str(), "Partial");
|
||||
|
||||
|
||||
// Get the config parameters for query feature and doc feature
|
||||
ConfigParameters configFeaturesQuery = readerConfig(m_featuresNameQuery, "");
|
||||
ConfigParameters configFeaturesDoc = readerConfig(m_featuresNameDoc, "");
|
||||
|
||||
if (configFeaturesQuery.size() == 0)
|
||||
RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'");
|
||||
if (configFeaturesDoc.size() == 0)
|
||||
RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'");
|
||||
|
||||
// Read in feature size information
|
||||
// This information will be used to handle OOVs
|
||||
m_featuresDimQuery = configFeaturesQuery(L"dim");
|
||||
m_featuresDimDoc = configFeaturesDoc(L"dim");
|
||||
|
||||
std::wstring fileQ = configFeaturesQuery("file");
|
||||
std::wstring fileD = configFeaturesDoc("file");
|
||||
|
||||
dssm_queryInput.Init(fileQ, m_featuresDimQuery);
|
||||
dssm_docInput.Init(fileD, m_featuresDimDoc);
|
||||
|
||||
m_totalSamples = dssm_queryInput.numRows;
|
||||
if (read_order == NULL)
|
||||
{
|
||||
read_order = new int[m_totalSamples];
|
||||
for (int c = 0; c < m_totalSamples; c++)
|
||||
{
|
||||
read_order[c] = c;
|
||||
}
|
||||
}
|
||||
m_mbSize = 0;
|
||||
|
||||
}
|
||||
// destructor - virtual so it gets called properly
|
||||
template<class ElemType>
|
||||
DSSMReader<ElemType>::~DSSMReader()
|
||||
{
|
||||
ReleaseMemory();
|
||||
}
|
||||
|
||||
// ReleaseMemory - release the memory footprint of DSSMReader
|
||||
// used when the caching reader is taking over
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::ReleaseMemory()
|
||||
{
|
||||
if (m_qfeaturesBuffer!=NULL)
|
||||
delete[] m_qfeaturesBuffer;
|
||||
m_qfeaturesBuffer=NULL;
|
||||
if (m_dfeaturesBuffer!=NULL)
|
||||
delete[] m_dfeaturesBuffer;
|
||||
m_dfeaturesBuffer=NULL;
|
||||
if (m_labelsBuffer!=NULL)
|
||||
delete[] m_labelsBuffer;
|
||||
m_labelsBuffer=NULL;
|
||||
if (m_labelsIdBuffer!=NULL)
|
||||
delete[] m_labelsIdBuffer;
|
||||
m_labelsIdBuffer=NULL;
|
||||
m_featureData.clear();
|
||||
m_labelIdData.clear();
|
||||
m_labelData.clear();
|
||||
}
|
||||
|
||||
//SetupEpoch - Setup the proper position in the file, and other variable settings to start a particular epoch
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::SetupEpoch()
|
||||
{
|
||||
}
|
||||
|
||||
// utility function to round an integer up to a multiple of size
|
||||
size_t RoundUp(size_t value, size_t size)
|
||||
{
|
||||
return ((value + size -1)/size)*size;
|
||||
}
|
||||
|
||||
//StartMinibatchLoop - Startup a minibatch loop
|
||||
// mbSize - [in] size of the minibatch (number of Samples, etc.)
|
||||
// epoch - [in] epoch number for this loop, if > 0 the requestedEpochSamples must be specified (unless epoch zero was completed this run)
|
||||
// requestedEpochSamples - [in] number of samples to randomize, defaults to requestDataSize which uses the number of samples there are in the dataset
|
||||
// this value must be a multiple of mbSize, if it is not, it will be rounded up to one.
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples)
|
||||
{
|
||||
size_t mbStartSample = m_epoch * m_epochSize;
|
||||
if (m_totalSamples == 0)
|
||||
{
|
||||
m_totalSamples = dssm_queryInput.numRows;
|
||||
}
|
||||
|
||||
size_t fileRecord = m_totalSamples ? mbStartSample % m_totalSamples : 0;
|
||||
fprintf(stderr, "starting epoch %lld at record count %lld, and file position %lld\n", m_epoch, mbStartSample, fileRecord);
|
||||
size_t currentFileRecord = m_mbStartSample % m_totalSamples;
|
||||
|
||||
|
||||
|
||||
// reset the next read sample
|
||||
m_readNextSample = 0;
|
||||
m_epochStartSample = m_mbStartSample = mbStartSample;
|
||||
m_mbSize = mbSize;
|
||||
m_epochSize = requestedEpochSamples;
|
||||
dssm_queryInput.SetupEpoch(mbSize);
|
||||
dssm_docInput.SetupEpoch(mbSize);
|
||||
if (m_epochSize > dssm_queryInput.numRows)
|
||||
{
|
||||
m_epochSize = dssm_queryInput.numRows;
|
||||
}
|
||||
if (Randomize())
|
||||
{
|
||||
random_shuffle(&read_order[0], &read_order[m_epochSize]);
|
||||
}
|
||||
m_epoch = epoch;
|
||||
m_mbStartSample = epoch*m_epochSize;
|
||||
|
||||
}
|
||||
|
||||
// function to store the LabelType in an ElemType
|
||||
// required for string labels, which can't be stored in ElemType arrays
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::StoreLabel(ElemType& labelStore, const LabelType& labelValue)
|
||||
{
|
||||
labelStore = (ElemType)m_mapLabelToId[labelValue];
|
||||
}
|
||||
|
||||
// GetMinibatch - Get the next minibatch (features and labels)
|
||||
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponing matrix,
|
||||
// [out] each matrix resized if necessary containing data.
|
||||
// returns - true if there are more minibatches, false if no more minibatchs remain
|
||||
template<class ElemType>
|
||||
bool DSSMReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices)
|
||||
{
|
||||
if (m_readNextSample >= m_totalSamples)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// In my unit test example, the input matrices contain 5: N, S, fD, fQ and labels
|
||||
// Both N and S serve as a pre-set constant values, no need to change them
|
||||
// In this node, we only need to fill in these matrices: fD, fQ, labels
|
||||
Matrix<ElemType>& featuresQ = *matrices[m_featuresNameQuery];
|
||||
Matrix<ElemType>& featuresD = *matrices[m_featuresNameDoc];
|
||||
Matrix<ElemType>& labels = *matrices[m_labelsName]; // will change this part later.
|
||||
|
||||
size_t actualMBSize = ( m_readNextSample + m_mbSize > m_totalSamples ) ? m_totalSamples - m_readNextSample : m_mbSize;
|
||||
|
||||
|
||||
featuresQ.SwitchToMatrixType(MatrixType::SPARSE, MatrixFormat::matrixFormatSparseCSC);
|
||||
featuresD.SwitchToMatrixType(MatrixType::SPARSE, MatrixFormat::matrixFormatSparseCSC);
|
||||
|
||||
/*
|
||||
featuresQ.Resize(dssm_queryInput.numRows, actualMBSize);
|
||||
featuresD.Resize(dssm_docInput.numRows, actualMBSize);
|
||||
*/
|
||||
|
||||
//fprintf(stderr, "featuresQ\n");
|
||||
dssm_queryInput.Next_Batch(featuresQ, m_readNextSample, actualMBSize, read_order);
|
||||
//fprintf(stderr, "\n\n\nfeaturesD\n");
|
||||
dssm_docInput.Next_Batch(featuresD, m_readNextSample, actualMBSize, read_order);
|
||||
//fprintf(stderr, "\n\n\n\n\n");
|
||||
m_readNextSample += actualMBSize;
|
||||
/*
|
||||
featuresQ.Print("featuresQ");
|
||||
fprintf(stderr, "\n");
|
||||
featuresD.Print("featuresD");
|
||||
fprintf(stderr, "\n");
|
||||
*/
|
||||
|
||||
/*
|
||||
GPUSPARSE_INDEX_TYPE* h_CSCCol;
|
||||
GPUSPARSE_INDEX_TYPE* h_Row;
|
||||
ElemType* h_val;
|
||||
size_t nz;
|
||||
size_t nrs;
|
||||
size_t ncols;
|
||||
featuresQ.GetMatrixFromCSCFormat(&h_CSCCol, &h_Row, &h_val, &nz, &nrs, &ncols);
|
||||
|
||||
for (int j = 0, k=0; j < nz; j++)
|
||||
{
|
||||
if (h_CSCCol[k] >= j)
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
k++;
|
||||
}
|
||||
fprintf(stderr, "%d:%.f ", h_Row[j], h_val[j]);
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
featuresQ.TransferFromDeviceToDevice(featuresQ.GetDeviceId(), -1);
|
||||
featuresQ.SwitchToMatrixType(MatrixType::DENSE, MatrixFormat::matrixFormatDense);
|
||||
featuresQ.Print("featuresQ");
|
||||
|
||||
featuresD.TransferFromDeviceToDevice(featuresD.GetDeviceId(), -1);
|
||||
featuresD.SwitchToMatrixType(MatrixType::DENSE, MatrixFormat::matrixFormatDense);
|
||||
featuresD.Print("featuresD");
|
||||
|
||||
exit(1);
|
||||
*/
|
||||
|
||||
if (actualMBSize > m_mbSize || m_labelsBuffer == NULL) {
|
||||
size_t rows = labels.GetNumRows();
|
||||
labels.Resize(rows, actualMBSize);
|
||||
labels.SetValue(0.0);
|
||||
m_labelsBuffer = new ElemType[rows * actualMBSize];
|
||||
memset(m_labelsBuffer, 0, sizeof(ElemType)* rows * actualMBSize);
|
||||
for (int i = 0; i < actualMBSize; i++)
|
||||
{
|
||||
m_labelsBuffer[i * rows] = 1;
|
||||
}
|
||||
labels.SetValue(rows, actualMBSize, m_labelsBuffer, 0, labels.GetDeviceId());
|
||||
|
||||
}
|
||||
/*
|
||||
featuresQ.Print("featuresQ");
|
||||
featuresD.Print("featuresD");
|
||||
labels.print("labels");
|
||||
*/
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template<class ElemType>
|
||||
const std::map<typename IDataReader<ElemType>::LabelIdType, typename IDataReader<ElemType>::LabelType>& DSSMReader<ElemType>::GetLabelMapping(const std::wstring& sectionName)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
return m_cachingReader->GetLabelMapping(sectionName);
|
||||
}
|
||||
return m_mapIdToLabel;
|
||||
}
|
||||
|
||||
// SetLabelMapping - Sets the label mapping from integer index to label
|
||||
// labelMapping - mapping table from label values to IDs (must be 0-n)
|
||||
// note: for tasks with labels, the mapping table must be the same between a training run and a testing run
|
||||
template<class ElemType>
|
||||
void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/, const std::map<typename IDataReader<ElemType>::LabelIdType, typename LabelType>& labelMapping)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
throw runtime_error("Cannot set mapping table when the caching reader is being used");
|
||||
}
|
||||
m_mapIdToLabel = labelMapping;
|
||||
m_mapLabelToId.clear();
|
||||
for (std::pair<unsigned, LabelType> var : labelMapping)
|
||||
{
|
||||
m_mapLabelToId[var.second] = var.first;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool DSSMReader<ElemType>::DataEnd(EndDataType endDataType)
|
||||
{
|
||||
bool ret = false;
|
||||
switch (endDataType)
|
||||
{
|
||||
case endDataNull:
|
||||
assert(false);
|
||||
break;
|
||||
case endDataEpoch:
|
||||
//ret = (m_mbStartSample / m_epochSize < m_epoch);
|
||||
ret = (m_readNextSample >= m_totalSamples);
|
||||
break;
|
||||
case endDataSet:
|
||||
ret = (m_readNextSample >= m_totalSamples);
|
||||
break;
|
||||
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
|
||||
ret = true;
|
||||
break;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
DSSM_BinaryInput<ElemType>::DSSM_BinaryInput(){
|
||||
}
|
||||
template<class ElemType>
|
||||
DSSM_BinaryInput<ElemType>::~DSSM_BinaryInput(){
|
||||
Dispose();
|
||||
}
|
||||
template<class ElemType>
|
||||
void DSSM_BinaryInput<ElemType>::Init(wstring fileName, size_t dim){
|
||||
|
||||
m_dim = dim;
|
||||
mbSize = 0;
|
||||
/*
|
||||
m_hndl = CreateFileA(fileName.c_str(), GENERIC_READ,
|
||||
FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
*/
|
||||
m_hndl = CreateFile(fileName.c_str(), GENERIC_READ,
|
||||
FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
if (m_hndl == INVALID_HANDLE_VALUE)
|
||||
{
|
||||
char message[256];
|
||||
sprintf_s(message, "Unable to Open/Create file %ls, error %x", fileName.c_str(), GetLastError());
|
||||
throw runtime_error(message);
|
||||
}
|
||||
|
||||
m_filemap = CreateFileMapping(m_hndl, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
|
||||
SYSTEM_INFO sysinfo;
|
||||
GetSystemInfo(&sysinfo);
|
||||
DWORD sysGran = sysinfo.dwAllocationGranularity;
|
||||
|
||||
header_buffer = MapViewOfFile(m_filemap, // handle to map object
|
||||
FILE_MAP_READ, // get correct permissions
|
||||
HIDWORD(0),
|
||||
LODWORD(0),
|
||||
sizeof(int64_t)* 2 + sizeof(int32_t));
|
||||
|
||||
//cout << "After mapviewoffile" << endl;
|
||||
|
||||
memcpy(&numRows, header_buffer, sizeof(int64_t));
|
||||
memcpy(&numCols, (char*)header_buffer + sizeof(int64_t), sizeof(int32_t));
|
||||
memcpy(&totalNNz, (char*)header_buffer + sizeof(int64_t)+sizeof(int32_t), sizeof(int64_t));
|
||||
|
||||
//cout << "After gotvalues" << endl;
|
||||
int64_t base_offset = sizeof(int64_t)* 2 + sizeof(int32_t);
|
||||
|
||||
int64_t offsets_padding = base_offset % sysGran;
|
||||
base_offset -= offsets_padding;
|
||||
|
||||
int64_t header_size = numRows*sizeof(int64_t)+offsets_padding;
|
||||
|
||||
void* offsets_orig = MapViewOfFile(m_filemap, // handle to map object
|
||||
FILE_MAP_READ, // get correct permissions
|
||||
HIDWORD(base_offset),
|
||||
LODWORD(base_offset),
|
||||
header_size);
|
||||
|
||||
offsets_buffer = (char*)offsets_orig + offsets_padding;
|
||||
|
||||
if (offsets != NULL){
|
||||
free(offsets);
|
||||
}
|
||||
offsets = (int64_t*)malloc(sizeof(int64_t)*numRows);
|
||||
memcpy(offsets, offsets_buffer, numRows*sizeof(int64_t));
|
||||
|
||||
|
||||
int64_t header_offset = base_offset + offsets_padding + numRows * sizeof(int64_t);
|
||||
|
||||
int64_t data_padding = header_offset % sysGran;
|
||||
header_offset -= data_padding;
|
||||
|
||||
void* data_orig = MapViewOfFile(m_filemap, // handle to map object
|
||||
FILE_MAP_READ, // get correct permissions
|
||||
HIDWORD(header_offset),
|
||||
LODWORD(header_offset),
|
||||
0);
|
||||
data_buffer = (char*)data_orig + data_padding;
|
||||
|
||||
}
|
||||
template<class ElemType>
|
||||
bool DSSM_BinaryInput<ElemType>::SetupEpoch( size_t minibatchSize){
|
||||
if (values == NULL || mbSize < minibatchSize)
|
||||
{
|
||||
if (values != NULL)
|
||||
{
|
||||
free(values);
|
||||
free(colIndices);
|
||||
free(rowIndices);
|
||||
}
|
||||
|
||||
values = (ElemType*)malloc(sizeof(ElemType)*MAX_BUFFER*minibatchSize);
|
||||
colIndices = (int32_t*)malloc(sizeof(int32_t)*(minibatchSize+1));
|
||||
rowIndices = (int32_t*)malloc(sizeof(int32_t)*MAX_BUFFER*minibatchSize);
|
||||
//fprintf(stderr, "values size: %d",sizeof(ElemType)*MAX_BUFFER*minibatchSize);
|
||||
//fprintf(stderr, "colindi size: %d",sizeof(int32_t)*MAX_BUFFER*(1+minibatchSize));
|
||||
//fprintf(stderr, "rowindi size: %d",sizeof(int32_t)*MAX_BUFFER*minibatchSize);
|
||||
}
|
||||
if (minibatchSize > mbSize)
|
||||
{
|
||||
mbSize = minibatchSize;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
template<class ElemType>
|
||||
bool DSSM_BinaryInput<ElemType>::Next_Batch(Matrix<ElemType>& matrices, int cur, int numToRead, int* ordering){
|
||||
/*
|
||||
int devId = matrices.GetDeviceId();
|
||||
matrices.TransferFromDeviceToDevice(devId, -1);
|
||||
*/
|
||||
|
||||
int64_t cur_index = 0;
|
||||
|
||||
for (int c = 0; c < numToRead; c++,cur++)
|
||||
{
|
||||
//int64_t cur_offset = offsets[ordering[cur]];
|
||||
int64_t cur_offset = offsets[cur];
|
||||
//int64_t cur_offset = offsets[ordering[c]];
|
||||
int32_t nnz;
|
||||
colIndices[c] = cur_index;
|
||||
memcpy(&nnz, (char*)data_buffer + cur_offset, sizeof(int32_t));
|
||||
memcpy(values+cur_index, (char*)data_buffer + cur_offset + sizeof(int32_t), sizeof(ElemType)*nnz);
|
||||
memcpy(rowIndices+cur_index, (char*)data_buffer + cur_offset + sizeof(int32_t)+sizeof(ElemType)*nnz, sizeof(int32_t)*nnz);
|
||||
/**
|
||||
fprintf(stderr, "%4d (%3d, %6d): ", c, nnz, cur_index + nnz);
|
||||
for (int i = 0; i < nnz; i++)
|
||||
{
|
||||
fprintf(stderr, "%d:%.f ", rowIndices[cur_index+i], values[cur_index+i]);
|
||||
//matrices.SetValue(rowIndices[cur_index + i], c, values[cur_index + i]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
**/
|
||||
|
||||
cur_index += nnz;
|
||||
}
|
||||
colIndices[numToRead] = cur_index;
|
||||
/*
|
||||
int col = 0;
|
||||
for (int c = 0; c < cur_index; c++)
|
||||
{
|
||||
if (colIndices[col] == c)
|
||||
{
|
||||
fprintf(stderr, "\n%4d: ", col);
|
||||
col++;
|
||||
}
|
||||
fprintf(stderr, "%d:%.f ", rowIndices[c], values[c]);
|
||||
}
|
||||
*/
|
||||
/*
|
||||
fprintf(stderr, "\nXXXX nnz: %d\n", cur_index);
|
||||
fprintf(stderr, "XXXX max values read: %d vs %d\n", sizeof(ElemType)*cur_index, sizeof(ElemType)*MAX_BUFFER*numToRead);
|
||||
fprintf(stderr, "XXXX max indices read: %d vs %d\n", sizeof(int32_t)*cur_index, sizeof(int32_t)*MAX_BUFFER*numToRead);
|
||||
fprintf(stderr, "XXXX sizeof(int32_t) = %d, sizeof(int) = %d\n", sizeof(int32_t), sizeof(int));
|
||||
*/
|
||||
/*
|
||||
values = (ElemType*)malloc(sizeof(ElemType)*MAX_BUFFER*minibatchSize);
|
||||
colIndices = (int32_t*)malloc(sizeof(int32_t)*MAX_BUFFER*(minibatchSize+1));
|
||||
rowIndices = (int32_t*)malloc(sizeof(int32_t)*MAX_BUFFER*minibatchSize);
|
||||
*/
|
||||
|
||||
matrices.SetMatrixFromCSCFormat(colIndices, rowIndices, values, cur_index, m_dim, numToRead);
|
||||
//matrices.Print("actual values");
|
||||
//exit(1);
|
||||
/*
|
||||
matrices.SwitchToMatrixType(MatrixType::DENSE, MatrixFormat::matrixFormatDense);
|
||||
matrices.Print("featuresQ");
|
||||
exit(1);
|
||||
matrices.TransferFromDeviceToDevice(-1,devId);
|
||||
*/
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DSSM_BinaryInput<ElemType>::Dispose(){
|
||||
if (offsets_orig != NULL){
|
||||
UnmapViewOfFile(offsets_orig);
|
||||
}
|
||||
if (data_orig != NULL)
|
||||
{
|
||||
UnmapViewOfFile(data_orig);
|
||||
}
|
||||
|
||||
if (offsets!= NULL)
|
||||
{
|
||||
free(offsets);// = (ElemType*)malloc(sizeof(float)* 230 * 1024);
|
||||
}
|
||||
if (values != NULL)
|
||||
{
|
||||
free(values);// = (ElemType*)malloc(sizeof(float)* 230 * 1024);
|
||||
}
|
||||
if (rowIndices != NULL){
|
||||
free(rowIndices);// = (int*)malloc(sizeof(float)* 230 * 1024);
|
||||
}
|
||||
if (colIndices != NULL){
|
||||
free(colIndices);// = (int*)malloc(sizeof(float)* 230 * 1024);
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool DSSMReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart)
|
||||
{
|
||||
if (m_cachingReader)
|
||||
{
|
||||
return m_cachingReader->GetData(sectionName, numRecords, data, dataBufferSize, recordStart);
|
||||
}
|
||||
throw runtime_error("GetData not supported in DSSMReader");
|
||||
}
|
||||
// instantiate all the combinations we expect to be used
|
||||
template class DSSMReader<double>;
|
||||
template class DSSMReader<float>;
|
||||
}}}
|
|
@ -0,0 +1,156 @@
|
|||
//
|
||||
// <copyright file="DSSMReader.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// DSSMReader.h - Include file for the MTK and MLF format of features and samples
|
||||
#pragma once
|
||||
#include "DataReader.h"
|
||||
#include "DataWriter.h"
|
||||
#include "DSSMParser.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "minibatchsourcehelpers.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
enum LabelKind
|
||||
{
|
||||
labelNone = 0, // no labels to worry about
|
||||
labelCategory = 1, // category labels, creates mapping tables
|
||||
labelRegression = 2, // regression labels
|
||||
labelOther = 3, // some other type of label
|
||||
};
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
class DSSM_BinaryInput {
|
||||
private:
|
||||
HANDLE m_hndl;
|
||||
HANDLE m_filemap;
|
||||
HANDLE m_header;
|
||||
HANDLE m_offsets;
|
||||
HANDLE m_data;
|
||||
|
||||
//void* header_orig; // Don't need this since the header is at the start of the file
|
||||
void* offsets_orig;
|
||||
void* data_orig;
|
||||
|
||||
void* header_buffer;
|
||||
void* offsets_buffer;
|
||||
void* data_buffer;
|
||||
|
||||
size_t m_dim;
|
||||
size_t mbSize;
|
||||
size_t MAX_BUFFER = 300;
|
||||
|
||||
ElemType* values; // = (ElemType*)malloc(sizeof(float)* 230 * 1024);
|
||||
int64_t* offsets; // = (int*)malloc(sizeof(int)* 230 * 1024);
|
||||
int32_t* colIndices; // = (int*)malloc(sizeof(int)* (batchsize + 1));
|
||||
int32_t* rowIndices; // = (int*)malloc(sizeof(int)* MAX_BUFFER * batchsize);
|
||||
|
||||
public:
|
||||
int64_t numRows;
|
||||
int32_t numCols;
|
||||
int64_t totalNNz;
|
||||
|
||||
DSSM_BinaryInput();
|
||||
~DSSM_BinaryInput();
|
||||
void Init(std::wstring fileName, size_t dim);
|
||||
bool SetupEpoch( size_t minibatchSize);
|
||||
bool Next_Batch(Matrix<ElemType>& matrices, int cur, int numToRead, int* ordering);
|
||||
void Dispose();
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
class DSSMReader : public IDataReader<ElemType>
|
||||
{
|
||||
//public:
|
||||
// typedef std::string LabelType;
|
||||
// typedef unsigned LabelIdType;
|
||||
private:
|
||||
int* read_order; // array to shuffle to reorder the dataset
|
||||
std::wstring m_featuresNameQuery;
|
||||
std::wstring m_featuresNameDoc;
|
||||
size_t m_featuresDimQuery;
|
||||
size_t m_featuresDimDoc;
|
||||
DSSM_BinaryInput<ElemType> dssm_queryInput;
|
||||
DSSM_BinaryInput<ElemType> dssm_docInput;
|
||||
|
||||
DSSMParser<ElemType, LabelType> m_parser;
|
||||
size_t m_mbSize; // size of minibatch requested
|
||||
LabelIdType m_labelIdMax; // maximum label ID we have encountered so far
|
||||
LabelIdType m_labelDim; // maximum label ID we will ever see (used for array dimensions)
|
||||
size_t m_mbStartSample; // starting sample # of the next minibatch
|
||||
size_t m_epochSize; // size of an epoch
|
||||
size_t m_epoch; // which epoch are we on
|
||||
size_t m_epochStartSample; // the starting sample for the epoch
|
||||
size_t m_totalSamples; // number of samples in the dataset
|
||||
size_t m_randomizeRange; // randomization range
|
||||
size_t m_featureCount; // feature count
|
||||
size_t m_readNextSample; // next sample to read
|
||||
bool m_labelFirst; // the label is the first element in a line
|
||||
bool m_partialMinibatch; // a partial minibatch is allowed
|
||||
LabelKind m_labelType; // labels are categories, create mapping table
|
||||
msra::dbn::randomordering m_randomordering; // randomizing class
|
||||
|
||||
std::wstring m_labelsName;
|
||||
std::wstring m_featuresName;
|
||||
std::wstring m_labelsCategoryName;
|
||||
std::wstring m_labelsMapName;
|
||||
ElemType* m_qfeaturesBuffer;
|
||||
ElemType* m_dfeaturesBuffer;
|
||||
ElemType* m_labelsBuffer;
|
||||
LabelIdType* m_labelsIdBuffer;
|
||||
std::wstring m_labelFileToWrite; // set to the path if we need to write out the label file
|
||||
|
||||
bool m_endReached;
|
||||
int m_traceLevel;
|
||||
|
||||
// feature and label data are parallel arrays
|
||||
std::vector<ElemType> m_featureData;
|
||||
std::vector<LabelIdType> m_labelIdData;
|
||||
std::vector<LabelType> m_labelData;
|
||||
|
||||
// map is from ElemType to LabelType
|
||||
// For DSSM, we really only need an int for label data, but we have to transmit in Matrix, so use ElemType instead
|
||||
std::map<LabelIdType, LabelType> m_mapIdToLabel;
|
||||
std::map<LabelType, LabelIdType> m_mapLabelToId;
|
||||
|
||||
// caching support
|
||||
DataReader<ElemType>* m_cachingReader;
|
||||
DataWriter<ElemType>* m_cachingWriter;
|
||||
ConfigParameters m_readerConfig;
|
||||
void InitCache(const ConfigParameters& config);
|
||||
|
||||
size_t RandomizeSweep(size_t epochSample);
|
||||
//bool Randomize() {return m_randomizeRange != randomizeNone;}
|
||||
bool Randomize() { return false; }
|
||||
size_t UpdateDataVariables(size_t mbStartSample);
|
||||
void SetupEpoch();
|
||||
void StoreLabel(ElemType& labelStore, const LabelType& labelValue);
|
||||
size_t RecordsToRead(size_t mbStartSample, bool tail=false);
|
||||
void ReleaseMemory();
|
||||
void WriteLabelFile();
|
||||
|
||||
|
||||
virtual bool ReadRecord(size_t readSample);
|
||||
public:
|
||||
virtual void Init(const ConfigParameters& config);
|
||||
virtual void Destroy();
|
||||
DSSMReader() { m_qfeaturesBuffer = NULL; m_dfeaturesBuffer = NULL; m_labelsBuffer = NULL; }
|
||||
virtual ~DSSMReader();
|
||||
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
|
||||
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
};
|
||||
}}}
|
|
@ -180,11 +180,9 @@
|
|||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\fileutil.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
<ClInclude Include="DSSMReader.h" />
|
||||
<ClInclude Include="DSSMParser.h" />
|
||||
<ClInclude Include="DSSMReader.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
|
@ -205,31 +203,12 @@
|
|||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">NotUsing</PrecompiledHeader>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">false</CompileAsManaged>
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
|
||||
</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
</PrecompiledHeader>
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">false</CompileAsManaged>
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
|
||||
</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DSSMReader.cpp" />
|
||||
<ClCompile Include="dllmain.cpp" />
|
||||
<ClCompile Include="DSSMParser.cpp" />
|
||||
<ClCompile Include="DSSMReader.cpp" />
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
</Project>
|
||||
</Project>
|
|
@ -21,21 +21,6 @@
|
|||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DSSMParser.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DSSMReader.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
|
@ -51,23 +36,20 @@
|
|||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DSSMParser.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DSSMReader.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="DSSMParser.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="DSSMReader.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="minibatchsourcehelpers.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="stdafx.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="targetver.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\basetypes.h">
|
||||
<Filter>Common\Include</Filter>
|
||||
</ClInclude>
|
||||
|
@ -83,5 +65,14 @@
|
|||
<ClInclude Include="..\..\Common\Include\fileutil.h">
|
||||
<Filter>Common\Include</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="DSSMParser.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="DSSMReader.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="minibatchsourcehelpers.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
</Project>
|
|
@ -0,0 +1,31 @@
|
|||
//
|
||||
// <copyright file="Exports.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// Exports.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#include "DSSMReader.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
{
|
||||
*preader = new DSSMReader<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
|
||||
}}}
|
|
@ -0,0 +1,19 @@
|
|||
// dllmain.cpp : Defines the entry point for the DLL application.
|
||||
#include "stdafx.h"
|
||||
|
||||
BOOL APIENTRY DllMain( HMODULE hModule,
|
||||
DWORD ul_reason_for_call,
|
||||
LPVOID lpReserved
|
||||
)
|
||||
{
|
||||
switch (ul_reason_for_call)
|
||||
{
|
||||
case DLL_PROCESS_ATTACH:
|
||||
case DLL_THREAD_ATTACH:
|
||||
case DLL_THREAD_DETACH:
|
||||
case DLL_PROCESS_DETACH:
|
||||
break;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
//
|
||||
// <copyright file="minibatchsourcehelpers.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// minibatchsourcehelpers.h -- helper classes for minibatch sources
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// randomordering -- class to help manage randomization of input data
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static inline size_t rand (const size_t begin, const size_t end)
|
||||
{
|
||||
const size_t randno = ::rand() * RAND_MAX + ::rand(); // BUGBUG: still only covers 32-bit range
|
||||
return begin + randno % (end - begin);
|
||||
}
|
||||
|
||||
class randomordering // note: NOT thread-safe at all
|
||||
{
|
||||
// constants for randomization
|
||||
const static size_t randomizeDisable=0;
|
||||
|
||||
typedef unsigned int INDEXTYPE; // don't use size_t, as this saves HUGE amounts of RAM
|
||||
std::vector<INDEXTYPE> map; // [t] -> t' indices in randomized order
|
||||
size_t currentseed; // seed for current sequence
|
||||
size_t randomizationrange; // t - randomizationrange/2 <= t' < t + randomizationrange/2 (we support this to enable swapping)
|
||||
// special values (randomizeDisable)
|
||||
void invalidate() { currentseed = (size_t) -1; }
|
||||
public:
|
||||
randomordering() { invalidate(); randomizationrange = randomizeDisable;}
|
||||
|
||||
void resize (size_t len, size_t p_randomizationrange) { randomizationrange = p_randomizationrange; if (len > 0) map.resize (len); invalidate(); }
|
||||
|
||||
// return the randomized feature bounds for a time range
|
||||
std::pair<size_t,size_t> bounds (size_t ts, size_t te) const
|
||||
{
|
||||
size_t tbegin = max (ts, randomizationrange/2) - randomizationrange/2;
|
||||
size_t tend = min (te + randomizationrange/2, map.size());
|
||||
return std::make_pair<size_t,size_t> (move(tbegin), move(tend));
|
||||
}
|
||||
|
||||
// this returns the map directly (read-only) and will lazily initialize it for a given seed
|
||||
const std::vector<INDEXTYPE> & operator() (size_t seed) //throw()
|
||||
{
|
||||
// if wrong seed then lazily recache the sequence
|
||||
if (seed != currentseed && randomizationrange != randomizeDisable)
|
||||
{
|
||||
// test for numeric overflow
|
||||
if (map.size()-1 != (INDEXTYPE) (map.size()-1))
|
||||
throw std::runtime_error ("randomordering: INDEXTYPE has too few bits for this corpus");
|
||||
// 0, 1, 2...
|
||||
foreach_index (t, map) map[t] = (INDEXTYPE) t;
|
||||
|
||||
if (map.size() > RAND_MAX * (size_t) RAND_MAX)
|
||||
throw std::runtime_error ("randomordering: too large training set: need to change to different random generator!");
|
||||
srand ((unsigned int) seed);
|
||||
size_t retries = 0;
|
||||
foreach_index (t, map)
|
||||
{
|
||||
for (int tries = 0; tries < 5; tries++)
|
||||
{
|
||||
// swap current pos with a random position
|
||||
// Random positions are limited to t+randomizationrange.
|
||||
// This ensures some locality suitable for paging with a sliding window.
|
||||
const size_t tbegin = max ((size_t) t, randomizationrange/2) - randomizationrange/2; // range of window --TODO: use bounds() function above
|
||||
const size_t tend = min (t + randomizationrange/2, map.size());
|
||||
assert (tend >= tbegin); // (guard against potential numeric-wraparound bug)
|
||||
const size_t trand = rand (tbegin, tend); // random number within windows
|
||||
assert ((size_t) t <= trand + randomizationrange/2 && trand < (size_t) t + randomizationrange/2);
|
||||
// if range condition is fulfilled then swap
|
||||
if (trand <= map[t] + randomizationrange/2 && map[t] < trand + randomizationrange/2
|
||||
&& (size_t) t <= map[trand] + randomizationrange/2 && map[trand] < (size_t) t + randomizationrange/2)
|
||||
{
|
||||
::swap (map[t], map[trand]);
|
||||
break;
|
||||
}
|
||||
// but don't multi-swap stuff out of its range (for swapping positions that have been swapped before)
|
||||
// instead, try again with a different random number
|
||||
retries++;
|
||||
}
|
||||
}
|
||||
fprintf (stderr, "randomordering: %d retries for %d elements (%.1f%%) to ensure window condition\n", retries, map.size(), 100.0 * retries / map.size());
|
||||
// ensure the window condition
|
||||
foreach_index (t, map) assert ((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2);
|
||||
#if 0 // and a live check since I don't trust myself here yet
|
||||
foreach_index (t, map) if (!((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2))
|
||||
{
|
||||
fprintf (stderr, "randomordering: windowing condition violated %d -> %d\n", t, map[t]);
|
||||
throw std::logic_error ("randomordering: windowing condition violated");
|
||||
}
|
||||
#endif
|
||||
#if 0 // test whether it is indeed a unique complete sequence
|
||||
auto map2 = map;
|
||||
::sort (map2.begin(), map2.end());
|
||||
foreach_index (t, map2) assert (map2[t] == (size_t) t);
|
||||
#endif
|
||||
fprintf (stderr, "randomordering: recached sequence for seed %d: %d, %d, ...\n", (int) seed, (int) map[0], (int) map[1]);
|
||||
currentseed = seed;
|
||||
}
|
||||
return map; // caller can now access it through operator[]
|
||||
}
|
||||
size_t CurrentSeed() {return currentseed;}
|
||||
};
|
||||
|
||||
typedef unsigned short CLASSIDTYPE; // type to store state ids; don't use size_t --saves HUGE amounts of RAM
|
||||
|
||||
};};
|
Загрузка…
Ссылка в новой задаче