Merge branch 'amitaga/prefetcher_packingmask'
This commit is contained in:
Коммит
1726a5944e
|
@ -17,10 +17,17 @@ template<class ElemType>
|
|||
class MinibatchFetcher
|
||||
{
|
||||
public:
|
||||
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
|
||||
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader,
|
||||
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
|
||||
Matrix<ElemType>* sentenceBegin,
|
||||
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels)
|
||||
:
|
||||
m_reader(trainSetDataReader),
|
||||
m_inputMatrices(inputMatrices)
|
||||
m_inputMatrices(inputMatrices),
|
||||
m_sentenceBegin(sentenceBegin),
|
||||
m_sentenceExistsBeginOrNoLabels(sentenceExistsBeginOrNoLabels)
|
||||
{
|
||||
assert((m_sentenceBegin != nullptr) && (m_sentenceExistsBeginOrNoLabels != nullptr));
|
||||
}
|
||||
|
||||
// This virtual dtor is necessary to allow invocation of derived dtors, which have some required synchronization points
|
||||
|
@ -28,12 +35,17 @@ public:
|
|||
|
||||
virtual bool GetMinibatch()
|
||||
{
|
||||
return m_reader->GetMinibatch(*const_cast<std::map<std::wstring, Matrix<ElemType>*>*>(m_inputMatrices));
|
||||
bool retVal = m_reader->GetMinibatch(*m_inputMatrices);
|
||||
m_reader->SetSentenceSegBatch(*m_sentenceBegin, *m_sentenceExistsBeginOrNoLabels);
|
||||
|
||||
return retVal;
|
||||
}
|
||||
|
||||
protected:
|
||||
IDataReader<ElemType>* m_reader;
|
||||
const std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
|
||||
std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
|
||||
Matrix<ElemType>* m_sentenceBegin;
|
||||
vector<MinibatchPackingFlag>* m_sentenceExistsBeginOrNoLabels;
|
||||
};
|
||||
|
||||
}}}
|
|
@ -24,8 +24,13 @@ template<class ElemType>
|
|||
class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
|
||||
{
|
||||
public:
|
||||
MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
|
||||
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices),
|
||||
MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader,
|
||||
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
|
||||
Matrix<ElemType>* sentenceBegin,
|
||||
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels) :
|
||||
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, sentenceBegin, sentenceExistsBeginOrNoLabels),
|
||||
m_prefetchSentenceBegin(nullptr),
|
||||
m_prefetchSentenceExistsBeginOrNoLabels(nullptr),
|
||||
m_isEpochReadingDone(false),
|
||||
m_minibatchReady(false),
|
||||
m_isTerminating(false)
|
||||
|
@ -42,6 +47,20 @@ public:
|
|||
iter->second->GetFormat());
|
||||
}
|
||||
|
||||
if (sentenceBegin != nullptr)
|
||||
{
|
||||
m_prefetchSentenceBegin = new Matrix<ElemType>(sentenceBegin->GetNumRows(),
|
||||
sentenceBegin->GetNumCols(),
|
||||
sentenceBegin->GetDeviceId(),
|
||||
sentenceBegin->GetMatrixType(),
|
||||
sentenceBegin->GetFormat());
|
||||
}
|
||||
|
||||
if (sentenceExistsBeginOrNoLabels != nullptr)
|
||||
{
|
||||
m_prefetchSentenceExistsBeginOrNoLabels = new vector<MinibatchPackingFlag>();
|
||||
}
|
||||
|
||||
// Launch a worker thread
|
||||
m_prefetchThread = std::thread([this]() { this->PrefetchWorker(); });
|
||||
}
|
||||
|
@ -66,6 +85,9 @@ public:
|
|||
{
|
||||
delete iter->second;
|
||||
}
|
||||
|
||||
delete m_prefetchSentenceBegin;
|
||||
delete m_prefetchSentenceExistsBeginOrNoLabels;
|
||||
}
|
||||
|
||||
virtual bool GetMinibatch()
|
||||
|
@ -97,6 +119,17 @@ public:
|
|||
std::swap(*(iter->second), *m_prefetchInput[iter->first]);
|
||||
}
|
||||
|
||||
if (m_sentenceBegin != nullptr)
|
||||
{
|
||||
assert(m_sentenceBegin->GetDeviceId() == m_prefetchSentenceBegin->GetDeviceId());
|
||||
std::swap(*m_sentenceBegin, *m_prefetchSentenceBegin);
|
||||
}
|
||||
|
||||
if (m_sentenceExistsBeginOrNoLabels != nullptr)
|
||||
{
|
||||
std::swap(*m_sentenceExistsBeginOrNoLabels, *m_prefetchSentenceExistsBeginOrNoLabels);
|
||||
}
|
||||
|
||||
hasMoreEpochReading = true;
|
||||
}
|
||||
|
||||
|
@ -160,7 +193,9 @@ private:
|
|||
Matrix<ElemType>::SyncComputeBeforeRead(m_deviceId);
|
||||
|
||||
// Get the next minibatch and wait for it to be available on the device
|
||||
bool isDone = !this->m_reader->GetMinibatch(const_cast<std::map<std::wstring, Matrix<ElemType>*>&>(m_prefetchInput));
|
||||
bool isDone = !this->m_reader->GetMinibatch(m_prefetchInput);
|
||||
this->m_reader->SetSentenceSegBatch(*m_prefetchSentenceBegin, *m_prefetchSentenceExistsBeginOrNoLabels);
|
||||
|
||||
Matrix<ElemType>::SyncPendingRead(m_deviceId);
|
||||
|
||||
return isDone;
|
||||
|
@ -168,6 +203,8 @@ private:
|
|||
|
||||
// @TODO: We need to add support for a larger number of prefetch buffers, larger than 1
|
||||
std::map<std::wstring, Matrix<ElemType>*> m_prefetchInput;
|
||||
Matrix<ElemType>* m_prefetchSentenceBegin;
|
||||
vector<MinibatchPackingFlag>* m_prefetchSentenceExistsBeginOrNoLabels;
|
||||
std::thread m_prefetchThread;
|
||||
std::mutex m_mutex;
|
||||
std::condition_variable m_cv;
|
||||
|
|
|
@ -1710,8 +1710,8 @@ protected:
|
|||
AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, FeatureNodes, inputMatrices);
|
||||
std::unique_ptr<MinibatchFetcher<ElemType>> mbFetcher(
|
||||
m_doPrefetchTrainingData ?
|
||||
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices) :
|
||||
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices));
|
||||
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())) :
|
||||
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())));
|
||||
|
||||
fprintf(stderr, "\nStarting minibatch loop, prefetching is: %s\n", m_doPrefetchTrainingData ? "ENABLED" : "DISABLED");
|
||||
|
||||
|
@ -1735,7 +1735,6 @@ protected:
|
|||
|
||||
net.SetActualMiniBatchSize(actualMBSize);
|
||||
net.SetActualNbrSlicesInEachRecIter(trainSetDataReader->NumberSlicesInEachRecurrentIter());
|
||||
trainSetDataReader->SetSentenceSegBatch(net.SentenceBoundary(), net.MinibatchPackingFlags());
|
||||
|
||||
#ifndef EVALDLL
|
||||
if (m_doGradientCheck && GradientCheck(net, criterionNodes, learnableNodes, 0) == false)
|
||||
|
|
Загрузка…
Ссылка в новой задаче