fast version
This commit is contained in:
Родитель
30fc8a5d04
Коммит
8a494db3d5
|
@ -222,10 +222,10 @@ public:
|
|||
let len = inputSequences[i].GetNumTimeSteps();
|
||||
// first see if we find a row that has enough space
|
||||
// TODO: Should we use a proper priority_queue?
|
||||
size_t s = rowAllocations.size();
|
||||
/*for (s = 0; s < rowAllocations.size(); s++)
|
||||
size_t s;
|
||||
for (s = 0; s < rowAllocations.size(); s++)
|
||||
if (rowAllocations[s] + len <= width)
|
||||
break; */ // yep, it fits
|
||||
break; // yep, it fits
|
||||
// we did not find a s that fit then create a new one
|
||||
if (s == rowAllocations.size())
|
||||
rowAllocations.push_back(0);
|
||||
|
|
|
@ -59,7 +59,8 @@ public:
|
|||
InputRef(1).ValueFor(fr).VectorMax(*m_maxIndexes1, *m_maxValues, true, m_topK);
|
||||
MaskMissingColumnsToZero(*m_maxIndexes0, InputRef(0).GetMBLayout(), fr);
|
||||
MaskMissingColumnsToZero(*m_maxIndexes1, InputRef(1).GetMBLayout(), fr);
|
||||
m_maxIndexes0->Print("LM out");
|
||||
//m_maxIndexes0->Print("LM out");
|
||||
//m_maxIndexes1->Print("LSTM out");
|
||||
Value().AssignNumOfDiff(*m_maxIndexes0, *m_maxIndexes1, m_topK > 1);
|
||||
#if NANCHECK
|
||||
Value().HasNan("ClassificationError");
|
||||
|
@ -512,6 +513,9 @@ public:
|
|||
|
||||
MaskMissingColumnsToZero(*m_maxIndexes0, Input(0)->GetMBLayout(), frameRange);
|
||||
MaskMissingColumnsToZero(*m_maxIndexes1, Input(1)->GetMBLayout(), frameRange);
|
||||
|
||||
//m_maxIndexes1->Print("LSTM output");
|
||||
//m_maxIndexes0->Print("label output");
|
||||
Value()(0, 0) = ComputeEditDistanceError(*m_maxIndexes0, *m_maxIndexes1, Input(0)->GetMBLayout(), m_subPen, m_delPen, m_insPen, m_squashInputs, m_tokensToIgnore);
|
||||
Value().TransferToDeviceIfNotThere(Input(0)->GetDeviceId());
|
||||
}
|
||||
|
@ -615,6 +619,7 @@ public:
|
|||
ExtractSampleSequence(firstSeq, columnIndices, squashInputs, tokensToIgnore, firstSeqVec);
|
||||
ExtractSampleSequence(secondSeq, columnIndices, squashInputs, tokensToIgnore, secondSeqVec);
|
||||
|
||||
|
||||
//calculate edit distance
|
||||
size_t firstSize = firstSeqVec.size();
|
||||
size_t secondSize = secondSeqVec.size();
|
||||
|
@ -623,6 +628,20 @@ public:
|
|||
else
|
||||
totalSampleNum += firstSize;
|
||||
|
||||
/*fprintf(stderr, "label: ");
|
||||
for (size_t n = 0; n < firstSize; n++)
|
||||
{
|
||||
fprintf(stderr, "%d ", firstSeqVec[n]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
fprintf(stderr, "output: ");
|
||||
for (size_t n = 0; n < secondSize; n++)
|
||||
{
|
||||
fprintf(stderr, "%d ", secondSeqVec[n]);
|
||||
}
|
||||
fprintf(stderr, "\n");*/
|
||||
|
||||
grid.Resize(firstSize + 1, secondSize + 1);
|
||||
insMatrix.Resize(firstSize + 1, secondSize + 1);
|
||||
delMatrix.Resize(firstSize + 1, secondSize + 1);
|
||||
|
|
|
@ -6879,6 +6879,17 @@ void _assignRNNTScore(
|
|||
y = exp(x);
|
||||
RNNTscore[probId] -= y;
|
||||
}
|
||||
|
||||
if (u == phoneNum - 1 && t == uttFrameNum[uttId] - 1)
|
||||
{
|
||||
size_t probId = tuID *totalPhoneNum + blankTokenId;
|
||||
x = alphaScore[alphaId] + prob[probId] - P_lx;
|
||||
if (x < LZERO)
|
||||
y = 0.0f;
|
||||
else
|
||||
y = exp(x);
|
||||
RNNTscore[probId] -= y;
|
||||
}
|
||||
//for (size_t k == 0; k < totalPhoneNum; k++)
|
||||
|
||||
}
|
||||
|
@ -7303,19 +7314,19 @@ CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignRNNTScore(const CPUMatrix<ElemTy
|
|||
|
||||
m_derivativeForF.SetValue(0.0);
|
||||
m_derivativeForG.SetValue(0.0);
|
||||
//_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
|
||||
// uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
|
||||
_assignRNNTScore2(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttPhoneBeginIdx, uttFrameToChanInd, uttPhoneToChanInd,
|
||||
uttBeginForOutputditribution, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId, m_derivativeForF.Data(), m_derivativeForG.Data());
|
||||
_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
|
||||
uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
|
||||
// _assignRNNTScore2(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttPhoneBeginIdx, uttFrameToChanInd, uttPhoneToChanInd,
|
||||
// uttBeginForOutputditribution, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId, m_derivativeForF.Data(), m_derivativeForG.Data());
|
||||
//this->Print("RNNT score");
|
||||
totalScore(0, 0) = 0.0;
|
||||
fprintf(stderr, "utt score: ");
|
||||
//fprintf(stderr, "utt score: ");
|
||||
for (size_t utt = 0; utt < uttNum; utt++)
|
||||
{
|
||||
fprintf(stderr, "%f ", scores[utt]);
|
||||
//fprintf(stderr, "%f ", scores[utt]);
|
||||
totalScore(0, 0) -= scores[utt];
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
//fprintf(stderr, "\n");
|
||||
//alpha.SetValue(0.0);
|
||||
//ElemType score= compute_alphas(prob.Data(), alpha.Data(), (int)maxFrameNum, (int)maxPhoneNum, phoneSeq.Data());
|
||||
//CPUMatrix<ElemType> trans_grads, predict_grads;
|
||||
|
|
|
@ -562,7 +562,7 @@ public:
|
|||
ElemType finalscore = 0;
|
||||
finalscore = totalScore.Get00Element();
|
||||
//fprintf(stderr, "finalscore:%f\n", finalscore);
|
||||
//if (finalscore > 50 || finalscore < 0)
|
||||
if (finalscore > 50 || finalscore < 0)
|
||||
{
|
||||
for (size_t i = 0; i < uttFrameNum.size(); i++)
|
||||
{
|
||||
|
@ -579,10 +579,10 @@ public:
|
|||
matrixOutputDistribution.ReleaseMemory();
|
||||
|
||||
//compute derivatives for F and G
|
||||
//m_derivativeForF.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
|
||||
// numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 0);
|
||||
//m_derivativeForG.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
|
||||
// numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 1);
|
||||
m_derivativeForF.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
|
||||
numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 0);
|
||||
m_derivativeForG.AssignUserOp2(RNNTPosterior, uttFrameToChanInd, uttPhoneToChanInd, uttFrameBeginIdx, uttPhoneBeginIdx, uttBeginForOutputditribution, uttFrameNum, uttPhoneNum,
|
||||
numParallelSequences, numPhoneParallelSequences, maxFrameNum, maxPhoneNum, 1);
|
||||
|
||||
//m_derivativeForF.Print("derivative for F");
|
||||
//m_derivativeForG.Print("derivative for G");
|
||||
|
|
Загрузка…
Ссылка в новой задаче