Optimize release
This commit is contained in:
Родитель
794b37e67b
Коммит
b27daef0a9
|
@ -54,6 +54,7 @@ class SimpleOutputWriter
|
|||
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
|
||||
typedef typename std::vector<Sequence>::iterator iterator;
|
||||
unordered_map<wstring, vector<shared_ptr<PastValueNode<ElemType>>>> m_nameToPastValueNodeCache;
|
||||
vector<shared_ptr<Matrix<ElemType>>> m_decodeOutputCache;
|
||||
|
||||
public:
|
||||
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
|
||||
|
@ -304,7 +305,15 @@ public:
|
|||
oneSeq.length = a.length;
|
||||
oneSeq.lengthwithblank = a.lengthwithblank;
|
||||
oneSeq.processlength = a.processlength;
|
||||
if (m_decodeOutputCache.size() > 0)
|
||||
{
|
||||
oneSeq.decodeoutput = m_decodeOutputCache.back();
|
||||
m_decodeOutputCache.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
|
||||
}
|
||||
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));
|
||||
|
||||
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
|
||||
|
@ -317,13 +326,12 @@ public:
|
|||
oneSeq.nameToParentNodeValues[it->first] = it->second;
|
||||
a.refs++;
|
||||
}
|
||||
else
|
||||
else
|
||||
oneSeq.nameToParentNodeValues[it->first] = a.nameToParentNodeValues[it->first];
|
||||
/*size_t ab = oneSeq.nameToParentNodeValues[it->first]->Value().GetNumElements();
|
||||
if (ab > 0)
|
||||
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
|
||||
}
|
||||
|
||||
auto itin = m_nameToPastValueNodeCache.find(it->first);
|
||||
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
|
||||
{
|
||||
|
@ -354,13 +362,14 @@ public:
|
|||
/*long t = oneSeq.nameToNodeValues[it->first].use_count();
|
||||
fprintf(stderr, "use count %lu %lu \n", t, oneSeq.refs);*/
|
||||
if (oneSeq.refs == 0)
|
||||
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
|
||||
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
|
||||
|
||||
/*std::ostringstream address;
|
||||
address << oneSeq.nameToNodeValues[it->first];
|
||||
fprintf(stderr, "deleteSeq %ls %s \n", it->first.c_str(), address.str().c_str());*/
|
||||
}
|
||||
oneSeq.decodeoutput->ReleaseMemory();
|
||||
m_decodeOutputCache.push_back(oneSeq.decodeoutput);
|
||||
//oneSeq.decodeoutput->ReleaseMemory();
|
||||
vector<size_t>().swap(oneSeq.labelseq);
|
||||
}
|
||||
iterator getMaxSeq(const vector<Sequence>& seqs)
|
||||
|
@ -415,9 +424,9 @@ public:
|
|||
{
|
||||
if (it->second && it->second->Value().GetNumElements() > 0)
|
||||
{
|
||||
it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
|
||||
it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
|
||||
/*std::ostringstream address;
|
||||
address << s.nameToNodeValues[it->first];
|
||||
address << s.nameToNodeValues[it->first];
|
||||
fprintf(stderr, "prepareSequence %ls %s \n", it->first.c_str(), address.str().c_str());*/
|
||||
}
|
||||
}
|
||||
|
@ -499,7 +508,7 @@ public:
|
|||
{
|
||||
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);
|
||||
}
|
||||
/* else
|
||||
/* else
|
||||
{
|
||||
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
|
||||
}*/
|
||||
|
@ -524,7 +533,7 @@ public:
|
|||
out << (*oneSeq.decodeoutput)(m_i, j);
|
||||
}
|
||||
}
|
||||
out << string("\n");
|
||||
out << string("\n");
|
||||
|
||||
out.close();*/
|
||||
|
||||
|
@ -584,9 +593,9 @@ public:
|
|||
//plus broadcast
|
||||
(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusNode)->Value())->SetValue(sumofENandDE);
|
||||
//SumMatrix.SetValue(sumofENandDE);
|
||||
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
|
||||
auto PlusMBlayout = PlusNode->GetMBLayout();
|
||||
PlusMBlayout->Init(1, 1);
|
||||
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
|
||||
PlusMBlayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
|
||||
m_net->ForwardPropFromTo(Plusnodes, Plustransnodes);
|
||||
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));
|
||||
|
|
Загрузка…
Ссылка в новой задаче