From 2ddfea26f6d506c980477fa0fb7ba5dbdce896a8 Mon Sep 17 00:00:00 2001 From: "Rui Zhao (SPEECH)" Date: Tue, 24 Sep 2019 15:56:55 -0700 Subject: [PATCH] force expand blank --- Source/SGDLib/SimpleOutputWriter.h | 54 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/Source/SGDLib/SimpleOutputWriter.h b/Source/SGDLib/SimpleOutputWriter.h index 10145d402..369a9b49e 100644 --- a/Source/SGDLib/SimpleOutputWriter.h +++ b/Source/SGDLib/SimpleOutputWriter.h @@ -441,7 +441,7 @@ public: return a.logP < b.logP; } - vector> getTopN(Microsoft::MSR::CNTK::Matrix& prob, size_t N) + vector> getTopN(Microsoft::MSR::CNTK::Matrix& prob, size_t N, size_t& blankid) { vector> datapair; typedef vector>::value_type ValueType; @@ -453,6 +453,7 @@ public: nth_element(datapair.begin(), datapair.begin() + N, datapair.end(), [](ValueType const& x, ValueType const& y) -> bool { return y.second < x.second; }); + datapair.push_back(ValueType(blankid, probdata[blankid])); delete probdata; return datapair; } @@ -656,7 +657,7 @@ public: //sumofENandDE.Print("sum"); //sort log posterior and get best N labels - vector> topN = getTopN(decodeOutput, expandBeam); + vector> topN = getTopN(decodeOutput, expandBeam, blankId); /*ElemType* logP = decodeOutput.CopyToArray(); std::priority_queue> q; int iLabel; @@ -681,38 +682,41 @@ public: CurSequences.push_back(seqK); q.pop(); }*/ + //expand blank + Sequence seqK = newSeq(tempSeq, deviceid); + ElemType newlogP = topN[vocabSize].second + tempSeq.logP; + seqK.logP = newlogP; + bool existseq = false; + for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++) + //for (Sequence seqP : keyNextSequences) //does not work + { + //merge the score with same sequence + if (seqK.labelseq == itseq->labelseq) + { + existseq = true; + itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP); + //itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2; + break; + } + } + if (!existseq) + { + nextSequences.push_back(seqK); + } + int iLabel; for (iLabel = 0; iLabel < expandBeam; iLabel++) { - Sequence seqK = newSeq(tempSeq, deviceid); - ElemType newlogP = topN[iLabel].second + tempSeq.logP; + seqK = newSeq(tempSeq, deviceid); + newlogP = topN[iLabel].second + tempSeq.logP; seqK.logP = newlogP; - if (topN[iLabel].first == blankId) + if (topN[iLabel].first != blankId) { - bool existseq = false; - for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++) - //for (Sequence seqP : keyNextSequences) - { - //merge the score with same sequence - if (seqK.labelseq == itseq->labelseq) - { - existseq = true; - itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP); - //itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2; - break; - } - } - if (!existseq) - { - nextSequences.push_back(seqK); - } - //nextSequences.push_back(seqK); - continue; - } extendSeq(seqK, topN[iLabel].first, newlogP); CurSequences.push_back(seqK); +} } vector>().swap(topN); //delete topN;