Use -1,0,1 to denote no observation, sentence begining and in-the-middle-of-sentence. changed LU sequence reader. But other readers haven't changed accordingly.
This commit is contained in:
Родитель
13c49fc3ad
Коммит
14ce7f4f2f
|
@ -110,11 +110,19 @@ void DataReader<ElemType>::SetNbrSlicesEachRecurrentIter(const size_t sz)
|
|||
{
|
||||
m_dataReader->SetNbrSlicesEachRecurrentIter(sz);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DataReader<ElemType>::SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
|
||||
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceEnd)
|
||||
{
|
||||
m_dataReader->SetSentenceEndInBatch(sentenceEnd);
|
||||
m_dataReader->SetSentenceSegBatch(sentenceEnd);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DataReader<ElemType>::SetRandomSeed(int seed)
|
||||
{
|
||||
m_dataReader->SetRandomSeed(seed);
|
||||
}
|
||||
|
||||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
template<class ElemType>
|
||||
|
|
|
@ -61,7 +61,8 @@ public:
|
|||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart) = 0;
|
||||
virtual bool DataEnd(EndDataType endDataType) = 0;
|
||||
virtual void SetSentenceEndInBatch(vector<size_t> &sentenceEnd) = 0;
|
||||
virtual void SetSentenceSegBatch(Matrix<ElemType>&sentenceEnd) = 0;
|
||||
virtual void SetRandomSeed(int) = 0;
|
||||
};
|
||||
|
||||
// GetReader - get a reader type from the DLL
|
||||
|
@ -156,7 +157,9 @@ public:
|
|||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&sentenceEnd);
|
||||
|
||||
virtual void SetRandomSeed(int);
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
|
@ -1131,4 +1131,11 @@ public:
|
|||
#define EPSILON 1e-5
|
||||
#define ISCLOSE(a, b, threshold) (abs(a - b) < threshold)?true:false
|
||||
|
||||
/**
|
||||
These macros are used for sentence segmentation information.
|
||||
*/
|
||||
#define SENTENCE_BEGIN 0
|
||||
#define SENTENCE_MIDDLE 1
|
||||
#define NO_OBSERVATION -1
|
||||
|
||||
#endif // _BASETYPES_
|
||||
|
|
|
@ -413,13 +413,15 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
void SetSentenceSegBatch(std::vector<size_t> &/*sentence begin*/) {};
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentence begin*/) {};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<typename BinaryReader<ElemType>::LabelIdType, typename BinaryReader<ElemType>::LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetRandomSeed(int){ NOT_IMPLEMENTED; };
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
|
|
@ -142,11 +142,15 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/){};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
|
||||
void SetRandomSeed(int){ NOT_IMPLEMENTED; }
|
||||
};
|
||||
}}}
|
||||
|
|
|
@ -114,9 +114,15 @@ void DataReader<ElemType>::SetNbrSlicesEachRecurrentIter(const size_t sz)
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DataReader<ElemType>::SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
|
||||
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceEnd)
|
||||
{
|
||||
m_dataReader->SetSentenceEndInBatch(sentenceEnd);
|
||||
m_dataReader->SetSentenceSegBatch(sentenceEnd);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DataReader<ElemType>::SetRandomSeed(int seed)
|
||||
{
|
||||
m_dataReader->SetRandomSeed(seed);
|
||||
}
|
||||
|
||||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
|
|
|
@ -1526,15 +1526,27 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
|
||||
void HTKMLFReader<ElemType>::SetSentenceSegBatch(vector<size_t> &sentenceEnd)
|
||||
{
|
||||
sentenceEnd.resize(m_switchFrame.size());
|
||||
for (size_t i = 0; i < m_switchFrame.size() ; i++)
|
||||
for (size_t i = 0; i < m_switchFrame.size(); i++)
|
||||
{
|
||||
sentenceEnd[i] = m_switchFrame[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetRandomSeed(int )
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
// GetFileConfigNames - determine the names of the features and labels sections in the config file
|
||||
// features - [in,out] a vector of feature name strings
|
||||
// labels - [in,out] a vector of label name strings
|
||||
|
|
|
@ -107,8 +107,10 @@ public:
|
|||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
|
||||
void SetSentenceSegBatch(vector<size_t> &/*sentenceEnd*/);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/);
|
||||
void SetSentenceEnd(int /*actualMbSize*/){};
|
||||
void SetRandomSeed(int);
|
||||
};
|
||||
|
||||
}}}
|
|
@ -1461,6 +1461,12 @@ void BatchSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
mBlgSize = readerConfig("nbruttsineachrecurrentiter", "1");
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::SetRandomSeed(int)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::Reset()
|
||||
{
|
||||
|
@ -1862,7 +1868,7 @@ void BatchSequenceReader<ElemType>::SetSentenceBegin(int wrd, int pos, int /*act
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
|
||||
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(vector<size_t> &sentenceEnd)
|
||||
{
|
||||
sentenceEnd.resize(mToProcess.size());
|
||||
if (mSentenceBegin)
|
||||
|
@ -1875,6 +1881,12 @@ void BatchSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &senten
|
|||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool BatchSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
|
||||
{
|
||||
|
|
|
@ -178,12 +178,14 @@ public:
|
|||
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -301,8 +303,10 @@ public:
|
|||
bool EnsureDataAvailable(size_t mbStartSample);
|
||||
size_t NumberSlicesInEachRecurrentIter();
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t mz);
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/);
|
||||
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd);
|
||||
|
||||
void SetRandomSeed(int);
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
|
@ -1255,6 +1255,11 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
|
|||
if (mMaxSentenceLength > m_mbSize)
|
||||
throw std::runtime_error("LUSequenceReader : minibatch size needs to be large enough to accomodate the longest sentence");
|
||||
|
||||
mtSentenceBegin.Resize(mToProcess.size(), mMaxSentenceLength);
|
||||
mtSentenceBegin.SetValue((ElemType) SENTENCE_MIDDLE);
|
||||
DEVICEID_TYPE sentenceSegDeviceId = mtSentenceBegin.GetDeviceId();
|
||||
mtSentenceBegin.TransferFromDeviceToDevice(sentenceSegDeviceId, CPUDEVICE, true, false, false);
|
||||
|
||||
for (i = (int)mLastPosInSentence; j < (int)mMaxSentenceLength; i++, j++)
|
||||
{
|
||||
for (int k = 0; k < mToProcess.size(); k++)
|
||||
|
@ -1264,7 +1269,9 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
|
|||
if (i == mLastPosInSentence)
|
||||
{
|
||||
mSentenceBeginAt[k] = i;
|
||||
mtSentenceBegin.SetValue(k, j, (ElemType) SENTENCE_BEGIN);
|
||||
}
|
||||
|
||||
if (i == m_parser.mSentenceIndex2SentenceInfo[seq].sLen - 1)
|
||||
{
|
||||
mSentenceEndAt[k] = i;
|
||||
|
@ -1321,6 +1328,7 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
|
|||
m_featureWordContext.push_back(tmpCxt);
|
||||
|
||||
m_labelIdData.push_back((LabelIdType)NULLLABEL);
|
||||
mtSentenceBegin.SetValue(k, j, (ElemType) NO_OBSERVATION);
|
||||
}
|
||||
|
||||
m_totalSamples ++;
|
||||
|
@ -1328,6 +1336,8 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
|
|||
}
|
||||
|
||||
mLastPosInSentence = (i == mMaxSentenceLength)?0:i;
|
||||
|
||||
mtSentenceBegin.TransferFromDeviceToDevice(CPUDEVICE, sentenceSegDeviceId, true, false, false);
|
||||
}
|
||||
|
||||
return bDataIsThere;
|
||||
|
@ -1387,12 +1397,16 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
|
|||
|
||||
//loop through all the samples
|
||||
Matrix<ElemType>& features = *matrices[m_featuresName];
|
||||
|
||||
if (matrices.find(m_featuresName) != matrices.end())
|
||||
{
|
||||
features.Resize(featInfo.dim * m_wordContext.size(), actualmbsize, true);
|
||||
features.SetValue(0);
|
||||
}
|
||||
|
||||
DEVICEID_TYPE featureDeviceId = features.GetDeviceId();
|
||||
features.TransferFromDeviceToDevice(featureDeviceId, CPUDEVICE, true, false, false);
|
||||
|
||||
size_t utt_id = 0;
|
||||
for (size_t j = 0; j < actualmbsize; ++j)
|
||||
{
|
||||
|
@ -1420,6 +1434,8 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
|
|||
}
|
||||
}
|
||||
|
||||
features.TransferFromDeviceToDevice(CPUDEVICE, featureDeviceId, true, false, false);
|
||||
|
||||
lablsize = GetLabelOutput(matrices, actualmbsize);
|
||||
|
||||
// go to the next sequence
|
||||
|
@ -1432,7 +1448,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
|
|||
}
|
||||
|
||||
// we read some records, so process them
|
||||
if (lablsize == 0)
|
||||
if (actualmbsize == 0)
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
|
@ -1464,17 +1480,10 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchLUSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
|
||||
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin)
|
||||
{
|
||||
sentenceEnd.resize(mToProcess.size());
|
||||
if (mSentenceBegin)
|
||||
{
|
||||
sentenceEnd.assign(mToProcess.size(), 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
sentenceEnd.assign(mToProcess.size(), m_mbSize+2);
|
||||
}
|
||||
mtSentenceBegin.TransferFromDeviceToDevice(mtSentenceBegin.GetDeviceId(), sentenceBegin.GetDeviceId(), true, false, false);
|
||||
sentenceBegin.SetValue(mtSentenceBegin);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -1583,21 +1592,12 @@ bool MultiIOBatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring,
|
|||
m_seed++;
|
||||
|
||||
/// run for each reader
|
||||
size_t nlabels = 0;
|
||||
size_t nsamples = 0;
|
||||
vector<size_t> to_process;
|
||||
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
|
||||
{
|
||||
if (to_process.size() > 0)
|
||||
(p->second)->SetToProcessId(to_process);
|
||||
nlabels = (p->second)->GetMinibatch(matrices);
|
||||
if (to_process.size() == 0)
|
||||
to_process = (p->second)->ReturnToProcessId();
|
||||
nsamples = max(nlabels, nsamples);
|
||||
if ((p->second)->GetMinibatch(matrices) == false)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (nsamples == 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1654,12 +1654,33 @@ void MultiIOBatchLUSequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, s
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
|
||||
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin)
|
||||
{
|
||||
/// run for each reader
|
||||
vector<size_t> col;
|
||||
size_t rows = 0, cols = 0;
|
||||
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
|
||||
{
|
||||
(p->second)->SetSentenceEndInBatch(sentenceEnd);
|
||||
(p->second)->SetSentenceSegBatch(sentenceBegin);
|
||||
if (rows == 0)
|
||||
rows = sentenceBegin.GetNumRows();
|
||||
else
|
||||
if (rows != sentenceBegin.GetNumRows())
|
||||
LogicError("multiple streams for LU sequence reader must have the same number of rows for sentence begining");
|
||||
size_t this_col = sentenceBegin.GetNumCols();
|
||||
col.push_back(this_col);
|
||||
cols += this_col;
|
||||
}
|
||||
|
||||
sentenceBegin.Resize(rows, cols);
|
||||
size_t i = 0, t = 0;
|
||||
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
|
||||
{
|
||||
Matrix<ElemType> mtmp(sentenceBegin.GetDeviceId());
|
||||
(p->second)->SetSentenceSegBatch(mtmp);
|
||||
sentenceBegin.ColumnSlice(i, col[t]).SetValue(mtmp);
|
||||
i += col[t];
|
||||
t++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -172,7 +172,6 @@ public:
|
|||
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
|
||||
void SentenceEnd(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
|
@ -292,7 +291,7 @@ public:
|
|||
size_t NumberSlicesInEachRecurrentIter();
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t mz);
|
||||
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin);
|
||||
|
||||
public:
|
||||
void LoadWordMapping(const ConfigParameters& readerConfig);
|
||||
|
@ -317,6 +316,20 @@ public:
|
|||
size_t mMaxSentenceLength;
|
||||
vector<int> mSentenceBeginAt;
|
||||
vector<int> mSentenceEndAt;
|
||||
|
||||
/// a matrix of n_stream x n_length
|
||||
/// n_stream is the number of streams
|
||||
/// n_length is the maximum lenght of each stream
|
||||
/// for example, two sentences used in parallel in one minibatch would be
|
||||
/// [2 x 5] if the max length of one of the sentences is 5
|
||||
/// the elements of the matrix is 0, 1, or -1, defined as SENTENCE_BEGIN, SENTENCE_MIDDLE, NO_OBSERVATION in cbasetype.h
|
||||
/// 0 1 1 0 1
|
||||
/// 1 0 1 0 0
|
||||
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
|
||||
/// the second data stream has two sentences, with 0 indicating begining of sentences
|
||||
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
|
||||
/// frame.
|
||||
Matrix<ElemType> mtSentenceBegin;
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -345,7 +358,7 @@ public:
|
|||
|
||||
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples);
|
||||
|
||||
void SetSentenceEndInBatch(vector<size_t> &sentenceEnd);
|
||||
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin);
|
||||
|
||||
size_t NumberSlicesInEachRecurrentIter();
|
||||
|
||||
|
|
|
@ -144,11 +144,13 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
||||
};
|
||||
}}}
|
||||
|
|
|
@ -106,11 +106,14 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(std::vector<size_t> &/*sentenceEnd*/){};
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&/*sentenceEnd*/) {};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
|
||||
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
||||
};
|
||||
}}}
|
||||
|
|
|
@ -164,14 +164,18 @@ public:
|
|||
size_t NumberSlicesInEachRecurrentIter() {return 1;}
|
||||
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t ) {}
|
||||
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd)
|
||||
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd)
|
||||
{
|
||||
sentenceEnd.resize(m_switchFrame.size());
|
||||
for (size_t i = 0; i < m_switchFrame.size() ; i++)
|
||||
for (size_t i = 0; i < m_switchFrame.size(); i++)
|
||||
{
|
||||
sentenceEnd[i] = m_switchFrame[i];
|
||||
}
|
||||
}
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
void GetSentenceBoundary(std::vector<size_t> boundaryInfo)
|
||||
{
|
||||
m_switchFrame.resize(boundaryInfo.size());
|
||||
|
@ -180,6 +184,9 @@ public:
|
|||
m_switchFrame[i] = boundaryInfo[i];
|
||||
}
|
||||
}
|
||||
|
||||
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
||||
|
||||
// GetLabelMapping - Gets the label mapping from integer index to label type
|
||||
// returns - a map from numeric datatype to native label type
|
||||
virtual const std::map<typename EvalReader<ElemType>::LabelIdType, typename EvalReader<ElemType>::LabelType>& GetLabelMapping(const std::wstring& /*sectionName*/)
|
||||
|
|
|
@ -85,7 +85,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t actualMBSize = m_net.GetActualMBSize();
|
||||
m_net.SetActualMiniBatchSize(actualMBSize);
|
||||
m_net.SetActualNbrSlicesInEachRecIter(dataReader.NumberSlicesInEachRecurrentIter());
|
||||
dataReader.SetSentenceEndInBatch(m_net.m_sentenceEnd);
|
||||
dataReader.SetSentenceSegBatch(m_net.m_sentenceSeg);
|
||||
|
||||
for (int i=0; i<outputNodes.size(); i++)
|
||||
{
|
||||
|
@ -173,7 +173,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
size_t actualMBSize = m_net.GetActualMBSize();
|
||||
m_net.SetActualMiniBatchSize(actualMBSize);
|
||||
dataReader.SetSentenceEndInBatch(m_net.m_sentenceEnd);
|
||||
dataReader.SetSentenceSegBatch(m_net.m_sentenceSeg);
|
||||
|
||||
for (int i=0; i<outputNodes.size(); i++)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче