This commit is contained in:
Родитель
fff5be60d6
Коммит
8ca0d095d6
|
@ -12,9 +12,7 @@
|
|||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <cstdio>
|
||||
/* guoye: start */
|
||||
#include <vector>
|
||||
/* guoye: end */
|
||||
|
||||
#undef INITIAL_STRANGE // [v-hansu] initialize structs to strange values
|
||||
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
|
||||
|
@ -34,21 +32,14 @@ struct nodeinfo
|
|||
// uint64_t firstoutedge : 24; // index of first outgoing edge
|
||||
// uint64_t t : 16; // time associated with this
|
||||
|
||||
/* guoye: start */
|
||||
uint64_t wid; // word ID associated with the node
|
||||
/* guoye: end */
|
||||
unsigned short t; // time associated with this
|
||||
|
||||
nodeinfo(size_t pt, size_t pwid)
|
||||
/* guoye: start */
|
||||
// : t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)
|
||||
: t((unsigned short)pt), wid(pwid)
|
||||
/* guoye: end */
|
||||
{
|
||||
checkoverflow(t, pt, "nodeinfo::t");
|
||||
/* guoye: start */
|
||||
checkoverflow(wid, pwid, "nodeinfo::wid");
|
||||
/* guoye: end */
|
||||
// checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
|
||||
// checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
|
||||
}
|
||||
|
|
|
@ -1959,11 +1959,6 @@ protected:
|
|||
template<typename ValueType>
|
||||
void TypedRequestMatrixFromPool(shared_ptr<Matrix<ValueType>>& matrixPtr, MatrixPool& matrixPool, size_t matrixSize=0, bool mbScale=false, bool isWorkSpace=false, bool aliasing=false)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 0 \n");
|
||||
|
||||
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 1 \n");
|
||||
/* guoye: end */
|
||||
if (matrixPtr == nullptr)
|
||||
{
|
||||
if (aliasing)
|
||||
|
|
|
@ -247,17 +247,8 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_gradientTemp, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 9 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -327,17 +318,8 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_diff, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -403,17 +385,8 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_softmax, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
|
|
@ -370,10 +370,7 @@ public:
|
|||
|
||||
// sequence training
|
||||
GPUMatrix<ElemType>& DropFrame(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& gamma, const ElemType& threshhold);
|
||||
/* guoye: start */
|
||||
//GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
|
||||
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha, bool MBR);
|
||||
/* guoye: end */
|
||||
GPUMatrix<ElemType>& AssignCTCScore(const GPUMatrix<ElemType>& prob, GPUMatrix<ElemType>& alpha, GPUMatrix<ElemType>& beta,
|
||||
const GPUMatrix<ElemType> phoneSeq, const GPUMatrix<ElemType> phoneBoundary, GPUMatrix<ElemType> & totalScore, const vector<size_t>& uttMap, const vector<size_t> & uttBeginFrame, const vector<size_t> & uttFrameNum,
|
||||
const vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const size_t blankTokenId, const int delayConstraint, const bool isColWise);
|
||||
|
|
|
@ -534,7 +534,6 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
|
||||
specialwordids.insert(0xfffff);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
|
@ -555,8 +554,6 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
int start_id = 6;
|
||||
readwordidmap(wordidmappath, wordidmap, start_id);
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
specialwordids.clear();
|
||||
specialwords.clear();
|
||||
|
||||
|
@ -593,7 +590,6 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
|
||||
specialwordids.insert(0xfffff);
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
|
@ -636,41 +632,18 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
double htktimetoframe = 100000.0; // default is 10ms
|
||||
// std::vector<msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>> labelsmulti;
|
||||
/* guoye: start */
|
||||
// std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
|
||||
std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> labelsmulti;
|
||||
// std::vector<std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence>> wordlabelsmulti;
|
||||
|
||||
/* debug to clean wordidmap */
|
||||
// wordidmap.clear();
|
||||
/* guoye: end */
|
||||
// std::vector<std::wstring> pagepath;
|
||||
foreach_index (i, mlfpathsmulti)
|
||||
{
|
||||
/* guoye: start */
|
||||
/*
|
||||
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
|
||||
*/
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
|
||||
// msra::asr::htkmlfreader<msra::asr::htkmlfentry>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, htktimetoframe); // label MLF
|
||||
// labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
|
||||
/* guoye: end */
|
||||
// get the temp file name for the page file
|
||||
|
||||
// Make sure 'msra::asr::htkmlfreader' type has a move constructor
|
||||
static_assert(std::is_move_constructible<msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>>::value,
|
||||
"Type 'msra::asr::htkmlfreader' should be move constructible!");
|
||||
|
||||
/* guoye: start */
|
||||
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels = labels.get_wordlabels();
|
||||
// guoye debug purpose
|
||||
// fprintf(stderr, "debug to set wordlabels to empty");
|
||||
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels;
|
||||
// wordlabelsmulti.push_back(std::move(wordlabels));
|
||||
/* guoye: end */
|
||||
labelsmulti.push_back(std::move(labels));
|
||||
}
|
||||
|
||||
|
@ -683,11 +656,7 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
// now get the frame source. This has better randomization and doesn't create temp files
|
||||
bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false);
|
||||
/* guoye: start */
|
||||
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
|
||||
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, wordlabelsmulti, specialwordids, m_featDims, m_labelDims,
|
||||
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, specialwordids, m_featDims, m_labelDims,
|
||||
/* guoye: end */
|
||||
numContextLeft, numContextRight, randomize,
|
||||
*m_lattices, m_latticeMap, m_frameMode,
|
||||
m_expandToUtt, m_maxUtteranceLength, m_truncated));
|
||||
|
@ -921,10 +890,8 @@ void HTKMLFReader<ElemType>::StartDistributedMinibatchLoop(size_t requestedMBSiz
|
|||
// for the multi-utterance process for lattice and phone boundary
|
||||
m_latticeBufferMultiUtt.assign(m_numSeqsPerMB, nullptr);
|
||||
m_labelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
/* guoye: start */
|
||||
m_wlabelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
m_nwsBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
/* guoye: end */
|
||||
m_phoneboundaryIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
|
||||
if (m_frameMode && (m_numSeqsPerMB > 1))
|
||||
|
@ -1063,17 +1030,11 @@ void HTKMLFReader<ElemType>::StartMinibatchLoopToWrite(size_t mbSize, size_t /*e
|
|||
|
||||
template <class ElemType>
|
||||
bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
|
||||
/* guoye: start */
|
||||
vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
// vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
/* guoye: end */
|
||||
{
|
||||
if (m_trainOrTest)
|
||||
{
|
||||
/* guoye: start */
|
||||
// return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap);
|
||||
return GetMinibatch4SEToTrainOrTest(latticeinput, uids, wids, nws, boundaries, extrauttmap);
|
||||
/* guoye: end */
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1083,30 +1044,21 @@ bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::
|
|||
template <class ElemType>
|
||||
bool HTKMLFReader<ElemType>::GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
|
||||
|
||||
/* guoye: start */
|
||||
std::vector<size_t>& uids, std::vector<size_t>& wids, std::vector<short>& nws, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
|
||||
// std::vector<size_t>& uids, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
|
||||
|
||||
/* guoye: end */
|
||||
{
|
||||
latticeinput.clear();
|
||||
uids.clear();
|
||||
/* guoye: start */
|
||||
wids.clear();
|
||||
nws.clear();
|
||||
/* guoye: end */
|
||||
boundaries.clear();
|
||||
extrauttmap.clear();
|
||||
for (size_t i = 0; i < m_extraSeqsPerMB.size(); i++)
|
||||
{
|
||||
latticeinput.push_back(m_extraLatticeBufferMultiUtt[i]);
|
||||
uids.insert(uids.end(), m_extraLabelsIDBufferMultiUtt[i].begin(), m_extraLabelsIDBufferMultiUtt[i].end());
|
||||
/* guoye: start */
|
||||
wids.insert(wids.end(), m_extraWLabelsIDBufferMultiUtt[i].begin(), m_extraWLabelsIDBufferMultiUtt[i].end());
|
||||
|
||||
nws.insert(nws.end(), m_extraNWsBufferMultiUtt[i].begin(), m_extraNWsBufferMultiUtt[i].end());
|
||||
|
||||
/* guoye: end */
|
||||
boundaries.insert(boundaries.end(), m_extraPhoneboundaryIDBufferMultiUtt[i].begin(), m_extraPhoneboundaryIDBufferMultiUtt[i].end());
|
||||
}
|
||||
|
||||
|
@ -1174,11 +1126,8 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
m_extraLabelsIDBufferMultiUtt.clear();
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.clear();
|
||||
m_extraSeqsPerMB.clear();
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.clear();
|
||||
|
||||
m_extraNWsBufferMultiUtt.clear();
|
||||
/* guoye: end */
|
||||
if (m_noData && m_numFramesToProcess[0] == 0) // no data left for the first channel of this minibatch,
|
||||
{
|
||||
return false;
|
||||
|
@ -1259,11 +1208,8 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
{
|
||||
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[i]);
|
||||
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[i]);
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[i]);
|
||||
|
||||
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[i]);
|
||||
/* guoye: end */
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1306,12 +1252,9 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
{
|
||||
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[src]);
|
||||
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[src]);
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[src]);
|
||||
|
||||
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[src]);
|
||||
|
||||
/* guoye: end */
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[src]);
|
||||
}
|
||||
|
||||
|
@ -2017,15 +1960,12 @@ bool HTKMLFReader<ElemType>::ReNewBufferForMultiIO(size_t i)
|
|||
m_phoneboundaryIDBufferMultiUtt[i] = m_mbiter->bounds();
|
||||
m_labelsIDBufferMultiUtt[i].clear();
|
||||
m_labelsIDBufferMultiUtt[i] = m_mbiter->labels();
|
||||
/* guoye: start */
|
||||
m_wlabelsIDBufferMultiUtt[i].clear();
|
||||
m_wlabelsIDBufferMultiUtt[i] = m_mbiter->wlabels();
|
||||
|
||||
m_nwsBufferMultiUtt[i].clear();
|
||||
m_nwsBufferMultiUtt[i] = m_mbiter->nwords();
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
m_processedFrame[i] = 0;
|
||||
|
|
|
@ -855,7 +855,6 @@ void lattice::parallelstate::getedgealignments(std::vector<unsigned short>& edge
|
|||
{
|
||||
pimpl->getedgealignments(edgealignments);
|
||||
}
|
||||
/* guoye: start */
|
||||
void lattice::parallelstate::getlogbetas(std::vector<double>& logbetas)
|
||||
{
|
||||
pimpl->getlogbetas(logbetas);
|
||||
|
@ -877,7 +876,6 @@ void lattice::parallelstate::setedgeweights(const std::vector<double>& edgeweigh
|
|||
|
||||
|
||||
|
||||
/* guoye: end */
|
||||
//template<class ElemType>
|
||||
void lattice::parallelstate::setloglls(const Microsoft::MSR::CNTK::Matrix<float>& loglls)
|
||||
{
|
||||
|
@ -1062,10 +1060,6 @@ double lattice::parallelforwardbackwardlattice(parallelstate& parallelstate, con
|
|||
return totalfwscore;
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
|
||||
// parallelforwardbackwardlattice() -- compute the latticelevel logpps using forwardbackward
|
||||
double lattice::parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
|
||||
const float lmf, const float wp, const float amf,
|
||||
std::vector<double>& edgelogbetas, std::vector<double>& logbetas) const
|
||||
|
@ -1128,7 +1122,6 @@ double lattice::parallelbackwardlatticeEMBR(parallelstate& parallelstate, const
|
|||
}
|
||||
return totalbwscore;
|
||||
}
|
||||
/* guoye: end */
|
||||
// ------------------------------------------------------------------------
|
||||
// parallel implementations of sMBR error updating step
|
||||
// ------------------------------------------------------------------------
|
||||
|
@ -1168,7 +1161,6 @@ void lattice::parallelsMBRerrorsignal(parallelstate& parallelstate, const edgeal
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// ------------------------------------------------------------------------
|
||||
void lattice::parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
|
||||
const std::vector<double>& edgeweights,
|
||||
|
@ -1196,8 +1188,6 @@ void lattice::parallelEMBRerrorsignal(parallelstate& parallelstate, const edgeal
|
|||
emulateEMBRerrorsignal(thisedgealignments.getalignmentsbuffer(), thisedgealignments.getalignoffsets(), edges, nodes, edgeweights, errorsignal);
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// parallel implementations of MMI error updating step
|
||||
// ------------------------------------------------------------------------
|
||||
|
|
Загрузка…
Ссылка в новой задаче