CNTK/DataReader/HTKMLFReader/minibatchiterator.h

300 строки
16 KiB
C++

//
// <copyright file="minibatchiterator.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// minibatchiterator.h -- iterator for minibatches
#pragma once
#define NONUMLATTICEMMI // [v-hansu] move from main.cpp, no numerator lattice for mmi training
#include <vector>
#include <unordered_map>
#include "ssematrix.h"
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
#include "simple_checked_arrays.h" // for const_array_ref
namespace msra { namespace dbn {
// ---------------------------------------------------------------------------
// latticesource -- manages loading of lattices for MMI (in pairs for numer and denom)
// ---------------------------------------------------------------------------
class latticesource
{
const msra::lattices::archive numlattices, denlattices;
public:
latticesource (std::pair<std::vector<wstring>,std::vector<wstring>> latticetocs, const std::unordered_map<std::string,size_t> & modelsymmap)
: numlattices (latticetocs.first, modelsymmap), denlattices (latticetocs.second, modelsymmap) {}
bool empty() const
{
#ifndef NONUMLATTICEMMI // TODO:set NUM lattice to null so as to save memory
if (numlattices.empty() ^ denlattices.empty())
throw std::runtime_error("latticesource: numerator and denominator lattices must be either both empty or both not empty");
#endif
return denlattices.empty();
}
bool haslattice (wstring key) const
{
#ifdef NONUMLATTICEMMI
return denlattices.haslattice (key);
#else
return numlattices.haslattice (key) && denlattices.haslattice (key);
#endif
}
class latticepair : public pair<msra::lattices::lattice,msra::lattices::lattice>
{
public:
// NOTE: we don't check numerator lattice now
size_t getnumframes () const { return second.getnumframes(); }
size_t getnumnodes () const { return second.getnumnodes(); }
size_t getnumedges () const { return second.getnumedges(); }
wstring getkey () const { return second.getkey(); }
};
void getlattices (const std::wstring & key, shared_ptr<const latticesource::latticepair> & L, size_t expectedframes) const
{
shared_ptr<latticepair> LP (new latticepair);
denlattices.getlattice (key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
L = LP;
}
};
// ---------------------------------------------------------------------------
// minibatchsource -- abstracted interface into frame sources
// There are three implementations:
// - the old minibatchframesource to randomize across frames and page to disk
// - minibatchutterancesource that randomizes in chunks and pages from input files directly
// - a wrapper that uses a thread to read ahead in parallel to CPU/GPU processing
// ---------------------------------------------------------------------------
class minibatchsource
{
public:
// read a minibatch
// This function returns all values in a "caller can keep them" fashion:
// - uids are stored in a huge 'const' array, and will never go away
// - transcripts are copied by value
// - lattices are returned as a shared_ptr
// Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
// Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
virtual bool getbatch (const size_t globalts,
const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
virtual bool getbatch (const size_t globalts,
const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
// getbatch() overload to support subsetting of mini-batches for parallel training
// Default implementation does not support subsetting and throws an exception on
// calling this overload with a numsubsets value other than 1.
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t & framesadvanced,
std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
{
assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting()); subsetnum; numsubsets;
bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices);
framesadvanced = feat[0].cols();
return retVal;
}
virtual bool supportsbatchsubsetting() const
{
return false;
}
virtual size_t totalframes() const = 0;
virtual double gettimegetbatch () = 0; // used to report runtime
virtual size_t firstvalidglobalts (const size_t globalts) = 0; // get first valid epoch start from intended 'globalts'
virtual const std::vector<size_t> & unitcounts() const = 0; // report number of senones
virtual void setverbosity(int newverbosity) = 0;
virtual ~minibatchsource() { }
};
// ---------------------------------------------------------------------------
// minibatchiterator -- class to iterate over one epoch, minibatch by minibatch
// This iterator supports both random frames and random utterances through the minibatchsource interface whichis common to both.
// This supports multiple data passes with identical randomization; which is intended to be used for utterance-based training.
// ---------------------------------------------------------------------------
class minibatchiterator
{
void operator= (const minibatchiterator &); // (non-copyable)
const size_t epochstartframe;
const size_t epochendframe;
size_t firstvalidepochstartframe; // epoch start frame rounded up to first utterance boundary after epoch boundary
const size_t requestedmbframes; // requested mb size; actual minibatches can be smaller (or even larger for lattices)
const size_t datapasses; // we return the data this many times; caller must sub-sample with 'datapass'
msra::dbn::minibatchsource & source; // feature source to read from
// subset to read during distributed data-parallel training (no subsetting: (0,1))
size_t subsetnum;
size_t numsubsets;
std::vector<msra::dbn::matrix> featbuf; // buffer for holding curernt minibatch's frames
std::vector<std::vector<size_t>> uids; // buffer for storing current minibatch's frame-level label sequence
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
std::vector<shared_ptr<const latticesource::latticepair>> lattices; // lattices of the utterances in current minibatch (empty in frame mode)
size_t mbstartframe; // current start frame into generalized time line (used for frame-wise mode and for diagnostic messages)
size_t actualmbframes; // actual number of frames in current minibatch
size_t mbframesadvanced; // logical number of frames the current MB represents (to advance time; > featbuf.cols() possible, intended for the case of distributed data-parallel training)
size_t datapass; // current datapass = pass through the data
double timegetbatch; // [v-hansu] for time measurement
double timechecklattice;
private:
// fetch the next mb
// This updates featbuf, uids[], mbstartframe, and actualmbframes.
void fillorclear()
{
if (!hasdata()) // we hit the end of the epoch: just cleanly clear out everything (not really needed, can't be requested ever)
{
foreach_index(i, featbuf)
featbuf[i].resize (0, 0);
foreach_index(i,uids)
uids[i].clear();
transcripts.clear();
actualmbframes = 0;
return;
}
// process one mini-batch (accumulation and update)
assert (requestedmbframes > 0);
const size_t requestedframes = min (requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
assert (requestedframes > 0);
source.getbatch (mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices);
timegetbatch = source.gettimegetbatch();
actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
// note:
// - in frame mode, actualmbframes may still return less if at end of sweep
// - in utterance mode, it likely returns less than requested, and
// it may also be > epochendframe (!) for the last utterance, which, most likely, crosses the epoch boundary
// - in case of data parallelism, featbuf.cols() < mbframesadvanced
auto_timer timerchecklattice;
if (!lattices.empty())
{
size_t totalframes = 0;
foreach_index (i, lattices)
totalframes += lattices[i]->getnumframes();
if (totalframes != actualmbframes)
throw std::logic_error ("fillorclear: frames in lattices do not match minibatch size");
}
timechecklattice = timerchecklattice;
}
bool hasdata() const { return mbstartframe < epochendframe; } // true if we can access and/or advance
void checkhasdata() const { if (!hasdata()) throw std::logic_error ("minibatchiterator: access beyond end of epoch"); }
public:
// interface: for (minibatchiterator i (...), i, i++) { ... }
minibatchiterator (msra::dbn::minibatchsource & source, size_t epoch, size_t epochframes, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
: source (source),
epochstartframe (epoch * epochframes),
epochendframe (epochstartframe + epochframes),
requestedmbframes (requestedmbframes),
subsetnum(subsetnum), numsubsets(numsubsets),
datapasses (datapasses),
timegetbatch (0), timechecklattice (0)
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
}
// TODO not nice, but don't know how to access these frames otherwise
// mbiterator constructor, set epochstart and -endframe explicitly
minibatchiterator(msra::dbn::minibatchsource & source, size_t epoch, size_t epochstart, size_t epochend, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
: source (source),
epochstartframe (epochstart),
epochendframe (epochend),
requestedmbframes (requestedmbframes),
subsetnum(subsetnum), numsubsets(numsubsets),
datapasses(datapasses),
timegetbatch (0), timechecklattice (0)
{
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
mbstartframe = firstvalidepochstartframe;
datapass = 0;
fillorclear(); // get the first batch
}
// need virtual destructor to ensure proper destruction
virtual ~minibatchiterator()
{}
// returns true if we still have data
operator bool() const { return hasdata(); }
// advance to the next minimb
void operator++(int/*denotes postfix version*/)
{
checkhasdata();
mbstartframe += mbframesadvanced;
// if we hit the end, we will get mbstartframe >= epochendframe <=> !hasdata()
// (most likely actually mbstartframe > epochendframe since the last utterance likely crosses the epoch boundary)
// in case of multiple datapasses, reset to start when hitting the end
if (!hasdata() && datapass + 1 < datapasses)
{
mbstartframe = firstvalidepochstartframe;
datapass++;
fprintf (stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", (int)(datapass+1));
}
fillorclear();
}
// accessors to current minibatch
size_t currentmbstartframe() const { return mbstartframe; }
size_t currentmbframes() const { return actualmbframes; }
size_t currentmbframesadvanced() const { return mbframesadvanced; }
size_t currentmblattices() const { return lattices.size(); }
size_t currentdatapass() const { return datapass; } // 0..datapasses-1; use this for sub-sampling
size_t requestedframes() const {return requestedmbframes; }
double gettimegetbatch () {return timegetbatch;}
double gettimechecklattice () {return timechecklattice;}
bool isfirst() const { return mbstartframe == firstvalidepochstartframe && datapass == 0; }
float progress() const // (note: 100%+eps possible for last utterance)
{
const float epochframes = (float) (epochendframe - epochstartframe);
return (mbstartframe + mbframesadvanced - epochstartframe + datapass * epochframes) / (datapasses * epochframes);
}
std::pair<size_t,size_t> range() const { return make_pair (epochstartframe, epochendframe); }
// return the current minibatch frames as a matrix ref into the feature buffer
// Number of frames is frames().cols() == currentmbframes().
// For frame-based randomization, this is 'requestedmbframes' most of the times, while for utterance randomization,
// this depends highly on the utterance lengths.
// User is allowed to manipulate the frames... for now--TODO: move silence filtering here as well
msra::dbn::matrixstripe frames(size_t i) { checkhasdata(); assert(featbuf.size()>=i+1); return msra::dbn::matrixstripe (featbuf[i], 0, actualmbframes); }
msra::dbn::matrixstripe frames() { checkhasdata(); assert(featbuf.size()==1); return msra::dbn::matrixstripe (featbuf[0], 0, actualmbframes); }
// return the reference transcript labels (state alignment) for current minibatch
/*const*/ std::vector<size_t> & labels() { checkhasdata(); assert(uids.size()==1);return uids[0]; }
/*const*/ std::vector<size_t> & labels(size_t i) { checkhasdata(); assert(uids.size()>=i+1); return uids[i]; }
// return a lattice for an utterance (caller should first get total through currentmblattices())
shared_ptr<const msra::dbn::latticesource::latticepair> lattice (size_t uttindex) const { return lattices[uttindex]; } // lattices making up the current
// return the reference transcript labels (words with alignments) for current minibatch (or empty if no transcripts requested)
const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word> transcript (size_t uttindex) { return transcripts.empty() ? const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>() : transcripts[uttindex]; }
};
};};