Merge branch 'amitaga/prefetcher_packingmask'

This commit is contained in:
Amit Agarwal 2015-07-17 20:40:16 -07:00
Родитель 329d3bc6fd 9dcc79948d
Коммит 1726a5944e
3 изменённых файлов: 58 добавлений и 10 удалений

Просмотреть файл

@ -17,10 +17,17 @@ template<class ElemType>
class MinibatchFetcher
{
public:
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader,
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
Matrix<ElemType>* sentenceBegin,
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels)
:
m_reader(trainSetDataReader),
m_inputMatrices(inputMatrices)
m_inputMatrices(inputMatrices),
m_sentenceBegin(sentenceBegin),
m_sentenceExistsBeginOrNoLabels(sentenceExistsBeginOrNoLabels)
{
assert((m_sentenceBegin != nullptr) && (m_sentenceExistsBeginOrNoLabels != nullptr));
}
// This virtual dtor is necessary to allow invocation of derived dtors, which have some required synchronization points
@ -28,12 +35,17 @@ public:
virtual bool GetMinibatch()
{
return m_reader->GetMinibatch(*const_cast<std::map<std::wstring, Matrix<ElemType>*>*>(m_inputMatrices));
bool retVal = m_reader->GetMinibatch(*m_inputMatrices);
m_reader->SetSentenceSegBatch(*m_sentenceBegin, *m_sentenceExistsBeginOrNoLabels);
return retVal;
}
protected:
IDataReader<ElemType>* m_reader;
const std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
Matrix<ElemType>* m_sentenceBegin;
vector<MinibatchPackingFlag>* m_sentenceExistsBeginOrNoLabels;
};
}}}

Просмотреть файл

@ -24,8 +24,13 @@ template<class ElemType>
class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
{
public:
MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices),
MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader,
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
Matrix<ElemType>* sentenceBegin,
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels) :
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, sentenceBegin, sentenceExistsBeginOrNoLabels),
m_prefetchSentenceBegin(nullptr),
m_prefetchSentenceExistsBeginOrNoLabels(nullptr),
m_isEpochReadingDone(false),
m_minibatchReady(false),
m_isTerminating(false)
@ -42,6 +47,20 @@ public:
iter->second->GetFormat());
}
if (sentenceBegin != nullptr)
{
m_prefetchSentenceBegin = new Matrix<ElemType>(sentenceBegin->GetNumRows(),
sentenceBegin->GetNumCols(),
sentenceBegin->GetDeviceId(),
sentenceBegin->GetMatrixType(),
sentenceBegin->GetFormat());
}
if (sentenceExistsBeginOrNoLabels != nullptr)
{
m_prefetchSentenceExistsBeginOrNoLabels = new vector<MinibatchPackingFlag>();
}
// Launch a worker thread
m_prefetchThread = std::thread([this]() { this->PrefetchWorker(); });
}
@ -66,6 +85,9 @@ public:
{
delete iter->second;
}
delete m_prefetchSentenceBegin;
delete m_prefetchSentenceExistsBeginOrNoLabels;
}
virtual bool GetMinibatch()
@ -97,6 +119,17 @@ public:
std::swap(*(iter->second), *m_prefetchInput[iter->first]);
}
if (m_sentenceBegin != nullptr)
{
assert(m_sentenceBegin->GetDeviceId() == m_prefetchSentenceBegin->GetDeviceId());
std::swap(*m_sentenceBegin, *m_prefetchSentenceBegin);
}
if (m_sentenceExistsBeginOrNoLabels != nullptr)
{
std::swap(*m_sentenceExistsBeginOrNoLabels, *m_prefetchSentenceExistsBeginOrNoLabels);
}
hasMoreEpochReading = true;
}
@ -160,7 +193,9 @@ private:
Matrix<ElemType>::SyncComputeBeforeRead(m_deviceId);
// Get the next minibatch and wait for it to be available on the device
bool isDone = !this->m_reader->GetMinibatch(const_cast<std::map<std::wstring, Matrix<ElemType>*>&>(m_prefetchInput));
bool isDone = !this->m_reader->GetMinibatch(m_prefetchInput);
this->m_reader->SetSentenceSegBatch(*m_prefetchSentenceBegin, *m_prefetchSentenceExistsBeginOrNoLabels);
Matrix<ElemType>::SyncPendingRead(m_deviceId);
return isDone;
@ -168,6 +203,8 @@ private:
// @TODO: We need to add support for a larger number of prefetch buffers, larger than 1
std::map<std::wstring, Matrix<ElemType>*> m_prefetchInput;
Matrix<ElemType>* m_prefetchSentenceBegin;
vector<MinibatchPackingFlag>* m_prefetchSentenceExistsBeginOrNoLabels;
std::thread m_prefetchThread;
std::mutex m_mutex;
std::condition_variable m_cv;

Просмотреть файл

@ -1710,8 +1710,8 @@ protected:
AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, FeatureNodes, inputMatrices);
std::unique_ptr<MinibatchFetcher<ElemType>> mbFetcher(
m_doPrefetchTrainingData ?
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices) :
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices));
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())) :
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())));
fprintf(stderr, "\nStarting minibatch loop, prefetching is: %s\n", m_doPrefetchTrainingData ? "ENABLED" : "DISABLED");
@ -1735,7 +1735,6 @@ protected:
net.SetActualMiniBatchSize(actualMBSize);
net.SetActualNbrSlicesInEachRecIter(trainSetDataReader->NumberSlicesInEachRecurrentIter());
trainSetDataReader->SetSentenceSegBatch(net.SentenceBoundary(), net.MinibatchPackingFlags());
#ifndef EVALDLL
if (m_doGradientCheck && GradientCheck(net, criterionNodes, learnableNodes, 0) == false)