force expand blank
This commit is contained in:
Родитель
4c67aae929
Коммит
2ddfea26f6
|
@ -441,7 +441,7 @@ public:
|
|||
return a.logP < b.logP;
|
||||
}
|
||||
|
||||
vector<pair<size_t, ElemType>> getTopN(Microsoft::MSR::CNTK::Matrix<ElemType>& prob, size_t N)
|
||||
vector<pair<size_t, ElemType>> getTopN(Microsoft::MSR::CNTK::Matrix<ElemType>& prob, size_t N, size_t& blankid)
|
||||
{
|
||||
vector<pair<size_t, ElemType>> datapair;
|
||||
typedef vector<pair<size_t, ElemType>>::value_type ValueType;
|
||||
|
@ -453,6 +453,7 @@ public:
|
|||
nth_element(datapair.begin(), datapair.begin() + N, datapair.end(), [](ValueType const& x, ValueType const& y) -> bool {
|
||||
return y.second < x.second;
|
||||
});
|
||||
datapair.push_back(ValueType(blankid, probdata[blankid]));
|
||||
delete probdata;
|
||||
return datapair;
|
||||
}
|
||||
|
@ -656,7 +657,7 @@ public:
|
|||
|
||||
//sumofENandDE.Print("sum");
|
||||
//sort log posterior and get best N labels
|
||||
vector<pair<size_t, ElemType>> topN = getTopN(decodeOutput, expandBeam);
|
||||
vector<pair<size_t, ElemType>> topN = getTopN(decodeOutput, expandBeam, blankId);
|
||||
/*ElemType* logP = decodeOutput.CopyToArray();
|
||||
std::priority_queue<std::pair<double, int>> q;
|
||||
int iLabel;
|
||||
|
@ -681,38 +682,41 @@ public:
|
|||
CurSequences.push_back(seqK);
|
||||
q.pop();
|
||||
}*/
|
||||
//expand blank
|
||||
Sequence seqK = newSeq(tempSeq, deviceid);
|
||||
ElemType newlogP = topN[vocabSize].second + tempSeq.logP;
|
||||
seqK.logP = newlogP;
|
||||
bool existseq = false;
|
||||
for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++)
|
||||
//for (Sequence seqP : keyNextSequences) //does not work
|
||||
{
|
||||
//merge the score with same sequence
|
||||
if (seqK.labelseq == itseq->labelseq)
|
||||
{
|
||||
existseq = true;
|
||||
itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP);
|
||||
//itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!existseq)
|
||||
{
|
||||
nextSequences.push_back(seqK);
|
||||
}
|
||||
|
||||
int iLabel;
|
||||
for (iLabel = 0; iLabel < expandBeam; iLabel++)
|
||||
{
|
||||
|
||||
Sequence seqK = newSeq(tempSeq, deviceid);
|
||||
ElemType newlogP = topN[iLabel].second + tempSeq.logP;
|
||||
seqK = newSeq(tempSeq, deviceid);
|
||||
newlogP = topN[iLabel].second + tempSeq.logP;
|
||||
seqK.logP = newlogP;
|
||||
|
||||
if (topN[iLabel].first == blankId)
|
||||
if (topN[iLabel].first != blankId)
|
||||
{
|
||||
bool existseq = false;
|
||||
for (auto itseq = nextSequences.begin(); itseq != nextSequences.end(); itseq++)
|
||||
//for (Sequence seqP : keyNextSequences)
|
||||
{
|
||||
//merge the score with same sequence
|
||||
if (seqK.labelseq == itseq->labelseq)
|
||||
{
|
||||
existseq = true;
|
||||
itseq->logP = decodeOutput.LogAdd(seqK.logP, itseq->logP);
|
||||
//itseq->lengthwithblank = (seqK.lengthwithblank + itseq->lengthwithblank) / 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!existseq)
|
||||
{
|
||||
nextSequences.push_back(seqK);
|
||||
}
|
||||
//nextSequences.push_back(seqK);
|
||||
continue;
|
||||
}
|
||||
extendSeq(seqK, topN[iLabel].first, newlogP);
|
||||
CurSequences.push_back(seqK);
|
||||
}
|
||||
}
|
||||
vector<pair<size_t, ElemType>>().swap(topN);
|
||||
//delete topN;
|
||||
|
|
Загрузка…
Ссылка в новой задаче