This commit is contained in:
Vadim Mazalov 2019-09-27 12:15:52 -07:00
Родитель 794b37e67b
Коммит b27daef0a9
1 изменённых файлов: 18 добавлений и 9 удалений

Просмотреть файл

@ -54,6 +54,7 @@ class SimpleOutputWriter
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
typedef typename std::vector<Sequence>::iterator iterator;
unordered_map<wstring, vector<shared_ptr<PastValueNode<ElemType>>>> m_nameToPastValueNodeCache;
vector<shared_ptr<Matrix<ElemType>>> m_decodeOutputCache;
public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
@ -304,7 +305,15 @@ public:
oneSeq.length = a.length;
oneSeq.lengthwithblank = a.lengthwithblank;
oneSeq.processlength = a.processlength;
if (m_decodeOutputCache.size() > 0)
{
oneSeq.decodeoutput = m_decodeOutputCache.back();
m_decodeOutputCache.pop_back();
}
else
{
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
}
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
@ -323,7 +332,6 @@ public:
if (ab > 0)
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
}
auto itin = m_nameToPastValueNodeCache.find(it->first);
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
{
@ -360,7 +368,8 @@ public:
address << oneSeq.nameToNodeValues[it->first];
fprintf(stderr, "deleteSeq %ls %s \n", it->first.c_str(), address.str().c_str());*/
}
oneSeq.decodeoutput->ReleaseMemory();
m_decodeOutputCache.push_back(oneSeq.decodeoutput);
//oneSeq.decodeoutput->ReleaseMemory();
vector<size_t>().swap(oneSeq.labelseq);
}
iterator getMaxSeq(const vector<Sequence>& seqs)
@ -584,9 +593,9 @@ public:
//plus broadcast
(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusNode)->Value())->SetValue(sumofENandDE);
//SumMatrix.SetValue(sumofENandDE);
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
auto PlusMBlayout = PlusNode->GetMBLayout();
PlusMBlayout->Init(1, 1);
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
PlusMBlayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
m_net->ForwardPropFromTo(Plusnodes, Plustransnodes);
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));