This commit is contained in:
Rui Zhao (SPEECH) 2018-08-27 17:03:04 -07:00
Родитель 30fc8a5d04
Коммит 8a494db3d5
4 изменённых файлов: 46 добавлений и 16 удалений

Просмотреть файл

@ -222,10 +222,10 @@ public:
let len = inputSequences[i].GetNumTimeSteps();
// first see if we find a row that has enough space
// TODO: Should we use a proper priority_queue?
size_t s = rowAllocations.size();
/*for (s = 0; s < rowAllocations.size(); s++)
size_t s;
for (s = 0; s < rowAllocations.size(); s++)
if (rowAllocations[s] + len <= width)
break; */ // yep, it fits
break; // yep, it fits
// we did not find a s that fit then create a new one
if (s == rowAllocations.size())
rowAllocations.push_back(0);

Просмотреть файл

@ -59,7 +59,8 @@ public:
InputRef(1).ValueFor(fr).VectorMax(*m_maxIndexes1, *m_maxValues, true, m_topK);
MaskMissingColumnsToZero(*m_maxIndexes0, InputRef(0).GetMBLayout(), fr);
MaskMissingColumnsToZero(*m_maxIndexes1, InputRef(1).GetMBLayout(), fr);
m_maxIndexes0->Print("LM out");
//m_maxIndexes0->Print("LM out");
//m_maxIndexes1->Print("LSTM out");
Value().AssignNumOfDiff(*m_maxIndexes0, *m_maxIndexes1, m_topK > 1);
#if NANCHECK
Value().HasNan("ClassificationError");
@ -512,6 +513,9 @@ public:
MaskMissingColumnsToZero(*m_maxIndexes0, Input(0)->GetMBLayout(), frameRange);
MaskMissingColumnsToZero(*m_maxIndexes1, Input(1)->GetMBLayout(), frameRange);
//m_maxIndexes1->Print("LSTM output");
//m_maxIndexes0->Print("label output");
Value()(0, 0) = ComputeEditDistanceError(*m_maxIndexes0, *m_maxIndexes1, Input(0)->GetMBLayout(), m_subPen, m_delPen, m_insPen, m_squashInputs, m_tokensToIgnore);
Value().TransferToDeviceIfNotThere(Input(0)->GetDeviceId());
}
@ -615,6 +619,7 @@ public:
ExtractSampleSequence(firstSeq, columnIndices, squashInputs, tokensToIgnore, firstSeqVec);
ExtractSampleSequence(secondSeq, columnIndices, squashInputs, tokensToIgnore, secondSeqVec);
//calculate edit distance
size_t firstSize = firstSeqVec.size();
size_t secondSize = secondSeqVec.size();
@ -623,6 +628,20 @@ public:
else
totalSampleNum += firstSize;
/*fprintf(stderr, "label: ");
for (size_t n = 0; n < firstSize; n++)
{
fprintf(stderr, "%d ", firstSeqVec[n]);
}
fprintf(stderr, "\n");
fprintf(stderr, "output: ");
for (size_t n = 0; n < secondSize; n++)
{
fprintf(stderr, "%d ", secondSeqVec[n]);
}
fprintf(stderr, "\n");*/
grid.Resize(firstSize + 1, secondSize + 1);
insMatrix.Resize(firstSize + 1, secondSize + 1);
delMatrix.Resize(firstSize + 1, secondSize + 1);

Просмотреть файл

@ -6879,6 +6879,17 @@ void _assignRNNTScore(
y = exp(x);
RNNTscore[probId] -= y;
}
if (u == phoneNum - 1 && t == uttFrameNum[uttId] - 1)
{
size_t probId = tuID *totalPhoneNum + blankTokenId;
x = alphaScore[alphaId] + prob[probId] - P_lx;
if (x < LZERO)
y = 0.0f;
else
y = exp(x);
RNNTscore[probId] -= y;
}
//for (size_t k == 0; k < totalPhoneNum; k++)
}
@ -7303,19 +7314,19 @@ CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignRNNTScore(const CPUMatrix<ElemTy
m_derivativeForF.SetValue(0.0);
m_derivativeForG.SetValue(0.0);
//_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
// uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
_assignRNNTScore2(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttPhoneBeginIdx, uttFrameToChanInd, uttPhoneToChanInd,
uttBeginForOutputditribution, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId, m_derivativeForF.Data(), m_derivativeForG.Data());
_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
// _assignRNNTScore2(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttPhoneBeginIdx, uttFrameToChanInd, uttPhoneToChanInd,
// uttBeginForOutputditribution, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId, m_derivativeForF.Data(), m_derivativeForG.Data());
//this->Print("RNNT score");
totalScore(0, 0) = 0.0;
fprintf(stderr, "utt score: ");
//fprintf(stderr, "utt score: ");
for (size_t utt = 0; utt < uttNum; utt++)
{
fprintf(stderr, "%f ", scores[utt]);
//fprintf(stderr, "%f ", scores[utt]);
totalScore(0, 0) -= scores[utt];
}
fprintf(stderr, "\n");
//fprintf(stderr, "\n");
//alpha.SetValue(0.0);
//ElemType score= compute_alphas(prob.Data(), alpha.Data(), (int)maxFrameNum, (int)maxPhoneNum, phoneSeq.Data());
//CPUMatrix<ElemType> trans_grads, predict_grads;

Просмотреть файл

@ -562,7 +562,7 @@ public:
ElemType finalscore = 0;
finalscore = totalScore.Get00Element();
//fprintf(stderr, "finalscore:%f\n", finalscore);
//if (finalscore > 50 || finalscore < 0)
if (finalscore > 50 || finalscore < 0)
{
for (size_t i = 0; i < uttFrameNum.size(); i++)
{
@ -579,10 +579,10 @@ public:
matrixOutputDistribution.ReleaseMemory();
//compute derivatives for F and G
//m_derivativeForF.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
// numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 0);
//m_derivativeForG.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
// numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 1);
m_derivativeForF.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 0);
m_derivativeForG.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 1);
//m_derivativeForF.Print("derivative for F");
//m_derivativeForG.Print("derivative for G");