diff --git a/DataReader/LUSequenceReader/LUSequenceReader.cpp b/DataReader/LUSequenceReader/LUSequenceReader.cpp index 64ae0d9e6..83b2a7a0e 100644 --- a/DataReader/LUSequenceReader/LUSequenceReader.cpp +++ b/DataReader/LUSequenceReader/LUSequenceReader.cpp @@ -15,6 +15,7 @@ #include // leak detection #endif #include +#include // std::default_random_engine #include "fileutil.h" namespace Microsoft { namespace MSR { namespace CNTK { @@ -353,16 +354,17 @@ void LUSequenceReader::Init(const ConfigParameters& readerConfig) m_traceLevel = readerConfig("traceLevel","0"); m_parser.SetTraceLevel(m_traceLevel); + mRandomize = false; if (readerConfig.Exists("randomize")) { string randomizeString = readerConfig("randomize"); if (randomizeString == "None") { - ; + mRandomize = false; } else if (randomizeString == "Auto") { - ; + mRandomize = true; } else { @@ -374,6 +376,8 @@ void LUSequenceReader::Init(const ConfigParameters& readerConfig) ; //randomizeAuto; } + m_seed = 0; + // The input data is a combination of the label Data and extra feature dims together // m_featureCount = m_featureDim + m_labelInfo[labelInfoIn].dim; m_featureCount = 1; @@ -1138,7 +1142,12 @@ bool BatchLUSequenceReader::EnsureDataAvailable(size_t /*mbStartSample mNumRead = m_parser.Parse(CACHE_BLOG_SIZE, &m_labelTemp, &m_featureTemp, &seqPos); if (mNumRead == 0) return false; - // std::random_shuffle(m_parser.mSentenceIndex2SentenceInfo.begin(), m_parser.mSentenceIndex2SentenceInfo.end()); + if (mRandomize) + { + unsigned seed = m_seed; + std::shuffle(m_parser.mSentenceIndex2SentenceInfo.begin(), m_parser.mSentenceIndex2SentenceInfo.end(), std::default_random_engine(seed)); + m_seed++; + } m_readNextSampleLine += mNumRead; sLn = FindNextSentences(mNumRead); diff --git a/DataReader/LUSequenceReader/LUSequenceReader.h b/DataReader/LUSequenceReader/LUSequenceReader.h index d3e053201..d43bb7f14 100644 --- a/DataReader/LUSequenceReader/LUSequenceReader.h +++ b/DataReader/LUSequenceReader/LUSequenceReader.h @@ -44,6 +44,9 @@ protected: public: int nwords, dims, nsamps, nglen, nmefeats; + int m_seed; + bool mRandomize; + int class_size; map> class_words; vectorclass_cn;