Adding interface for utterance derivative computation in the Kaldi2Reader
This commit is contained in:
Родитель
17eae08d98
Коммит
85f5777a52
|
@ -48,7 +48,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_mbiter = NULL;
|
||||
m_frameSource = NULL;
|
||||
m_lattices = NULL;
|
||||
m_sequenceTrainingIO = NULL;
|
||||
m_seqTrainDeriv = NULL;
|
||||
m_uttDerivBuffer = NULL;
|
||||
m_minibatchBuffer.resize(0);
|
||||
m_minibatchBufferIndex = 0;
|
||||
m_minibatchBufferLeftovers = 0;
|
||||
|
@ -56,6 +57,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_convertLabelsToTargets = false;
|
||||
m_doSeqTrain = false;
|
||||
m_getMinibatchCopy = false;
|
||||
m_doMinibatchBuffering = false;
|
||||
|
||||
if (readerConfig.Exists("legacyMode"))
|
||||
{
|
||||
|
@ -172,47 +174,59 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
aliRspecifier = wstring(aliConfig("rx"));
|
||||
|
||||
// Initializes sequence training interface.
|
||||
m_sequenceTrainingIO = new KaldiSequenceTrainingIO<ElemType>(
|
||||
denlatRspecifier, aliRspecifier, transModelFilename,
|
||||
silencePhoneStr, m_seqTrainCriterion, oldAcousticScale,
|
||||
acousticScale, lmScale,
|
||||
oneSilenceClass, m_numberOfuttsPerMinibatch);
|
||||
|
||||
// Scans the configurations to get "seqTrainDeriv" type input and
|
||||
// "seqTrainObj" type input. Both are feature nodes, we feed derivatives
|
||||
// to training criterion node through "seqTrainDeriv" and feed objective
|
||||
// through "seqTrainObj".
|
||||
// Scans the configurations to get "readerDeriv" type input and
|
||||
// "readerObj" type input. Both are feature nodes, we feed derivatives
|
||||
// to training criterion node through "readerDeriv" and feed objective
|
||||
// through "readerObj".
|
||||
bool hasDrive = false, hasObj = false;
|
||||
for (auto iter = readerConfig.begin(); iter != readerConfig.end(); ++iter)
|
||||
{
|
||||
ConfigParameters temp = iter->second;
|
||||
if (temp.ExistsCurrent("type"))
|
||||
{
|
||||
if (temp("type") == "seqTrainDeriv")
|
||||
if (temp("type") == "readerDeriv"
|
||||
|| temp("type") == "seqTrainDeriv" /*for back compatibility */)
|
||||
{
|
||||
m_nameToTypeMap[msra::strfun::utf16(iter->first)] = InputOutputTypes::seqTrainDeriv;
|
||||
m_nameToTypeMap[msra::strfun::utf16(iter->first)] = InputOutputTypes::readerDeriv;
|
||||
hasDrive = true;
|
||||
}
|
||||
else if (temp("type") == "seqTrainObj")
|
||||
else if (temp("type") == "readerObj"
|
||||
|| temp("type") == "seqTrainObj" /*for back compatibility */)
|
||||
{
|
||||
m_nameToTypeMap[msra::strfun::utf16(iter->first)] = InputOutputTypes::seqTrainObj;
|
||||
m_nameToTypeMap[msra::strfun::utf16(iter->first)] = InputOutputTypes::readerObj;
|
||||
hasObj = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!hasDrive || !hasObj)
|
||||
{
|
||||
LogicError("Missing seqTrainDeriv or seqTrainObj type feature\n");
|
||||
LogicError("Missing readerDeriv or readerObj type feature\n");
|
||||
}
|
||||
|
||||
// Initializes sequence training interface.
|
||||
m_seqTrainDeriv = new KaldiSequenceTrainingDerivative<ElemType>(
|
||||
denlatRspecifier, aliRspecifier, transModelFilename,
|
||||
silencePhoneStr, m_seqTrainCriterion, oldAcousticScale,
|
||||
acousticScale, lmScale, oneSilenceClass);
|
||||
|
||||
// Initializes derivative buffering.
|
||||
m_doMinibatchBuffering = true;
|
||||
if (m_uttDerivBuffer != NULL)
|
||||
{
|
||||
LogicError("Derivative buffer has already been set, are you doing "
|
||||
"sequence with some other metric that using derivative "
|
||||
"buffering?\n");
|
||||
}
|
||||
m_uttDerivBuffer = new UtteranceDerivativeBuffer<ElemType>(
|
||||
m_numberOfuttsPerMinibatch, m_seqTrainDeriv);
|
||||
}
|
||||
|
||||
// Loads input and output data for training and testing. Below we list the
|
||||
// categories for different input/output:
|
||||
// features: InputOutputTypes::real
|
||||
// labels: InputOutputTypes::category
|
||||
// derivatives: InputOutputTypes::seqTrainDeriv
|
||||
// objectives: InputOutputTypes::seqTrainObj
|
||||
// derivatives: InputOutputTypes::readerDeriv
|
||||
// objectives: InputOutputTypes::readerObj
|
||||
//
|
||||
// Note that we treat <derivatives> and <objectives> as features, but they
|
||||
// will be computed in the reader, rather then reading from disks. Those
|
||||
|
@ -225,10 +239,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
PrepareForSequenceTraining(readerConfig);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_sequenceTrainingIO = NULL;
|
||||
}
|
||||
|
||||
// Variables related to multi-utterance.
|
||||
// m_featuresBufferMultiUtt:
|
||||
|
@ -672,11 +682,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
delete m_lattices;
|
||||
m_lattices = NULL;
|
||||
}
|
||||
if (m_sequenceTrainingIO != NULL)
|
||||
if (m_seqTrainDeriv != NULL)
|
||||
{
|
||||
delete m_sequenceTrainingIO;
|
||||
m_sequenceTrainingIO = NULL;
|
||||
delete m_seqTrainDeriv;
|
||||
m_seqTrainDeriv = NULL;
|
||||
}
|
||||
if (m_uttDerivBuffer != NULL)
|
||||
{
|
||||
delete m_uttDerivBuffer;
|
||||
m_uttDerivBuffer = NULL;
|
||||
}
|
||||
|
||||
|
||||
if (!m_featuresBufferMultiIO.empty())
|
||||
{
|
||||
|
@ -803,11 +819,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t currentMBSize = (m_framemode == true) ? mbSize : 1;
|
||||
m_mbiter = new msra::dbn::minibatchiterator(*source, epoch, requestedEpochSamples, currentMBSize, datapasses);
|
||||
|
||||
// Resets sequence training class.
|
||||
if (m_doSeqTrain)
|
||||
// Resets utterance derivative buffering class.
|
||||
m_doMinibatchBuffering = true;
|
||||
if (m_doMinibatchBuffering)
|
||||
{
|
||||
assert(m_sequenceTrainingIO != NULL);
|
||||
m_sequenceTrainingIO->ResetEpoch();
|
||||
assert(m_uttDerivBuffer != NULL);
|
||||
m_uttDerivBuffer->ResetEpoch();
|
||||
}
|
||||
|
||||
// Clears minibatch buffer.
|
||||
|
@ -1021,9 +1038,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_checkDictionaryKeys=false;
|
||||
}
|
||||
|
||||
// If we are doing sequence training, we need to keep the utterance
|
||||
// information.
|
||||
if (m_doSeqTrain)
|
||||
// If we are doing utterance derivative buffering, we need to keep the
|
||||
// utterance information.
|
||||
if (m_doMinibatchBuffering)
|
||||
{
|
||||
m_minibatchUttInfo.assign(m_numberOfuttsPerMinibatch,
|
||||
std::vector<std::pair<wstring, size_t>>(0));
|
||||
|
@ -1100,7 +1117,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
endFrame = startFrame + m_currentMBSize;
|
||||
bool populateSucc = PopulateUtteranceInMinibatch(matrices, i, startFrame, endFrame, m_currentMBSize);
|
||||
if (m_doSeqTrain && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
if (m_doMinibatchBuffering && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
m_processedFrame[i] += m_currentMBSize;
|
||||
}
|
||||
else if ((startFrame + m_currentMBSize) == m_toProcess[i])
|
||||
|
@ -1134,7 +1151,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// next one.
|
||||
endFrame = startFrame + m_currentMBSize;
|
||||
bool populateSucc = PopulateUtteranceInMinibatch(matrices, i, startFrame, endFrame, m_currentMBSize);
|
||||
if (m_doSeqTrain && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
if (m_doMinibatchBuffering && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
m_processedFrame[i] += m_currentMBSize;
|
||||
bool reNewSucc = ReNewBufferForMultiIO(i);
|
||||
}
|
||||
|
@ -1196,7 +1213,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
endFrame = m_toProcess[i];
|
||||
size_t currentMBFilled = endFrame - startFrame;
|
||||
bool populateSucc = PopulateUtteranceInMinibatch(matrices, i, startFrame, endFrame, m_currentMBSize);
|
||||
if (m_doSeqTrain && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
if (m_doMinibatchBuffering && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
m_processedFrame[i] += currentMBFilled;
|
||||
bool reNewSucc = ReNewBufferForMultiIO(i);
|
||||
|
||||
|
@ -1211,7 +1228,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_sentenceBegin.SetValue(i, currentMBFilled + m_toProcess[i] - 1, (ElemType)SEQUENCE_END);
|
||||
m_minibatchPackingFlag[currentMBFilled + m_toProcess[i] - 1] |= MinibatchPackingFlag::SequenceEnd;
|
||||
populateSucc = PopulateUtteranceInMinibatch(matrices, i, 0, m_toProcess[i], m_currentMBSize, currentMBFilled);
|
||||
if (m_doSeqTrain && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
if (m_doMinibatchBuffering && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
assert(m_processedFrame[i] == 0);
|
||||
m_processedFrame[i] = m_toProcess[i];
|
||||
currentMBFilled += m_toProcess[i];
|
||||
|
@ -1223,7 +1240,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (reNewSucc && !m_framemode && m_truncated)
|
||||
{
|
||||
populateSucc = PopulateUtteranceInMinibatch(matrices, i, 0, m_currentMBSize - currentMBFilled, m_currentMBSize, currentMBFilled);
|
||||
if (m_doSeqTrain && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
if (m_doMinibatchBuffering && populateSucc) { m_minibatchUttInfo[i].push_back(m_uttInfo[i][0]); }
|
||||
m_processedFrame[i] += m_currentMBSize - currentMBFilled;
|
||||
if (currentMBFilled < m_currentMBSize)
|
||||
{
|
||||
|
@ -1257,7 +1274,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template<class ElemType>
|
||||
bool HTKMLFReader<ElemType>::ShouldCopyMinibatchFromBuffer()
|
||||
{
|
||||
if (m_doSeqTrain)
|
||||
if (m_doMinibatchBuffering)
|
||||
{
|
||||
// If <m_getMinibatchCopy> is false, then we should copy data from
|
||||
// buffer for back-propagation.
|
||||
|
@ -1359,18 +1376,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_minibatchBuffer[index].labels[id].data(),
|
||||
matrixFlagNormal);
|
||||
}
|
||||
else if (m_doSeqTrain && !m_getMinibatchCopy)
|
||||
else if (m_doMinibatchBuffering && !m_getMinibatchCopy)
|
||||
{
|
||||
if (m_nameToTypeMap[iter->first] == InputOutputTypes::seqTrainDeriv)
|
||||
if (m_nameToTypeMap[iter->first] == InputOutputTypes::readerDeriv)
|
||||
{
|
||||
m_sequenceTrainingIO->GetDerivative(
|
||||
m_uttDerivBuffer->GetDerivative(
|
||||
m_minibatchUttInfo, m_sentenceBegin,
|
||||
m_minibatchPackingFlag, matrices[iter->first]);
|
||||
}
|
||||
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::seqTrainObj)
|
||||
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::readerObj)
|
||||
{
|
||||
m_sequenceTrainingIO->GetObjective(m_minibatchUttInfo,
|
||||
matrices[iter->first]);
|
||||
m_uttDerivBuffer->GetObjective(m_minibatchUttInfo,
|
||||
matrices[iter->first]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1408,14 +1425,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
assert(id < labelBuffer.size());
|
||||
data.SetValue(dim, size, labelBuffer[id], matrixFlagNormal);
|
||||
}
|
||||
else if (m_doSeqTrain)
|
||||
else if (m_doMinibatchBuffering)
|
||||
{
|
||||
if (m_nameToTypeMap[iter->first] == InputOutputTypes::seqTrainDeriv)
|
||||
if (m_nameToTypeMap[iter->first] == InputOutputTypes::readerDeriv)
|
||||
{
|
||||
data.Resize(data.GetNumRows(), m_currentMBSize);
|
||||
data.SetValue(0);
|
||||
}
|
||||
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::seqTrainObj)
|
||||
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::readerObj)
|
||||
{
|
||||
data.SetValue(0);
|
||||
}
|
||||
|
@ -1453,9 +1470,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// If we are in the "copy" mode, and we cannot get a full minibatch,
|
||||
// then we have computed the posteriors for all the minibatches.
|
||||
if (m_doSeqTrain && !success && m_getMinibatchCopy)
|
||||
if (m_doMinibatchBuffering && !success && m_getMinibatchCopy)
|
||||
{
|
||||
m_sequenceTrainingIO->SetEpochEnd();
|
||||
m_uttDerivBuffer->SetEpochEnd();
|
||||
}
|
||||
|
||||
return success;
|
||||
|
@ -1637,8 +1654,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return ReNewBufferForMultiIO(i);
|
||||
}
|
||||
|
||||
if (m_doSeqTrain && !m_sequenceTrainingIO->HasLatticeAndAlignment(
|
||||
m_uttInfo[i][0].first))
|
||||
if (m_doMinibatchBuffering
|
||||
&& !m_uttDerivBuffer->HasResourceForDerivative(
|
||||
m_uttInfo[i][0].first))
|
||||
{
|
||||
(*m_mbiter)++;
|
||||
if (!(*m_mbiter))
|
||||
|
@ -1646,14 +1664,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_noData = true;
|
||||
}
|
||||
fprintf(stderr, "WARNING: Utterance \"%S\" does not have "
|
||||
"lattice or alignment, skipping it.\n",
|
||||
"resource to compute derivative, skipping it.\n",
|
||||
m_uttInfo[i][0].first.c_str());
|
||||
return ReNewBufferForMultiIO(i);
|
||||
}
|
||||
|
||||
// We don't support having two utterances in the same buffer.
|
||||
if (m_doSeqTrain &&
|
||||
m_sequenceTrainingIO->HasUtterance(m_uttInfo[i][0].first))
|
||||
if (m_doMinibatchBuffering &&
|
||||
m_uttDerivBuffer->HasUtterance(m_uttInfo[i][0].first))
|
||||
{
|
||||
(*m_mbiter)++;
|
||||
if (!(*m_mbiter))
|
||||
|
@ -1813,10 +1831,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
// We need to get a "copy" of the minibatch to do the forward
|
||||
// computation for sequence training.
|
||||
if (m_doSeqTrain)
|
||||
if (m_doMinibatchBuffering)
|
||||
{
|
||||
assert(m_framemode == false);
|
||||
if (m_sequenceTrainingIO->NeedLikelihoodToComputeDerivative())
|
||||
if (m_uttDerivBuffer->NeedLikelihoodToComputeDerivative())
|
||||
{
|
||||
m_getMinibatchCopy = true;
|
||||
if (GetMinibatchToTrainOrTest(matrices))
|
||||
|
@ -1843,14 +1861,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
// Set the likelihoods for the utterance with which we can comput the
|
||||
// derivatives. Note that the minibatch may only contain partial output
|
||||
// for the utterance, <m_sequenceTrainingIO> takes care of "pasting"
|
||||
// them together.
|
||||
if (m_doSeqTrain)
|
||||
// for the utterance, <m_uttDerivBuffer> takes care of "gluing" them
|
||||
// together.
|
||||
if (m_doMinibatchBuffering)
|
||||
{
|
||||
assert(m_framemode == false);
|
||||
return m_sequenceTrainingIO->SetLikelihood(uttInfo, outputs,
|
||||
sentenceBegin,
|
||||
minibatchPackingFlag);
|
||||
return m_uttDerivBuffer->SetLikelihood(uttInfo, outputs,
|
||||
sentenceBegin,
|
||||
minibatchPackingFlag);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
|
||||
#pragma once
|
||||
#include "DataReader.h"
|
||||
#include "KaldiSequenceTrainingIO.h"
|
||||
#include "KaldiSequenceTrainingDerivative.h"
|
||||
#include "UtteranceDerivativeBuffer.h"
|
||||
#include "commandArgUtil.h" // for intargvector
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
@ -25,6 +26,11 @@ private:
|
|||
map<wstring,msra::lattices::lattice::htkmlfwordsequence> m_latticeMap;
|
||||
|
||||
// Sequence training realted members.
|
||||
bool m_doSeqTrain;
|
||||
wstring m_seqTrainCriterion;
|
||||
KaldiSequenceTrainingDerivative<ElemType>* m_seqTrainDeriv;
|
||||
|
||||
// Minibatch buffering.
|
||||
struct MinibatchBufferUnit
|
||||
{
|
||||
std::vector<std::vector<ElemType>> features;
|
||||
|
@ -33,14 +39,15 @@ private:
|
|||
vector<MinibatchPackingFlag> minibatchPackingFlag;
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>> minibatchUttInfo;
|
||||
size_t currentMBSize;
|
||||
};
|
||||
bool m_doSeqTrain;
|
||||
};
|
||||
bool m_doMinibatchBuffering;
|
||||
bool m_getMinibatchCopy;
|
||||
size_t m_minibatchBufferIndex;
|
||||
size_t m_minibatchBufferLeftovers;
|
||||
wstring m_seqTrainCriterion;
|
||||
KaldiSequenceTrainingIO<ElemType>* m_sequenceTrainingIO;
|
||||
std::deque<MinibatchBufferUnit> m_minibatchBuffer;
|
||||
UtteranceDerivativeBuffer<ElemType>* m_uttDerivBuffer;
|
||||
|
||||
// Utterance information.
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>> m_uttInfo;
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>> m_minibatchUttInfo;
|
||||
|
||||
|
@ -136,8 +143,8 @@ private:
|
|||
{
|
||||
real,
|
||||
category,
|
||||
seqTrainDeriv, /*sequence training derivative, computed in the reader*/
|
||||
seqTrainObj, /*sequence training objective, computed in the reader*/
|
||||
readerDeriv, /*derivative computed in the reader*/
|
||||
readerObj, /*objective computed in the reader*/
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
#include "basetypes.h"
|
||||
#include "htkfeatio_utils.h"
|
||||
#include "KaldiSequenceTrainingDerivative.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// Constructor.
|
||||
template<class ElemType>
|
||||
KaldiSequenceTrainingDerivative<ElemType>::KaldiSequenceTrainingDerivative(
|
||||
const wstring& denlatRspecifier, const wstring& aliRspecifier,
|
||||
const wstring& transModelFilename, const wstring& silencePhoneStr,
|
||||
const wstring& trainCriterion,
|
||||
ElemType oldAcousticScale, ElemType acousticScale,
|
||||
ElemType lmScale, bool oneSilenceClass)
|
||||
{
|
||||
using namespace msra::asr;
|
||||
assert(denlatRspecifier != L"");
|
||||
assert(aliRspecifier != L"");
|
||||
m_denlatReader = new kaldi::RandomAccessCompactLatticeReader(
|
||||
trimmed(fileToStr(toStr(denlatRspecifier))));
|
||||
m_aliReader = new kaldi::RandomAccessInt32VectorReader(
|
||||
trimmed(fileToStr(toStr(aliRspecifier))));
|
||||
ReadKaldiObject(toStr(transModelFilename), &m_transModel);
|
||||
m_oldAcousticScale = oldAcousticScale;
|
||||
m_acousticScale = acousticScale;
|
||||
m_lmScale = lmScale;
|
||||
m_trainCriterion = trainCriterion;
|
||||
m_oneSilenceClass = oneSilenceClass;
|
||||
if (!kaldi::SplitStringToIntegers(toStr(silencePhoneStr),
|
||||
":", false, &m_silencePhones))
|
||||
{
|
||||
LogicError("Invalid silence phone sequence.\n");
|
||||
}
|
||||
if (m_trainCriterion != L"mpfe" && m_trainCriterion != L"smbr")
|
||||
{
|
||||
LogicError("Supported sequence training criterion: mpfe, smbr.\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Destructor.
|
||||
template<class ElemType>
|
||||
KaldiSequenceTrainingDerivative<ElemType>::~KaldiSequenceTrainingDerivative()
|
||||
{
|
||||
if (m_denlatReader != NULL)
|
||||
{
|
||||
delete m_denlatReader;
|
||||
m_denlatReader = NULL;
|
||||
}
|
||||
if (m_aliReader != NULL)
|
||||
{
|
||||
delete m_aliReader;
|
||||
m_aliReader = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingDerivative<ElemType>::ComputeDerivative(
|
||||
const wstring& uttID,
|
||||
const Matrix<ElemType>& logLikelihood,
|
||||
Matrix<ElemType>* derivative,
|
||||
ElemType* objective)
|
||||
{
|
||||
std::string uttIDStr = msra::asr::toStr(uttID);
|
||||
|
||||
// Sanity check.
|
||||
if (m_transModel.NumPdfs() != logLikelihood.GetNumRows())
|
||||
{
|
||||
RuntimeError("Number of labels in logLikelihood does not match that"
|
||||
" in the Kaldi model for utterance %S: %d v.s. %d\n",
|
||||
uttID.c_str(), logLikelihood.GetNumRows(),
|
||||
m_transModel.NumPdfs());
|
||||
}
|
||||
|
||||
// Reads alignment.
|
||||
if (!m_aliReader->HasKey(uttIDStr))
|
||||
{
|
||||
RuntimeError("Alignment not found for utterance %s\n",
|
||||
uttIDStr.c_str());
|
||||
}
|
||||
const std::vector<int32> ali = m_aliReader->Value(uttIDStr);
|
||||
if (ali.size() != logLikelihood.GetNumCols())
|
||||
{
|
||||
RuntimeError("Number of frames in logLikelihood does not match that"
|
||||
" in the alignment for utterance %S: %d v.s. %d\n",
|
||||
uttID.c_str(), logLikelihood.GetNumCols(), ali.size());
|
||||
}
|
||||
|
||||
// Reads denominator lattice.
|
||||
if (!m_denlatReader->HasKey(uttIDStr))
|
||||
{
|
||||
RuntimeError("Denominator lattice not found for utterance %S\n",
|
||||
uttID.c_str());
|
||||
}
|
||||
kaldi::CompactLattice clat = m_denlatReader->Value(uttIDStr);
|
||||
fst::CreateSuperFinal(&clat); /* One final state with weight One() */
|
||||
kaldi::Lattice lat;
|
||||
fst::ConvertLattice(clat, &lat);
|
||||
|
||||
// Does a first path of acoustic scaling. Typically this sets the old
|
||||
// acoustic scale to 0.
|
||||
if (m_oldAcousticScale != 1.0)
|
||||
{
|
||||
fst::ScaleLattice(fst::AcousticLatticeScale(m_oldAcousticScale),
|
||||
&lat);
|
||||
}
|
||||
|
||||
// Topsort lattice.
|
||||
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
|
||||
if (!(props & fst::kTopSorted))
|
||||
{
|
||||
if (fst::TopSort(&lat) == false)
|
||||
{
|
||||
RuntimeError("Cycles detected in lattice: %S\n", uttID.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Does lattice acoustic rescoring with the new posteriors from the
|
||||
// neural network.
|
||||
LatticeAcousticRescore(uttID, logLikelihood, &lat);
|
||||
|
||||
// Second pass acoustic and language model scale.
|
||||
if (m_acousticScale != 1.0 || m_lmScale != 1.0)
|
||||
{
|
||||
fst::ScaleLattice(fst::LatticeScale(m_lmScale, m_acousticScale),
|
||||
&lat);
|
||||
}
|
||||
|
||||
// Forward-backward on the lattice.
|
||||
kaldi::Posterior post, pdfPost;
|
||||
if (m_trainCriterion == L"smbr")
|
||||
{
|
||||
*objective = kaldi::LatticeForwardBackwardMpeVariants(
|
||||
m_transModel, m_silencePhones, lat,
|
||||
ali, "smbr", m_oneSilenceClass, &post);
|
||||
}
|
||||
else if (m_trainCriterion == L"mpfe")
|
||||
{
|
||||
*objective = kaldi::LatticeForwardBackwardMpeVariants(
|
||||
m_transModel, m_silencePhones, lat,
|
||||
ali, "mpfe", m_oneSilenceClass, &post);
|
||||
}
|
||||
|
||||
ConvertPosteriorToDerivative(post, derivative);
|
||||
assert(derivative->GetNumCols() == logLikelihood.GetNumCols());
|
||||
|
||||
// Uses "expected error rate" instead of "expected accuracy".
|
||||
*objective = logLikelihood.GetNumCols() - *objective;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void KaldiSequenceTrainingDerivative<ElemType>::ConvertPosteriorToDerivative(
|
||||
const kaldi::Posterior& post,
|
||||
Matrix<ElemType>* derivative)
|
||||
{
|
||||
kaldi::Posterior pdfPost;
|
||||
kaldi::ConvertPosteriorToPdfs(m_transModel, post, &pdfPost);
|
||||
|
||||
derivative->Resize(m_transModel.NumPdfs(), pdfPost.size());
|
||||
derivative->SetValue(0);
|
||||
|
||||
for (size_t t = 0; t < pdfPost.size(); ++t)
|
||||
{
|
||||
for (size_t i = 0; i < pdfPost[t].size(); ++i)
|
||||
{
|
||||
size_t pdf_id = pdfPost[t][i].first;
|
||||
assert(pdf_id < m_transModel.NumPdfs());
|
||||
// Flips the sign below.
|
||||
(*derivative)(pdf_id, t) -= pdfPost[t][i].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void KaldiSequenceTrainingDerivative<ElemType>::LatticeAcousticRescore(
|
||||
const wstring& uttID,
|
||||
const Matrix<ElemType>& logLikelihood,
|
||||
kaldi::Lattice* lat) const
|
||||
{
|
||||
// Gets time information for the lattice.
|
||||
std::vector<kaldi::int32> stateTimes;
|
||||
kaldi::int32 maxTime = kaldi::LatticeStateTimes(*lat, &stateTimes);
|
||||
if (maxTime != logLikelihood.GetNumCols())
|
||||
{
|
||||
RuntimeError("Number of frames in the logLikelihood does not match"
|
||||
" that in the denominator lattice for utterance %S\n",
|
||||
uttID.c_str(), logLikelihood.GetNumRows(), maxTime);
|
||||
}
|
||||
|
||||
std::vector<std::vector<kaldi::int32>> timeStateMap(
|
||||
logLikelihood.GetNumCols());
|
||||
size_t num_states = lat->NumStates();
|
||||
for (size_t s = 0; s < num_states; s++)
|
||||
{
|
||||
assert(stateTimes[s] >= 0
|
||||
&& stateTimes[s] <= logLikelihood.GetNumCols());
|
||||
if (stateTimes[s] < logLikelihood.GetNumCols())
|
||||
{
|
||||
timeStateMap[stateTimes[s]].push_back(s);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < logLikelihood.GetNumCols(); ++t)
|
||||
{
|
||||
for (size_t i = 0; i < timeStateMap[t].size(); ++i)
|
||||
{
|
||||
kaldi::int32 state = timeStateMap[t][i];
|
||||
for (fst::MutableArcIterator<kaldi::Lattice> aiter(lat, state);
|
||||
!aiter.Done(); aiter.Next())
|
||||
{
|
||||
kaldi::LatticeArc arc = aiter.Value();
|
||||
kaldi::int32 trans_id = arc.ilabel;
|
||||
if (trans_id != 0)
|
||||
{
|
||||
kaldi::int32 pdf_id =
|
||||
m_transModel.TransitionIdToPdf(trans_id);
|
||||
arc.weight.SetValue2(-logLikelihood(pdf_id, t)
|
||||
+ arc.weight.Value2());
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
}
|
||||
// Checks final state.
|
||||
kaldi::LatticeWeight final = lat->Final(state);
|
||||
if (final != kaldi::LatticeWeight::Zero())
|
||||
{
|
||||
final.SetValue2(0.0);
|
||||
lat->SetFinal(state, final);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingDerivative<ElemType>::HasResourceForDerivative(
|
||||
const wstring& uttID) const
|
||||
{
|
||||
if(m_aliReader == false || m_denlatReader == false)
|
||||
{
|
||||
fprintf(stderr, "WARNING: lattice or alignemnt reader has not been"
|
||||
" set up yet.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string uttIDStr = msra::asr::toStr(uttID);
|
||||
if(!m_aliReader->HasKey(uttIDStr) || !m_denlatReader->HasKey(uttIDStr))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template class KaldiSequenceTrainingDerivative<float>;
|
||||
template class KaldiSequenceTrainingDerivative<double>;
|
||||
}}}
|
|
@ -0,0 +1,58 @@
|
|||
#pragma once
|
||||
|
||||
#include "kaldi.h"
|
||||
#include "Matrix.h"
|
||||
#include "basetypes.h"
|
||||
#include "UtteranceDerivativeComputationInterface.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// This class deals with the interaction with Kaldi in order to do sequence
|
||||
// in CNTK.
|
||||
template<class ElemType>
|
||||
class KaldiSequenceTrainingDerivative :
|
||||
public UtteranceDerivativeComputationInterface<ElemType>
|
||||
{
|
||||
private:
|
||||
bool m_oneSilenceClass;
|
||||
wstring m_trainCriterion;
|
||||
ElemType m_oldAcousticScale;
|
||||
ElemType m_acousticScale;
|
||||
ElemType m_lmScale;
|
||||
std::vector<kaldi::int32> m_silencePhones;
|
||||
kaldi::TransitionModel m_transModel;
|
||||
kaldi::RandomAccessCompactLatticeReader* m_denlatReader;
|
||||
kaldi::RandomAccessInt32VectorReader* m_aliReader;
|
||||
|
||||
// Rescores the lattice with the lastest posteriors from the neural network.
|
||||
void LatticeAcousticRescore(const wstring& uttID,
|
||||
const Matrix<ElemType>& outputs,
|
||||
kaldi::Lattice* lat) const;
|
||||
|
||||
void ConvertPosteriorToDerivative(const kaldi::Posterior& post,
|
||||
Matrix<ElemType>* derivative);
|
||||
|
||||
public:
|
||||
// Constructor.
|
||||
KaldiSequenceTrainingDerivative(const wstring& denlatRspecifier,
|
||||
const wstring& aliRspecifier,
|
||||
const wstring& transModelFilename,
|
||||
const wstring& silencePhoneStr,
|
||||
const wstring& trainCriterion,
|
||||
ElemType oldAcousticScale,
|
||||
ElemType acousticScale,
|
||||
ElemType lmScale,
|
||||
bool oneSilenceClass);
|
||||
|
||||
// Destructor.
|
||||
~KaldiSequenceTrainingDerivative();
|
||||
|
||||
bool ComputeDerivative(const wstring& uttID,
|
||||
const Matrix<ElemType>& logLikelihood,
|
||||
Matrix<ElemType>* derivative,
|
||||
ElemType* objective);
|
||||
|
||||
bool HasResourceForDerivative(const wstring& uttID) const;
|
||||
};
|
||||
|
||||
}}}
|
|
@ -1,228 +1,33 @@
|
|||
#include "basetypes.h"
|
||||
#include "htkfeatio_utils.h"
|
||||
#include "KaldiSequenceTrainingIO.h"
|
||||
#include "UtteranceDerivativeBuffer.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// Constructor.
|
||||
template<class ElemType>
|
||||
KaldiSequenceTrainingIO<ElemType>::KaldiSequenceTrainingIO(
|
||||
const wstring& denlatRspecifier, const wstring& aliRspecifier,
|
||||
const wstring& transModelFilename, const wstring& silencePhoneStr,
|
||||
const wstring& trainCriterion,
|
||||
ElemType oldAcousticScale, ElemType acousticScale,
|
||||
ElemType lmScale, bool oneSilenceClass, size_t numberOfuttsPerMinibatch)
|
||||
UtteranceDerivativeBuffer<ElemType>::UtteranceDerivativeBuffer(
|
||||
size_t numberOfuttsPerMinibatch,
|
||||
UtteranceDerivativeComputationInterface<ElemType>* derivativeInterface)
|
||||
{
|
||||
using namespace msra::asr;
|
||||
assert(denlatRspecifier != L"");
|
||||
assert(aliRspecifier != L"");
|
||||
m_denlatReader = new kaldi::RandomAccessCompactLatticeReader(
|
||||
trimmed(fileToStr(toStr(denlatRspecifier))));
|
||||
m_aliReader = new kaldi::RandomAccessInt32VectorReader(
|
||||
trimmed(fileToStr(toStr(aliRspecifier))));
|
||||
ReadKaldiObject(toStr(transModelFilename), &m_transModel);
|
||||
m_oldAcousticScale = oldAcousticScale;
|
||||
m_acousticScale = acousticScale;
|
||||
m_lmScale = lmScale;
|
||||
m_trainCriterion = trainCriterion;
|
||||
m_oneSilenceClass = oneSilenceClass;
|
||||
assert(derivativeInterface != NULL);
|
||||
m_derivativeInterface = derivativeInterface;
|
||||
m_numUttsPerMinibatch = numberOfuttsPerMinibatch;
|
||||
m_needLikelihood = true;
|
||||
m_currentObj = 0;
|
||||
m_minibatchIndex = 1;
|
||||
m_lastCompleteMinibatch.assign(m_numUttsPerMinibatch, 0);
|
||||
m_epochEnd = false;
|
||||
if (!kaldi::SplitStringToIntegers(toStr(silencePhoneStr),
|
||||
":", false, &m_silencePhones))
|
||||
{
|
||||
LogicError("Invalid silence phone sequence.\n");
|
||||
}
|
||||
if (m_trainCriterion != L"mpfe" && m_trainCriterion != L"smbr")
|
||||
{
|
||||
LogicError("Supported sequence training criterion: mpfe, smbr.\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Destructor.
|
||||
template<class ElemType>
|
||||
KaldiSequenceTrainingIO<ElemType>::~KaldiSequenceTrainingIO()
|
||||
{
|
||||
if (m_denlatReader != NULL)
|
||||
{
|
||||
delete m_denlatReader;
|
||||
m_denlatReader = NULL;
|
||||
}
|
||||
if (m_aliReader != NULL)
|
||||
{
|
||||
delete m_aliReader;
|
||||
m_aliReader = NULL;
|
||||
}
|
||||
m_dimension = 0;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingIO<ElemType>::ComputeDerivative(
|
||||
const wstring& uttID)
|
||||
{
|
||||
assert(m_uttPool.find(uttID) != m_uttPool.end());
|
||||
assert(m_uttPool[uttID].hasDerivative == false);
|
||||
Matrix<ElemType>& logLikelihood = m_uttPool[uttID].logLikelihood;
|
||||
|
||||
std::string uttIDStr = msra::asr::toStr(uttID);
|
||||
|
||||
// Sanity check.
|
||||
if (m_transModel.NumPdfs() != logLikelihood.GetNumRows())
|
||||
{
|
||||
RuntimeError("Number of labels in logLikelihood does not match that"
|
||||
" in the Kaldi model for utterance %S: %d v.s. %d\n",
|
||||
uttID.c_str(), logLikelihood.GetNumRows(),
|
||||
m_transModel.NumPdfs());
|
||||
}
|
||||
|
||||
// Reads alignment.
|
||||
if (!m_aliReader->HasKey(uttIDStr))
|
||||
{
|
||||
RuntimeError("Alignment not found for utterance %s\n",
|
||||
uttIDStr.c_str());
|
||||
}
|
||||
const std::vector<int32> ali = m_aliReader->Value(uttIDStr);
|
||||
if (ali.size() != logLikelihood.GetNumCols())
|
||||
{
|
||||
RuntimeError("Number of frames in logLikelihood does not match that"
|
||||
" in the alignment for utterance %S: %d v.s. %d\n",
|
||||
uttID.c_str(), logLikelihood.GetNumCols(), ali.size());
|
||||
}
|
||||
|
||||
// Reads denominator lattice.
|
||||
if (!m_denlatReader->HasKey(uttIDStr))
|
||||
{
|
||||
RuntimeError("Denominator lattice not found for utterance %S\n",
|
||||
uttID.c_str());
|
||||
}
|
||||
kaldi::CompactLattice clat = m_denlatReader->Value(uttIDStr);
|
||||
fst::CreateSuperFinal(&clat); /* One final state with weight One() */
|
||||
kaldi::Lattice lat;
|
||||
fst::ConvertLattice(clat, &lat);
|
||||
|
||||
// Does a first path of acoustic scaling. Typically this sets the old
|
||||
// acoustic scale to 0.
|
||||
if (m_oldAcousticScale != 1.0)
|
||||
{
|
||||
fst::ScaleLattice(fst::AcousticLatticeScale(m_oldAcousticScale),
|
||||
&lat);
|
||||
}
|
||||
|
||||
// Topsort lattice.
|
||||
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
|
||||
if (!(props & fst::kTopSorted))
|
||||
{
|
||||
if (fst::TopSort(&lat) == false)
|
||||
{
|
||||
RuntimeError("Cycles detected in lattice: %S\n", uttID.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Gets time information for the lattice.
|
||||
std::vector<kaldi::int32> stateTimes;
|
||||
kaldi::int32 maxTime = kaldi::LatticeStateTimes(lat, &stateTimes);
|
||||
if (maxTime != logLikelihood.GetNumCols())
|
||||
{
|
||||
RuntimeError("Number of frames in the logLikelihood does not match"
|
||||
" that in the denominator lattice for utterance %S\n",
|
||||
uttID.c_str(), logLikelihood.GetNumRows(), maxTime);
|
||||
}
|
||||
|
||||
// Does lattice acoustic rescoring with the new posteriors from the
|
||||
// neural network.
|
||||
LatticeAcousticRescore(stateTimes, logLikelihood, &lat);
|
||||
|
||||
// Second pass acoustic and language model scale.
|
||||
if (m_acousticScale != 1.0 || m_lmScale != 1.0)
|
||||
{
|
||||
fst::ScaleLattice(fst::LatticeScale(m_lmScale, m_acousticScale),
|
||||
&lat);
|
||||
}
|
||||
|
||||
// Forward-backward on the lattice.
|
||||
kaldi::Posterior post;
|
||||
ElemType thisObj = 0;
|
||||
if (m_trainCriterion == L"smbr")
|
||||
{
|
||||
thisObj = kaldi::LatticeForwardBackwardMpeVariants(
|
||||
m_transModel, m_silencePhones, lat,
|
||||
ali, "smbr", m_oneSilenceClass, &post);
|
||||
}
|
||||
else if (m_trainCriterion == L"mpfe")
|
||||
{
|
||||
thisObj = kaldi::LatticeForwardBackwardMpeVariants(
|
||||
m_transModel, m_silencePhones, lat,
|
||||
ali, "mpfe", m_oneSilenceClass, &post);
|
||||
}
|
||||
|
||||
kaldi::ConvertPosteriorToPdfs(m_transModel,
|
||||
post, &(m_uttPool[uttID].posterior));
|
||||
|
||||
// Uses "expected error rate" instead of "expected accuracy".
|
||||
m_uttPool[uttID].objective = logLikelihood.GetNumCols() - thisObj;
|
||||
|
||||
assert(m_uttPool[uttID].posterior.size() == logLikelihood.GetNumCols());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void KaldiSequenceTrainingIO<ElemType>::LatticeAcousticRescore(
|
||||
const std::vector<kaldi::int32>& stateTimes,
|
||||
const Matrix<ElemType>& logLikelihood, kaldi::Lattice* lat) const
|
||||
{
|
||||
std::vector<std::vector<kaldi::int32>> timeStateMap(
|
||||
logLikelihood.GetNumCols());
|
||||
size_t num_states = lat->NumStates();
|
||||
for (size_t s = 0; s < num_states; s++)
|
||||
{
|
||||
assert(stateTimes[s] >= 0
|
||||
&& stateTimes[s] <= logLikelihood.GetNumCols());
|
||||
if (stateTimes[s] < logLikelihood.GetNumCols())
|
||||
{
|
||||
timeStateMap[stateTimes[s]].push_back(s);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < logLikelihood.GetNumCols(); ++t)
|
||||
{
|
||||
for (size_t i = 0; i < timeStateMap[t].size(); ++i)
|
||||
{
|
||||
kaldi::int32 state = timeStateMap[t][i];
|
||||
for (fst::MutableArcIterator<kaldi::Lattice> aiter(lat, state);
|
||||
!aiter.Done(); aiter.Next())
|
||||
{
|
||||
kaldi::LatticeArc arc = aiter.Value();
|
||||
kaldi::int32 trans_id = arc.ilabel;
|
||||
if (trans_id != 0)
|
||||
{
|
||||
kaldi::int32 pdf_id =
|
||||
m_transModel.TransitionIdToPdf(trans_id);
|
||||
arc.weight.SetValue2(-logLikelihood(pdf_id, t)
|
||||
+ arc.weight.Value2());
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
}
|
||||
// Checks final state.
|
||||
kaldi::LatticeWeight final = lat->Final(state);
|
||||
if (final != kaldi::LatticeWeight::Zero())
|
||||
{
|
||||
final.SetValue2(0.0);
|
||||
lat->SetFinal(state, final);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void KaldiSequenceTrainingIO<ElemType>::ProcessUttInfo(
|
||||
void UtteranceDerivativeBuffer<ElemType>::ProcessUttInfo(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const Matrix<ElemType>& sentenceBegin,
|
||||
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag,
|
||||
std::vector<std::vector<std::pair<wstring, std::pair<size_t, size_t>>>>* uttInfoInMinibatch) const
|
||||
std::vector<std::vector<std::pair<
|
||||
wstring, std::pair<size_t, size_t>>>>* uttInfoInMinibatch) const
|
||||
{
|
||||
assert(uttInfoInMinibatch != NULL);
|
||||
assert(uttInfo.size() == m_numUttsPerMinibatch);
|
||||
|
@ -236,22 +41,23 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t numFrames = 0;
|
||||
for (size_t j = 0; j < sentenceBegin.GetNumCols(); ++j)
|
||||
{
|
||||
if (((size_t)sentenceBegin(i, j) & NO_LABEL) == NO_LABEL)
|
||||
if (((int)sentenceBegin(i, j) & NO_LABEL) == NO_LABEL)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (((size_t)sentenceBegin(i, j) & NO_FEATURE) == NO_FEATURE)
|
||||
if (((int)sentenceBegin(i, j) & NO_FEATURE) == NO_FEATURE)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
numFrames += 1;
|
||||
if ((((size_t)sentenceBegin(i, j) & SEQUENCE_END) == SEQUENCE_END)
|
||||
if ((((int)sentenceBegin(i, j) & SEQUENCE_END) == SEQUENCE_END)
|
||||
|| j == sentenceBegin.GetNumCols() - 1)
|
||||
{
|
||||
size_t uttIndex = (*uttInfoInMinibatch)[i].size();
|
||||
wstring uttID = uttInfo[i][uttIndex].first;
|
||||
(*uttInfoInMinibatch)[i].push_back(
|
||||
make_pair(uttID, make_pair(startFrameIndexInMinibatch, numFrames)));
|
||||
make_pair(uttID, make_pair(startFrameIndexInMinibatch,
|
||||
numFrames)));
|
||||
startFrameIndexInMinibatch = j + 1;
|
||||
numFrames = 0;
|
||||
}
|
||||
|
@ -266,7 +72,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// 2: a21 b21 c21 a22 b22 c22...
|
||||
// 3: a31 b31 c31 a32 b32 c32...
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingIO<ElemType>::SetLikelihood(
|
||||
bool UtteranceDerivativeBuffer<ElemType>::SetLikelihood(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const Matrix<ElemType>& logLikelihoodIn,
|
||||
const Matrix<ElemType>& sentenceBegin,
|
||||
|
@ -274,6 +80,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
assert(m_needLikelihood == true);
|
||||
assert(m_epochEnd == false);
|
||||
|
||||
if (m_dimension == 0)
|
||||
{
|
||||
m_dimension = logLikelihoodIn.GetNumRows();
|
||||
}
|
||||
assert(m_dimension == logLikelihoodIn.GetNumRows());
|
||||
|
||||
std::vector<std::vector<
|
||||
std::pair<wstring, std::pair<size_t, size_t>>>> uttInfoInMinibatch;
|
||||
ProcessUttInfo(uttInfo, sentenceBegin,
|
||||
|
@ -287,7 +100,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
logLikelihood.GetDeviceId(), CPUDEVICE, true, false, false);
|
||||
}
|
||||
|
||||
bool minibatchComplete = true;
|
||||
size_t currentMBSize = minibatchPackingFlag.size();
|
||||
for (size_t i = 0; i < uttInfo.size(); ++i)
|
||||
{
|
||||
|
@ -302,7 +114,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
tmpUttUnit.uttLength = uttInfo[i][j].second;
|
||||
tmpUttUnit.progress = 0;
|
||||
tmpUttUnit.streamID = i;
|
||||
tmpUttUnit.logLikelihood.Resize(m_transModel.NumPdfs(),
|
||||
tmpUttUnit.logLikelihood.Resize(logLikelihood.GetNumRows(),
|
||||
tmpUttUnit.uttLength);
|
||||
m_uttPool[uttID] = tmpUttUnit;
|
||||
}
|
||||
|
@ -329,7 +141,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_uttPool[uttID].progress += numFrames;
|
||||
if (m_uttPool[uttID].progress == m_uttPool[uttID].uttLength)
|
||||
{
|
||||
ComputeDerivative(uttID);
|
||||
m_derivativeInterface->ComputeDerivative(
|
||||
uttID,
|
||||
m_uttPool[uttID].logLikelihood,
|
||||
&m_uttPool[uttID].derivative,
|
||||
&m_uttPool[uttID].objective);
|
||||
m_uttPool[uttID].hasDerivative = true;
|
||||
m_uttPool[uttID].progress = 0;
|
||||
if (startFrame + numFrames == currentMBSize)
|
||||
|
@ -360,7 +176,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// 2: a21 b21 c21 a22 b22 c22...
|
||||
// 3: a31 b31 c31 a32 b32 c32...
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingIO<ElemType>::GetDerivative(
|
||||
bool UtteranceDerivativeBuffer<ElemType>::GetDerivative(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const Matrix<ElemType>& sentenceBegin,
|
||||
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag,
|
||||
|
@ -373,12 +189,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
ProcessUttInfo(uttInfo, sentenceBegin,
|
||||
minibatchPackingFlag, &uttInfoInMinibatch);
|
||||
|
||||
Matrix<ElemType> derivatives(CPUDEVICE);
|
||||
derivatives.Resize(m_transModel.NumPdfs(),
|
||||
sentenceBegin.GetNumCols() * sentenceBegin.GetNumRows());
|
||||
derivatives.SetValue(0);
|
||||
|
||||
m_currentObj = 0;
|
||||
Matrix<ElemType> derivatives(CPUDEVICE);
|
||||
derivatives.Resize(m_dimension,
|
||||
sentenceBegin.GetNumCols() * sentenceBegin.GetNumRows());
|
||||
for (size_t i = 0; i < uttInfo.size(); ++i)
|
||||
{
|
||||
assert(uttInfo[i].size() == uttInfoInMinibatch[i].size());
|
||||
|
@ -402,17 +216,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t numFrames = uttInfoInMinibatch[i][j].second.second;
|
||||
for (size_t k = 0; k < numFrames; ++k)
|
||||
{
|
||||
size_t posStart = startFrameInUtt + k;
|
||||
for (size_t l = 0;
|
||||
l < m_uttPool[uttID].posterior[posStart].size(); ++l)
|
||||
{
|
||||
size_t pdf_id =
|
||||
m_uttPool[uttID].posterior[posStart][l].first;
|
||||
assert(pdf_id < m_transModel.NumPdfs());
|
||||
derivatives(pdf_id,
|
||||
(startFrame + k) * m_numUttsPerMinibatch + i) -=
|
||||
m_uttPool[uttID].posterior[posStart][l].second;
|
||||
}
|
||||
derivatives.SetColumn(
|
||||
m_uttPool[uttID].derivative.ColumnSlice(
|
||||
startFrameInUtt + k, 1),
|
||||
(startFrame + k) * m_numUttsPerMinibatch + i);
|
||||
}
|
||||
m_currentObj += m_uttPool[uttID].objective
|
||||
* numFrames / m_uttPool[uttID].uttLength;
|
||||
|
@ -459,38 +266,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingIO<ElemType>::GetObjective(
|
||||
bool UtteranceDerivativeBuffer<ElemType>::GetObjective(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
Matrix<ElemType>* objectivesIn)
|
||||
{
|
||||
assert(objectivesIn != NULL);
|
||||
|
||||
// Checks utterance information.
|
||||
bool match = true;
|
||||
if (uttInfo.size() == m_currentUttInfo.size())
|
||||
{
|
||||
for (size_t i = 0; i < uttInfo.size(); ++i)
|
||||
{
|
||||
if (uttInfo[i].size() != m_currentUttInfo[i].size())
|
||||
{
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
for (size_t j = 0; j < uttInfo[i].size(); ++j)
|
||||
{
|
||||
if (uttInfo[i][j].first != m_currentUttInfo[i][j].first ||
|
||||
uttInfo[i][j].second != m_currentUttInfo[i][j].second)
|
||||
{
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
match = false;
|
||||
}
|
||||
bool match = CompareUttInfo(uttInfo, m_currentUttInfo);
|
||||
if (!match)
|
||||
{
|
||||
RuntimeError("Current objective does not correspond to the"
|
||||
|
@ -506,37 +289,58 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool KaldiSequenceTrainingIO<ElemType>::HasLatticeAndAlignment(
|
||||
bool UtteranceDerivativeBuffer<ElemType>::HasResourceForDerivative(
|
||||
const wstring& uttID) const
|
||||
{
|
||||
if(m_aliReader == false || m_denlatReader == false)
|
||||
{
|
||||
fprintf(stderr, "WARNING: lattice or alignemnt reader has not been"
|
||||
" set up yet.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string uttIDStr = msra::asr::toStr(uttID);
|
||||
if(!m_aliReader->HasKey(uttIDStr) || !m_denlatReader->HasKey(uttIDStr))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return m_derivativeInterface->HasResourceForDerivative(uttID);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void KaldiSequenceTrainingIO<ElemType>::ResetEpoch()
|
||||
bool UtteranceDerivativeBuffer<ElemType>::CompareUttInfo(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo1,
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo2)
|
||||
{
|
||||
bool match = true;
|
||||
if (uttInfo1.size() == uttInfo2.size())
|
||||
{
|
||||
for (size_t i = 0; i < uttInfo1.size(); ++i)
|
||||
{
|
||||
if (uttInfo1[i].size() != uttInfo2[i].size())
|
||||
{
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
for (size_t j = 0; j < uttInfo1[i].size(); ++j)
|
||||
{
|
||||
if (uttInfo1[i][j].first != uttInfo2[i][j].first ||
|
||||
uttInfo1[i][j].second != uttInfo2[i][j].second)
|
||||
{
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
match = false;
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void UtteranceDerivativeBuffer<ElemType>::ResetEpoch()
|
||||
{
|
||||
m_uttPool.clear();
|
||||
m_needLikelihood = true;
|
||||
m_currentObj = 0;
|
||||
m_minibatchIndex = 1;
|
||||
m_lastCompleteMinibatch.assign(m_numUttsPerMinibatch, 0);
|
||||
m_minCompleteMinibatchIndex = 0;
|
||||
m_epochEnd = false;
|
||||
m_lastCompleteMinibatch.assign(m_numUttsPerMinibatch, 0);
|
||||
m_uttPool.clear();
|
||||
m_currentUttInfo.clear();
|
||||
}
|
||||
|
||||
template class KaldiSequenceTrainingIO<float>;
|
||||
template class KaldiSequenceTrainingIO<double>;
|
||||
template class UtteranceDerivativeBuffer<float>;
|
||||
template class UtteranceDerivativeBuffer<double>;
|
||||
}}}
|
|
@ -1,30 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include "kaldi.h"
|
||||
#include "Matrix.h"
|
||||
#include "basetypes.h"
|
||||
#include "UtteranceDerivativeComputationInterface.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// This class deals with the interaction with Kaldi in order to do sequence
|
||||
// in CNTK.
|
||||
// This class "gules" together the log-likelihood from different minibatches,
|
||||
// and then calls <UtteranceDerivativeComputationInterface> class to compute
|
||||
// the derivative for given utterance.
|
||||
template<class ElemType>
|
||||
class KaldiSequenceTrainingIO
|
||||
class UtteranceDerivativeBuffer
|
||||
{
|
||||
private:
|
||||
bool m_oneSilenceClass;
|
||||
bool m_needLikelihood;
|
||||
bool m_epochEnd;
|
||||
size_t m_numUttsPerMinibatch;
|
||||
wstring m_trainCriterion;
|
||||
ElemType m_oldAcousticScale;
|
||||
ElemType m_acousticScale;
|
||||
ElemType m_lmScale;
|
||||
std::vector<kaldi::int32> m_silencePhones;
|
||||
kaldi::TransitionModel m_transModel;
|
||||
kaldi::RandomAccessCompactLatticeReader* m_denlatReader;
|
||||
kaldi::RandomAccessInt32VectorReader* m_aliReader;
|
||||
|
||||
struct UtteranceDerivativeUnit
|
||||
{
|
||||
bool hasDerivative;
|
||||
|
@ -32,10 +20,11 @@ private:
|
|||
size_t progress;
|
||||
size_t streamID;
|
||||
Matrix<ElemType> logLikelihood;
|
||||
kaldi::Posterior posterior;
|
||||
Matrix<ElemType> derivative;
|
||||
ElemType objective;
|
||||
|
||||
UtteranceDerivativeUnit() : logLikelihood(CPUDEVICE)
|
||||
UtteranceDerivativeUnit() :
|
||||
logLikelihood(CPUDEVICE), derivative(CPUDEVICE)
|
||||
{
|
||||
hasDerivative = false;
|
||||
uttLength = 0;
|
||||
|
@ -43,17 +32,18 @@ private:
|
|||
streamID = 0;
|
||||
}
|
||||
};
|
||||
ElemType m_currentObj;
|
||||
|
||||
bool m_needLikelihood;
|
||||
bool m_epochEnd;
|
||||
int m_minCompleteMinibatchIndex;
|
||||
int m_minibatchIndex;
|
||||
size_t m_numUttsPerMinibatch;
|
||||
size_t m_dimension;
|
||||
ElemType m_currentObj;
|
||||
std::vector<int> m_lastCompleteMinibatch;
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>> m_currentUttInfo;
|
||||
unordered_map<wstring, UtteranceDerivativeUnit> m_uttPool;
|
||||
|
||||
// Rescores the lattice with the lastest posteriors from the neural network.
|
||||
void LatticeAcousticRescore(
|
||||
const std::vector<kaldi::int32>& stateTimes,
|
||||
const Matrix<ElemType>& outputs, kaldi::Lattice* lat) const;
|
||||
UtteranceDerivativeComputationInterface<ElemType>* m_derivativeInterface;
|
||||
|
||||
// <uttInfoInMinibatch> is a vector of vector of the following:
|
||||
// uttID startFrameIndexInMinibatch numFrames
|
||||
|
@ -64,23 +54,19 @@ private:
|
|||
std::vector<std::vector<std::pair<
|
||||
wstring, std::pair<size_t, size_t>>>>* uttInfoInMinibatch) const;
|
||||
|
||||
bool ComputeDerivative(const wstring& uttID);
|
||||
bool CompareUttInfo(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo1,
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo2);
|
||||
|
||||
public:
|
||||
// Constructor.
|
||||
KaldiSequenceTrainingIO(const wstring& denlatRspecifier,
|
||||
const wstring& aliRspecifier,
|
||||
const wstring& transModelFilename,
|
||||
const wstring& silencePhoneStr,
|
||||
const wstring& trainCriterion,
|
||||
ElemType oldAcousticScale,
|
||||
ElemType acousticScale,
|
||||
ElemType lmScale,
|
||||
bool oneSilenceClass,
|
||||
size_t numberOfuttsPerMinibatch);
|
||||
// Does not take ownership of <derivativeInterface>.
|
||||
UtteranceDerivativeBuffer(
|
||||
size_t numberOfuttsPerMinibatch,
|
||||
UtteranceDerivativeComputationInterface<ElemType>* derivativeInterface);
|
||||
|
||||
// Destructor.
|
||||
~KaldiSequenceTrainingIO();
|
||||
~UtteranceDerivativeBuffer() {}
|
||||
|
||||
bool NeedLikelihoodToComputeDerivative() const { return m_needLikelihood; }
|
||||
|
||||
|
@ -102,7 +88,7 @@ public:
|
|||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
Matrix<ElemType>* objectivesIn);
|
||||
|
||||
bool HasLatticeAndAlignment(const wstring& uttID) const;
|
||||
bool HasResourceForDerivative(const wstring& uttID) const;
|
||||
|
||||
bool HasUtterance(const wstring& uttID) const
|
||||
{
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
#include "Matrix.h"
|
||||
#include "basetypes.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// This class defines the interface for utterance derivative computation.
|
||||
template<class ElemType>
|
||||
class UtteranceDerivativeComputationInterface
|
||||
{
|
||||
public:
|
||||
// Computes derivative and objective for given utterance ID and
|
||||
// log-likelihood from neural network output.
|
||||
virtual bool ComputeDerivative(const wstring& /*uttID*/,
|
||||
const Matrix<ElemType>& /*logLikelihood*/,
|
||||
Matrix<ElemType>* /*derivative*/,
|
||||
ElemType* /*objective*/) = 0;
|
||||
|
||||
// Returns true if we have resources to comptue the derivative, otherwise
|
||||
// returns false.
|
||||
virtual bool HasResourceForDerivative(const wstring& /*uttID*/) const = 0;
|
||||
};
|
||||
|
||||
}}}
|
|
@ -93,7 +93,7 @@ CN_SRC = MachineLearning/CNTK/NetworkDescriptionLanguage.cpp MachineLearning/CN
|
|||
BINARYREADER_SRC = DataReader/BinaryReader/BinaryWriter.cpp DataReader/BinaryReader/BinaryReader.cpp DataReader/BinaryReader/BinaryFile.cpp
|
||||
HTKMLFREADER_SRC = DataReader/HTKMLFReader_linux/HTKMLFWriter.cpp DataReader/HTKMLFReader_linux/DataWriter.cpp DataReader/HTKMLFReader_linux/DataReader.cpp DataReader/HTKMLFReader_linux/HTKMLFReader.cpp
|
||||
KALDIREADER_SRC = DataReader/KaldiReader/HTKMLFWriter.cpp DataReader/KaldiReader/DataWriter.cpp DataReader/KaldiReader/DataReader.cpp DataReader/KaldiReader/HTKMLFReader.cpp
|
||||
KALDI2READER_SRC = DataReader/Kaldi2Reader/HTKMLFWriter.cpp DataReader/Kaldi2Reader/DataWriter.cpp DataReader/Kaldi2Reader/DataReader.cpp DataReader/Kaldi2Reader/HTKMLFReader.cpp DataReader/Kaldi2Reader/KaldiSequenceTrainingIO.cpp
|
||||
KALDI2READER_SRC = DataReader/Kaldi2Reader/HTKMLFWriter.cpp DataReader/Kaldi2Reader/DataWriter.cpp DataReader/Kaldi2Reader/DataReader.cpp DataReader/Kaldi2Reader/HTKMLFReader.cpp DataReader/Kaldi2Reader/UtteranceDerivativeBuffer.cpp DataReader/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp
|
||||
SEQUENCEREADER_SRC = DataReader/LMSequenceReader/SequenceReader.cpp DataReader/LMSequenceReader/SequenceParser.cpp DataReader/LMSequenceReader/Exports.cpp
|
||||
LUSEQUENCEREADER_SRC = DataReader/LUSequenceReader/LUSequenceReader.cpp DataReader/LUSequenceReader/LUSequenceParser.cpp DataReader/LUSequenceReader/Exports.cpp
|
||||
UCIFASTREADER_SRC = DataReader/UCIFastReader/UCIParser.cpp DataReader/UCIFastReader/UCIFastReader.cpp DataReader/UCIFastReader/Exports.cpp
|
||||
|
|
|
@ -105,7 +105,7 @@ CN_SRC = MachineLearning/CNTK/NetworkDescriptionLanguage.cpp MachineLearning/CN
|
|||
BINARYREADER_SRC = #DataReader/BinaryReader/BinaryWriter.cpp DataReader/BinaryReader/BinaryReader.cpp DataReader/BinaryReader/BinaryFile.cpp
|
||||
HTKMLFREADER_SRC = DataReader/HTKMLFReader_linux/HTKMLFWriter.cpp DataReader/HTKMLFReader_linux/DataWriter.cpp DataReader/HTKMLFReader_linux/DataReader.cpp DataReader/HTKMLFReader_linux/HTKMLFReader.cpp
|
||||
KALDIREADER_SRC = DataReader/KaldiReader/HTKMLFWriter.cpp DataReader/KaldiReader/DataWriter.cpp DataReader/KaldiReader/DataReader.cpp DataReader/KaldiReader/HTKMLFReader.cpp
|
||||
KALDI2READER_SRC = DataReader/Kaldi2Reader/HTKMLFWriter.cpp DataReader/Kaldi2Reader/DataWriter.cpp DataReader/Kaldi2Reader/DataReader.cpp DataReader/Kaldi2Reader/HTKMLFReader.cpp DataReader/Kaldi2Reader/KaldiSequenceTrainingIO.cpp
|
||||
KALDI2READER_SRC = DataReader/Kaldi2Reader/HTKMLFWriter.cpp DataReader/Kaldi2Reader/DataWriter.cpp DataReader/Kaldi2Reader/DataReader.cpp DataReader/Kaldi2Reader/HTKMLFReader.cpp DataReader/Kaldi2Reader/UtteranceDerivativeBuffer.cpp DataReader/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp
|
||||
|
||||
SEQUENCEREADER_SRC = DataReader/LMSequenceReader/SequenceReader.cpp DataReader/LMSequenceReader/SequenceParser.cpp DataReader/LMSequenceReader/Exports.cpp
|
||||
LUSEQUENCEREADER_SRC = DataReader/LUSequenceReader/LUSequenceReader.cpp DataReader/LUSequenceReader/LUSequenceParser.cpp DataReader/LUSequenceReader/Exports.cpp
|
||||
|
|
Загрузка…
Ссылка в новой задаче