This commit is contained in:
Rui Zhao (SPEECH) 2019-09-24 15:56:55 -07:00
Родитель 4c67aae929
Коммит 2ddfea26f6
1 изменённых файлов: 29 добавлений и 25 удалений

Просмотреть файл

@ -441,7 +441,7 @@ public:
return a.logP < b.logP;
}
vector<pair<size_t, ElemType>> getTopN(Microsoft::MSR::CNTK::Matrix<ElemType>& prob, size_t N)
vector<pair<size_t, ElemType>> getTopN(Microsoft::MSR::CNTK::Matrix<ElemType>& prob, size_t N, size_t& blankid)
{
vector<pair<size_t, ElemType>> datapair;
typedef vector<pair<size_t, ElemType>>::value_type ValueType;
@ -453,6 +453,7 @@ public:
nth_element(datapair.begin(), datapair.begin() + N, datapair.end(), [](ValueType const& x, ValueType const& y) -> bool {
return y.second < x.second;
});
datapair.push_back(ValueType(blankid, probdata[blankid]));
delete probdata;
return datapair;
}
@ -656,7 +657,7 @@ public:
//sumofENandDE.Print("sum");
//sort log posterior and get best N labels
vector<pair<size_t, ElemType>> topN = getTopN(decodeOutput, expandBeam);
vector<pair<size_t, ElemType>> topN = getTopN(decodeOutput, expandBeam, blankId);
/*ElemType* logP = decodeOutput.CopyToArray();
std::priority_queue<std::pair<double, int>> q;
int iLabel;
@ -681,38 +682,41 @@ public:
CurSequences.push_back(seqK);
q.pop();
}*/
//expand blank
Sequence seqK = newSeq(tempSeq, deviceid);
ElemType newlogP = topN[vocabSize].second + tempSeq.logP;
seqK.logP = newlogP;
bool existseq = false;
for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++)
//for (Sequence seqP : keyNextSequences) //does not work
{
//merge the score with same sequence
if (seqK.labelseq == itseq->labelseq)
{
existseq = true;
itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP);
//itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2;
break;
}
}
if (!existseq)
{
nextSequences.push_back(seqK);
}
int iLabel;
for (iLabel = 0; iLabel < expandBeam; iLabel++)
{
Sequence seqK = newSeq(tempSeq, deviceid);
ElemType newlogP = topN[iLabel].second + tempSeq.logP;
seqK = newSeq(tempSeq, deviceid);
newlogP = topN[iLabel].second + tempSeq.logP;
seqK.logP = newlogP;
if (topN[iLabel].first == blankId)
if (topN[iLabel].first != blankId)
{
bool existseq = false;
for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++)
//for (Sequence seqP : keyNextSequences)
{
//merge the score with same sequence
if (seqK.labelseq == itseq->labelseq)
{
existseq = true;
itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP);
//itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2;
break;
}
}
if (!existseq)
{
nextSequences.push_back(seqK);
}
//nextSequences.push_back(seqK);
continue;
}
extendSeq(seqK, topN[iLabel].first, newlogP);
CurSequences.push_back(seqK);
}
}
vector<pair<size_t, ElemType>>().swap(topN);
//delete topN;