This commit is contained in:
Vadim Mazalov 2019-09-27 12:15:52 -07:00
Родитель 794b37e67b
Коммит b27daef0a9
1 изменённых файлов: 18 добавлений и 9 удалений

Просмотреть файл

@ -54,6 +54,7 @@ class SimpleOutputWriter
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
typedef typename std::vector<Sequence>::iterator iterator;
unordered_map<wstring, vector<shared_ptr<PastValueNode<ElemType>>>> m_nameToPastValueNodeCache;
vector<shared_ptr<Matrix<ElemType>>> m_decodeOutputCache;
public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
@ -304,7 +305,15 @@ public:
oneSeq.length = a.length;
oneSeq.lengthwithblank = a.lengthwithblank;
oneSeq.processlength = a.processlength;
if (m_decodeOutputCache.size() > 0)
{
oneSeq.decodeoutput = m_decodeOutputCache.back();
m_decodeOutputCache.pop_back();
}
else
{
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
}
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
@ -317,13 +326,12 @@ public:
oneSeq.nameToParentNodeValues[it->first] = it->second;
a.refs++;
}
else
else
oneSeq.nameToParentNodeValues[it->first] = a.nameToParentNodeValues[it->first];
/*size_t ab = oneSeq.nameToParentNodeValues[it->first]->Value().GetNumElements();
if (ab > 0)
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
}
auto itin = m_nameToPastValueNodeCache.find(it->first);
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
{
@ -354,13 +362,14 @@ public:
/*long t = oneSeq.nameToNodeValues[it->first].use_count();
fprintf(stderr, "use count %lu %lu \n", t, oneSeq.refs);*/
if (oneSeq.refs == 0)
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
/*std::ostringstream address;
address << oneSeq.nameToNodeValues[it->first];
fprintf(stderr, "deleteSeq %ls %s \n", it->first.c_str(), address.str().c_str());*/
}
oneSeq.decodeoutput->ReleaseMemory();
m_decodeOutputCache.push_back(oneSeq.decodeoutput);
//oneSeq.decodeoutput->ReleaseMemory();
vector<size_t>().swap(oneSeq.labelseq);
}
iterator getMaxSeq(const vector<Sequence>& seqs)
@ -415,9 +424,9 @@ public:
{
if (it->second && it->second->Value().GetNumElements() > 0)
{
it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
/*std::ostringstream address;
address << s.nameToNodeValues[it->first];
address << s.nameToNodeValues[it->first];
fprintf(stderr, "prepareSequence %ls %s \n", it->first.c_str(), address.str().c_str());*/
}
}
@ -499,7 +508,7 @@ public:
{
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);
}
/* else
/* else
{
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
}*/
@ -524,7 +533,7 @@ public:
out << (*oneSeq.decodeoutput)(m_i, j);
}
}
out << string("\n");
out << string("\n");
out.close();*/
@ -584,9 +593,9 @@ public:
//plus broadcast
(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusNode)->Value())->SetValue(sumofENandDE);
//SumMatrix.SetValue(sumofENandDE);
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
auto PlusMBlayout = PlusNode->GetMBLayout();
PlusMBlayout->Init(1, 1);
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
PlusMBlayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
m_net->ForwardPropFromTo(Plusnodes, Plustransnodes);
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));