310 строки
14 KiB
C++
310 строки
14 KiB
C++
//
|
|
// <copyright file="readaheadsource.h" company="Microsoft">
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
//
|
|
// readaheadsource.h -- wrapper ('minibatchreadaheadsource') of a read-ahead thread that pre-rolls feature and lattice data
|
|
//
|
|
// F. Seide, Oct 2012
|
|
//
|
|
#pragma once
|
|
|
|
#include "basetypes.h"
|
|
#include "simplethread.h"
|
|
#include "Matrix.h"
|
|
#include "DataReader.h"
|
|
#include <deque>
|
|
#include <stdexcept>
|
|
#include <vector>
|
|
#include <memory>
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// minibatchreadaheadsource -- read-ahead thread that pre-rolls feature and lattice data
|
|
// ---------------------------------------------------------------------------
|
|
template<class ElemType, typename LabelType>
|
|
class minibatchreadaheadsource : public noncopyable/*assignment operator needed somewhere*/,
|
|
CCritSec/*for multi-threaded access*/
|
|
{
|
|
typedef unsigned LabelIdType;
|
|
size_t epochframes; // epoch size
|
|
IDataReader<ElemType, LabelType>* m_dataReader; // data source
|
|
std::unique_ptr<msra::util::simplethread> thread;
|
|
int verbosity;
|
|
// the FIFO
|
|
struct batchdata // all arguments to/from getbatch
|
|
{
|
|
size_t globalts; // time for which we get the data
|
|
// return values
|
|
Matrix<ElemType> feat;
|
|
Matrix<ElemType> labels;
|
|
bool readRecords;
|
|
// constructor takes feature and labels array and moves them
|
|
batchdata (size_t globalts) {Init(globalts); }
|
|
void Init(size_t p_globalts)
|
|
{
|
|
globalts = p_globalts;
|
|
readRecords = false;
|
|
}
|
|
|
|
// move constructor
|
|
batchdata(batchdata&& moveFrom)
|
|
{
|
|
globalts = moveFrom.globalts;
|
|
readRecords = moveFrom.readRecords;
|
|
feat = move(moveFrom.feat);
|
|
labels = move(moveFrom.labels); //shallow copy the pointer
|
|
}
|
|
//move assignment operator, shallow copy
|
|
batchdata& operator=(batchdata&& moveFrom)
|
|
{
|
|
if (this != &moveFrom)
|
|
{
|
|
globalts = moveFrom.globalts;
|
|
readRecords = moveFrom.readRecords;
|
|
// swap the matricies so we don't loose any of them
|
|
swap(feat,moveFrom.feat);
|
|
swap(labels,moveFrom.labels);
|
|
}
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
deque<batchdata> batchCache; // batches waiting to be used
|
|
deque<batchdata> fifo; // fifo queue for batches that are ready
|
|
size_t epoch; // which epoch we are in currently
|
|
|
|
// parameters for the thread proc (set by caller; taken over once newglobalts is set to non-SIZE_MAX (cleared back by thread))
|
|
volatile size_t newglobalts; // reset request
|
|
volatile size_t currentepochreqframes; // minibatch size for this epoch (taken from the first getbatch() call)
|
|
volatile size_t currentepochendframe; // we cannot request beyond
|
|
|
|
// signaling
|
|
mutable msra::util::signallingevent callerchangedsignal, threadchangedsignal;
|
|
void waitcallerchanged() const { callerchangedsignal.wait(); }
|
|
void flagcallerchanged() const { callerchangedsignal.flag(); }
|
|
void waitthreadchanged() const { threadchangedsignal.wait(); }
|
|
void flagthreadchanged() const { threadchangedsignal.flag(); }
|
|
|
|
// GetBatch from the batch cache, or return a new one
|
|
batchdata GetBatch(size_t globalts)
|
|
{
|
|
CAutoLock(*this); // lock the batch cache
|
|
if (batchCache.empty())
|
|
{
|
|
batchdata batch(globalts);
|
|
batchCache.push_back(move(batch));
|
|
}
|
|
batchdata cacheBatch = move(batchCache.front());
|
|
batchCache.pop_front();
|
|
cacheBatch.Init(globalts);
|
|
return move(cacheBatch);
|
|
}
|
|
|
|
// the thread proc
|
|
volatile bool terminaterequest; // threadproc must respond to this
|
|
size_t globalts; // read cursor, owned by thread only
|
|
void threadproc()
|
|
{
|
|
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread entered\n");
|
|
try
|
|
{
|
|
size_t epochreqframes = 0; // minibatch size for this epoch (taken from the first getbatch() call)
|
|
size_t epochendframe = 0; // we cannot request beyond
|
|
size_t globalts = 0; // reset request
|
|
while (!terminaterequest)
|
|
{
|
|
bool stillhasdata;
|
|
{
|
|
CAutoLock lock (*this);
|
|
// if reset request then do it
|
|
if (newglobalts != SIZE_MAX)
|
|
{
|
|
// take over parameters from caller
|
|
globalts = newglobalts;
|
|
epochreqframes = currentepochreqframes;
|
|
epochendframe = currentepochendframe;
|
|
newglobalts = SIZE_MAX; // remember we got it
|
|
// reset the FIFO
|
|
fifo.clear();
|
|
flagthreadchanged(); // signal state change (needed?)
|
|
fprintf (stderr, "minibatchreadaheadsource: thread entered new epoch, frame pos reset to %d\n", (int) globalts);
|
|
continue;
|
|
}
|
|
stillhasdata = !fifo.empty();
|
|
}
|
|
// we kick in once the FIFO is empty (and only once we know the mbsize)
|
|
if (globalts >= epochendframe || stillhasdata)
|
|
{
|
|
waitcallerchanged(); // nothing to do: wait for caller state change and check again
|
|
continue;
|
|
}
|
|
|
|
// we will bring in data from the current 'globalts' until the sub-getbatch() tells us
|
|
// that we loaded new data (which means subsequent getbatch() will be free until the next load).
|
|
// We assume the access pattern that
|
|
// - we start at or closely after the epoch boundary
|
|
// - we never go across an epoch boundary
|
|
// - the number of requested frames within an epoch is always the same except for the last MB
|
|
// This pattern is implemented by the minibatchiterator. We require it.
|
|
// (but it is possible that less is returned, i.e. at a sweep boundary or epoch end).
|
|
bool readfromdisk = false;
|
|
// we stop once data was read (the subsequent fetches will be cheap until the next data read)
|
|
// For small setups, all data may be in RAM and thus no reading will happen anymore.
|
|
// To guard against that, we limit the number of frames we pre-read.
|
|
fprintf (stderr, "minibatchreadaheadsource: thread entering reading loop, frame read pos %d\n", (int) globalts);
|
|
size_t batchesread = 0;
|
|
const size_t prerollendframe = globalts + 360000; // read max. 1 hour --to guard against setups that fit to RAM entirely (no disk reading after startup)
|
|
while (!terminaterequest && !readfromdisk && globalts < epochendframe && globalts < prerollendframe)
|
|
{
|
|
// get batch and append to FIFO (outside the lock)
|
|
batchdata batch = GetBatch(globalts);
|
|
const size_t requestedframes = min (epochreqframes, epochendframe - globalts); // we must not request beyond the epoch
|
|
bool readRecords = m_dataReader->GetMinibatch(batch.feat, batch.labels);
|
|
readfromdisk = !readRecords;
|
|
batch.readRecords = readRecords;
|
|
if (readRecords)
|
|
{
|
|
batchesread++;
|
|
// Note: We may still get data beyond the end of the epoch, in utterance mode, since the epoch boundary likely falls within an utterance.
|
|
CAutoLock lock (*this);
|
|
if (!fifo.empty() && globalts != fifo.back().globalts + fifo.back().feat.GetNumCols())
|
|
throw std::logic_error ("minibatchreadaheadsource: FIFO got out of order while pre-reading new batch");
|
|
if (newglobalts != SIZE_MAX)
|
|
throw std::logic_error ("minibatchreadaheadsource: main thread reset to new epoch while current epoch not yet finished");
|
|
globalts += batch.feat.GetNumCols();
|
|
}
|
|
else
|
|
{
|
|
// we usually get here because the epoch has ended, or we hit end of file
|
|
// we support skipping the last minibatch if it is partial, so set globalts to correct value
|
|
if (globalts < epochendframe) // && batchesread > 0)
|
|
{
|
|
fprintf (stderr, "minibatchreadaheadsource: skipping to next epoch, epoch ended at %d\n", (int) globalts);
|
|
globalts = epochendframe;
|
|
} // set to end of frame so the thread management works properly
|
|
}
|
|
|
|
// push into the fifo queue, if we didn't read anything the other side won't use data arrays, just flags
|
|
fifo.push_back (std::move (batch));
|
|
flagthreadchanged(); // signal state change so caller can pick up the new batch
|
|
}
|
|
fprintf (stderr, "minibatchreadaheadsource: thread exited reading loop, %d batches read up to frame position %d-1\n", (int) batchesread, (int) globalts);
|
|
}
|
|
}
|
|
catch (...)
|
|
{
|
|
flagthreadchanged();
|
|
}
|
|
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread exited\n");
|
|
}
|
|
void cancelthread()
|
|
{
|
|
terminaterequest = true;
|
|
flagcallerchanged();
|
|
thread->wait();
|
|
}
|
|
public:
|
|
minibatchreadaheadsource (IDataReader<ElemType, LabelType>* p_datareader)
|
|
: terminaterequest (false), globalts (SIZE_MAX),
|
|
epoch (SIZE_MAX), currentepochreqframes (0), currentepochendframe (0), newglobalts (SIZE_MAX), verbosity(2)
|
|
{
|
|
// kick off the thread
|
|
m_dataReader = p_datareader;
|
|
fprintf (stderr, "minibatchreadaheadsource: kicking off read-ahead thread\n");
|
|
thread.reset (new msra::util::simplethread ([this] () { threadproc(); }));
|
|
}
|
|
|
|
// Init - Initialize the readahead reader for another epoch
|
|
// p_globalts - sample number we are reading
|
|
// p_framesrequested - minibatch size,
|
|
// p_epoch - epoch we are requesting
|
|
// p_epochframes - number of frames in the epoch
|
|
void Init(size_t p_globalts, size_t p_framesrequested, size_t p_epoch, size_t p_epochframes)
|
|
{
|
|
// first check whether the thread is still alive
|
|
thread->check();
|
|
|
|
// transfer over the parameter values
|
|
epochframes = p_epochframes;
|
|
globalts = p_globalts;
|
|
epoch = p_epoch;
|
|
|
|
assert(p_epoch == globalts / epochframes);
|
|
|
|
// now start everthing rolling
|
|
fprintf (stderr, "minibatchreadaheadsource: signalling thread to enter new epoch\n");
|
|
CAutoLock lock (*this);
|
|
if (!fifo.empty())
|
|
throw std::logic_error ("getbatch: FIFO not cleared at end of epoch");
|
|
newglobalts = globalts;
|
|
currentepochreqframes = p_framesrequested; // it is assumed that these won't change
|
|
currentepochendframe = (epoch + 1) * epochframes;
|
|
flagcallerchanged();
|
|
}
|
|
|
|
~minibatchreadaheadsource()
|
|
{
|
|
cancelthread();
|
|
}
|
|
void setverbosity(int newverbosity){ verbosity = newverbosity; }
|
|
|
|
// getbatch - get a batch from the queue
|
|
// globalts - time stamp for this batch
|
|
// framesrequested - number of frames requested
|
|
// feat - feature matrix
|
|
// labels - label matrix
|
|
bool getbatch (const size_t globalts,
|
|
const size_t framesrequested, Matrix<ElemType> & feat, Matrix<ElemType> & labels)
|
|
{
|
|
// first check whether the thread is still alive
|
|
thread->check();
|
|
// in case of epoch change, we signal the thread
|
|
size_t thisepoch = globalts / epochframes;
|
|
if (thisepoch != epoch)
|
|
{
|
|
// if epoch changed just return;
|
|
return false;
|
|
}
|
|
|
|
if (globalts + framesrequested < currentepochendframe && currentepochreqframes != framesrequested)
|
|
throw std::logic_error ("getbatch: cannot change minibatch size mid-epoch");
|
|
// loop
|
|
for(;;) // wait for batch to appear
|
|
{
|
|
thread->check();
|
|
{
|
|
CAutoLock lock (*this);
|
|
if (!fifo.empty())
|
|
{
|
|
// get the first batch from the FIFO
|
|
batchdata front = std::move (fifo.front());
|
|
fifo.pop_front();
|
|
flagcallerchanged();
|
|
// it must be the correct one
|
|
if (front.globalts != globalts)
|
|
throw std::logic_error ("getbatch: data in FIFO out of sequence");
|
|
|
|
// if we actually read anything put it in here
|
|
if (front.readRecords)
|
|
{
|
|
// swap the Matricies between the passed in values and the batchdata
|
|
std::swap(feat, front.feat);
|
|
std::swap(labels, front.labels);
|
|
}
|
|
|
|
// put the batch in the batch cache;
|
|
batchCache.push_back(front);
|
|
return front.readRecords;
|
|
}
|
|
}
|
|
// batch not there --keep looping
|
|
waitthreadchanged();
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
}}}
|