Use -1,0,1 to denote no observation, sentence begining and in-the-middle-of-sentence. changed LU sequence reader. But other readers haven't changed accordingly.

This commit is contained in:
kaisheny 2015-04-08 21:20:12 -07:00
Родитель 13c49fc3ad
Коммит 14ce7f4f2f
16 изменённых файлов: 154 добавлений и 48 удалений

Просмотреть файл

@ -110,11 +110,19 @@ void DataReader<ElemType>::SetNbrSlicesEachRecurrentIter(const size_t sz)
{
m_dataReader->SetNbrSlicesEachRecurrentIter(sz);
}
template<class ElemType>
void DataReader<ElemType>::SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceEnd)
{
m_dataReader->SetSentenceEndInBatch(sentenceEnd);
m_dataReader->SetSentenceSegBatch(sentenceEnd);
}
template<class ElemType>
void DataReader<ElemType>::SetRandomSeed(int seed)
{
m_dataReader->SetRandomSeed(seed);
}
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
template<class ElemType>

Просмотреть файл

@ -61,7 +61,8 @@ public:
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart) = 0;
virtual bool DataEnd(EndDataType endDataType) = 0;
virtual void SetSentenceEndInBatch(vector<size_t> &sentenceEnd) = 0;
virtual void SetSentenceSegBatch(Matrix<ElemType>&sentenceEnd) = 0;
virtual void SetRandomSeed(int) = 0;
};
// GetReader - get a reader type from the DLL
@ -156,7 +157,9 @@ public:
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
void SetSentenceSegBatch(Matrix<ElemType>&sentenceEnd);
virtual void SetRandomSeed(int);
};
}}}

Просмотреть файл

@ -1131,4 +1131,11 @@ public:
#define EPSILON 1e-5
#define ISCLOSE(a, b, threshold) (abs(a - b) < threshold)?true:false
/**
These macros are used for sentence segmentation information.
*/
#define SENTENCE_BEGIN 0
#define SENTENCE_MIDDLE 1
#define NO_OBSERVATION -1
#endif // _BASETYPES_

Просмотреть файл

@ -413,13 +413,15 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
void SetSentenceSegBatch(std::vector<size_t> &/*sentence begin*/) {};
void SetSentenceSegBatch(Matrix<ElemType>&/*sentence begin*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<typename BinaryReader<ElemType>::LabelIdType, typename BinaryReader<ElemType>::LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetRandomSeed(int){ NOT_IMPLEMENTED; };
};
template<class ElemType>

Просмотреть файл

@ -142,11 +142,15 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/){};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetRandomSeed(int){ NOT_IMPLEMENTED; }
};
}}}

Просмотреть файл

@ -114,9 +114,15 @@ void DataReader<ElemType>::SetNbrSlicesEachRecurrentIter(const size_t sz)
}
template<class ElemType>
void DataReader<ElemType>::SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceEnd)
{
m_dataReader->SetSentenceEndInBatch(sentenceEnd);
m_dataReader->SetSentenceSegBatch(sentenceEnd);
}
template<class ElemType>
void DataReader<ElemType>::SetRandomSeed(int seed)
{
m_dataReader->SetRandomSeed(seed);
}
// GetLabelMapping - Gets the label mapping from integer index to label type

Просмотреть файл

@ -1526,15 +1526,27 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
template<class ElemType>
void HTKMLFReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
void HTKMLFReader<ElemType>::SetSentenceSegBatch(vector<size_t> &sentenceEnd)
{
sentenceEnd.resize(m_switchFrame.size());
for (size_t i = 0; i < m_switchFrame.size() ; i++)
for (size_t i = 0; i < m_switchFrame.size(); i++)
{
sentenceEnd[i] = m_switchFrame[i];
}
}
template<class ElemType>
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &)
{
NOT_IMPLEMENTED;
}
template<class ElemType>
void HTKMLFReader<ElemType>::SetRandomSeed(int )
{
NOT_IMPLEMENTED;
}
// GetFileConfigNames - determine the names of the features and labels sections in the config file
// features - [in,out] a vector of feature name strings
// labels - [in,out] a vector of label name strings

Просмотреть файл

@ -107,8 +107,10 @@ public:
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
void SetSentenceSegBatch(vector<size_t> &/*sentenceEnd*/);
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/);
void SetSentenceEnd(int /*actualMbSize*/){};
void SetRandomSeed(int);
};
}}}

Просмотреть файл

@ -1461,6 +1461,12 @@ void BatchSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
mBlgSize = readerConfig("nbruttsineachrecurrentiter", "1");
}
template<class ElemType>
void BatchSequenceReader<ElemType>::SetRandomSeed(int)
{
NOT_IMPLEMENTED;
}
template<class ElemType>
void BatchSequenceReader<ElemType>::Reset()
{
@ -1862,7 +1868,7 @@ void BatchSequenceReader<ElemType>::SetSentenceBegin(int wrd, int pos, int /*act
}
template<class ElemType>
void BatchSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(vector<size_t> &sentenceEnd)
{
sentenceEnd.resize(mToProcess.size());
if (mSentenceBegin)
@ -1875,6 +1881,12 @@ void BatchSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &senten
}
}
template<class ElemType>
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &)
{
NOT_IMPLEMENTED;
}
template<class ElemType>
bool BatchSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
{

Просмотреть файл

@ -178,12 +178,14 @@ public:
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/) {};
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
};
template<class ElemType>
@ -301,8 +303,10 @@ public:
bool EnsureDataAvailable(size_t mbStartSample);
size_t NumberSlicesInEachRecurrentIter();
void SetNbrSlicesEachRecurrentIter(const size_t mz);
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/);
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd);
void SetRandomSeed(int);
};
}}}

Просмотреть файл

@ -1255,6 +1255,11 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
if (mMaxSentenceLength > m_mbSize)
throw std::runtime_error("LUSequenceReader : minibatch size needs to be large enough to accomodate the longest sentence");
mtSentenceBegin.Resize(mToProcess.size(), mMaxSentenceLength);
mtSentenceBegin.SetValue((ElemType) SENTENCE_MIDDLE);
DEVICEID_TYPE sentenceSegDeviceId = mtSentenceBegin.GetDeviceId();
mtSentenceBegin.TransferFromDeviceToDevice(sentenceSegDeviceId, CPUDEVICE, true, false, false);
for (i = (int)mLastPosInSentence; j < (int)mMaxSentenceLength; i++, j++)
{
for (int k = 0; k < mToProcess.size(); k++)
@ -1264,7 +1269,9 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
if (i == mLastPosInSentence)
{
mSentenceBeginAt[k] = i;
mtSentenceBegin.SetValue(k, j, (ElemType) SENTENCE_BEGIN);
}
if (i == m_parser.mSentenceIndex2SentenceInfo[seq].sLen - 1)
{
mSentenceEndAt[k] = i;
@ -1321,6 +1328,7 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
m_featureWordContext.push_back(tmpCxt);
m_labelIdData.push_back((LabelIdType)NULLLABEL);
mtSentenceBegin.SetValue(k, j, (ElemType) NO_OBSERVATION);
}
m_totalSamples ++;
@ -1328,6 +1336,8 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
}
mLastPosInSentence = (i == mMaxSentenceLength)?0:i;
mtSentenceBegin.TransferFromDeviceToDevice(CPUDEVICE, sentenceSegDeviceId, true, false, false);
}
return bDataIsThere;
@ -1387,12 +1397,16 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
//loop through all the samples
Matrix<ElemType>& features = *matrices[m_featuresName];
if (matrices.find(m_featuresName) != matrices.end())
{
features.Resize(featInfo.dim * m_wordContext.size(), actualmbsize, true);
features.SetValue(0);
}
DEVICEID_TYPE featureDeviceId = features.GetDeviceId();
features.TransferFromDeviceToDevice(featureDeviceId, CPUDEVICE, true, false, false);
size_t utt_id = 0;
for (size_t j = 0; j < actualmbsize; ++j)
{
@ -1420,6 +1434,8 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
}
}
features.TransferFromDeviceToDevice(CPUDEVICE, featureDeviceId, true, false, false);
lablsize = GetLabelOutput(matrices, actualmbsize);
// go to the next sequence
@ -1432,7 +1448,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
}
// we read some records, so process them
if (lablsize == 0)
if (actualmbsize == 0)
return false;
else
return true;
@ -1464,17 +1480,10 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
}
template<class ElemType>
void BatchLUSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin)
{
sentenceEnd.resize(mToProcess.size());
if (mSentenceBegin)
{
sentenceEnd.assign(mToProcess.size(), 0);
}
else
{
sentenceEnd.assign(mToProcess.size(), m_mbSize+2);
}
mtSentenceBegin.TransferFromDeviceToDevice(mtSentenceBegin.GetDeviceId(), sentenceBegin.GetDeviceId(), true, false, false);
sentenceBegin.SetValue(mtSentenceBegin);
}
template<class ElemType>
@ -1583,21 +1592,12 @@ bool MultiIOBatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring,
m_seed++;
/// run for each reader
size_t nlabels = 0;
size_t nsamples = 0;
vector<size_t> to_process;
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
if (to_process.size() > 0)
(p->second)->SetToProcessId(to_process);
nlabels = (p->second)->GetMinibatch(matrices);
if (to_process.size() == 0)
to_process = (p->second)->ReturnToProcessId();
nsamples = max(nlabels, nsamples);
if ((p->second)->GetMinibatch(matrices) == false)
return false;
}
if (nsamples == 0)
return false;
return true;
}
@ -1654,12 +1654,33 @@ void MultiIOBatchLUSequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, s
}
template<class ElemType>
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin)
{
/// run for each reader
vector<size_t> col;
size_t rows = 0, cols = 0;
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
(p->second)->SetSentenceEndInBatch(sentenceEnd);
(p->second)->SetSentenceSegBatch(sentenceBegin);
if (rows == 0)
rows = sentenceBegin.GetNumRows();
else
if (rows != sentenceBegin.GetNumRows())
LogicError("multiple streams for LU sequence reader must have the same number of rows for sentence begining");
size_t this_col = sentenceBegin.GetNumCols();
col.push_back(this_col);
cols += this_col;
}
sentenceBegin.Resize(rows, cols);
size_t i = 0, t = 0;
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
Matrix<ElemType> mtmp(sentenceBegin.GetDeviceId());
(p->second)->SetSentenceSegBatch(mtmp);
sentenceBegin.ColumnSlice(i, col[t]).SetValue(mtmp);
i += col[t];
t++;
}
}

Просмотреть файл

@ -172,7 +172,6 @@ public:
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
void SentenceEnd(std::vector<size_t> &/*sentenceEnd*/) {};
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
@ -292,7 +291,7 @@ public:
size_t NumberSlicesInEachRecurrentIter();
void SetNbrSlicesEachRecurrentIter(const size_t mz);
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
void SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin);
public:
void LoadWordMapping(const ConfigParameters& readerConfig);
@ -317,6 +316,20 @@ public:
size_t mMaxSentenceLength;
vector<int> mSentenceBeginAt;
vector<int> mSentenceEndAt;
/// a matrix of n_stream x n_length
/// n_stream is the number of streams
/// n_length is the maximum lenght of each stream
/// for example, two sentences used in parallel in one minibatch would be
/// [2 x 5] if the max length of one of the sentences is 5
/// the elements of the matrix is 0, 1, or -1, defined as SENTENCE_BEGIN, SENTENCE_MIDDLE, NO_OBSERVATION in cbasetype.h
/// 0 1 1 0 1
/// 1 0 1 0 0
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
/// the second data stream has two sentences, with 0 indicating begining of sentences
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
/// frame.
Matrix<ElemType> mtSentenceBegin;
};
template<class ElemType>
@ -345,8 +358,8 @@ public:
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples);
void SetSentenceEndInBatch(vector<size_t> &sentenceEnd);
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin);
size_t NumberSlicesInEachRecurrentIter();
void Init(const ConfigParameters& readerConfig);

Просмотреть файл

@ -144,11 +144,13 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
};
}}}

Просмотреть файл

@ -106,11 +106,14 @@ public:
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
virtual bool DataEnd(EndDataType endDataType);
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
};
}}}

Просмотреть файл

@ -164,14 +164,18 @@ public:
size_t NumberSlicesInEachRecurrentIter() {return 1;}
void SetNbrSlicesEachRecurrentIter(const size_t ) {}
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd)
{
sentenceEnd.resize(m_switchFrame.size());
for (size_t i = 0; i < m_switchFrame.size() ; i++)
for (size_t i = 0; i < m_switchFrame.size(); i++)
{
sentenceEnd[i] = m_switchFrame[i];
}
}
void SetSentenceSegBatch(Matrix<ElemType>&)
{
NOT_IMPLEMENTED;
}
void GetSentenceBoundary(std::vector<size_t> boundaryInfo)
{
m_switchFrame.resize(boundaryInfo.size());
@ -180,6 +184,9 @@ public:
m_switchFrame[i] = boundaryInfo[i];
}
}
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
// GetLabelMapping - Gets the label mapping from integer index to label type
// returns - a map from numeric datatype to native label type
virtual const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& GetLabelMapping(const std::wstring& /*sectionName*/)

Просмотреть файл

@ -85,7 +85,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t actualMBSize = m_net.GetActualMBSize();
m_net.SetActualMiniBatchSize(actualMBSize);
m_net.SetActualNbrSlicesInEachRecIter(dataReader.NumberSlicesInEachRecurrentIter());
dataReader.SetSentenceEndInBatch(m_net.m_sentenceEnd);
dataReader.SetSentenceSegBatch(m_net.m_sentenceSeg);
for (int i=0; i<outputNodes.size(); i++)
{
@ -173,7 +173,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t actualMBSize = m_net.GetActualMBSize();
m_net.SetActualMiniBatchSize(actualMBSize);
dataReader.SetSentenceEndInBatch(m_net.m_sentenceEnd);
dataReader.SetSentenceSegBatch(m_net.m_sentenceSeg);
for (int i=0; i<outputNodes.size(); i++)
{