Enable cuda evaluation of RNNT models
This commit is contained in:
Родитель
2ddfea26f6
Коммит
1386278827
|
@ -18,7 +18,6 @@
|
|||
#include <cstdio>
|
||||
#include "ProgressTracing.h"
|
||||
#include "ComputationNetworkBuilder.h"
|
||||
#include "RecurrentNodes.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace std;
|
||||
|
@ -46,11 +45,9 @@ class SimpleOutputWriter
|
|||
{
|
||||
return logP < rhs.logP;
|
||||
}
|
||||
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>> nameToNodeValues;
|
||||
};
|
||||
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
|
||||
typedef typename std::vector<Sequence>::iterator iterator;
|
||||
unordered_map<wstring, vector<shared_ptr<PastValueNode<ElemType>>>> m_nameToPastValueNodeCache;
|
||||
|
||||
public:
|
||||
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
|
||||
|
@ -291,11 +288,6 @@ public:
|
|||
Sequence newSeq(size_t numRow, size_t numCol, DEVICEID_TYPE deviceId)
|
||||
{
|
||||
Sequence oneSeq = {std::vector<size_t>(), 0.0, 0, 0, 0, make_shared<Matrix<ElemType>>(numRow, (size_t) 1, deviceId)};
|
||||
for (size_t i = 0; i < m_nodesToCache.size(); i++)
|
||||
{
|
||||
vector<ElemType> v;
|
||||
oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared<PastValueNode<ElemType>>(deviceId, L"test");
|
||||
}
|
||||
return oneSeq;
|
||||
}
|
||||
Sequence newSeq(Sequence& a, DEVICEID_TYPE deviceId)
|
||||
|
@ -308,34 +300,10 @@ public:
|
|||
oneSeq.processlength = a.processlength;
|
||||
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
|
||||
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));
|
||||
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
|
||||
for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++)
|
||||
{
|
||||
auto itin = m_nameToPastValueNodeCache.find(it->first);
|
||||
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
|
||||
{
|
||||
oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back();
|
||||
m_nameToPastValueNodeCache[it->first].pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
oneSeq.nameToNodeValues[it->first] = make_shared<PastValueNode<ElemType>>(deviceId, it->first);
|
||||
}
|
||||
|
||||
it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
|
||||
}
|
||||
return oneSeq;
|
||||
}
|
||||
void deleteSeq(Sequence oneSeq)
|
||||
{
|
||||
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
|
||||
for (it = oneSeq.nameToNodeValues.begin(); it != oneSeq.nameToNodeValues.end(); it++)
|
||||
{
|
||||
auto itin = m_nameToPastValueNodeCache.find(it->first);
|
||||
if (itin == m_nameToPastValueNodeCache.end())
|
||||
m_nameToPastValueNodeCache[it->first] = vector<shared_ptr<PastValueNode<ElemType>>>();
|
||||
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
|
||||
}
|
||||
oneSeq.decodeoutput->ReleaseMemory();
|
||||
vector<size_t>().swap(oneSeq.labelseq);
|
||||
}
|
||||
|
@ -369,18 +337,7 @@ public:
|
|||
insequence.lengthwithblank++;
|
||||
}
|
||||
|
||||
std::vector<ComputationNodeBasePtr> GetNodesByNames(const std::vector<std::wstring>& names) const
|
||||
{
|
||||
std::vector<ComputationNodeBasePtr> nodes;
|
||||
for (size_t i = 0; i < names.size(); i++)
|
||||
{
|
||||
auto nodesByName = m_net->GetNodesFromName(names[i]);
|
||||
if (nodesByName.size() != 1)
|
||||
LogicError("Exactly one node is expected for name %ls", names[i]);
|
||||
nodes.push_back(nodesByName[0]);
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
void forward_decode(Sequence& oneSeq, StreamMinibatchInputs decodeinputMatrices, DEVICEID_TYPE deviceID, const std::vector<ComputationNodeBasePtr>& decodeOutputNodes,
|
||||
const std::vector<ComputationNodeBasePtr>& decodeinputNodes, size_t vocabSize, size_t plength)
|
||||
{
|
||||
|
@ -393,46 +350,23 @@ public:
|
|||
Matrix<ElemType> lmin(deviceID);
|
||||
|
||||
//greedyOutput.SetValue(greedyOutputMax.ColumnSlice(0, lmt));
|
||||
lmin.Resize(vocabSize, 1);
|
||||
lmin.Resize(vocabSize, plength);
|
||||
lmin.SetValue(0.0);
|
||||
lmin(oneSeq.labelseq[plength - 1], 0) = 1.0;
|
||||
for (size_t n = 0; n < plength; n++)
|
||||
{
|
||||
lmin(oneSeq.labelseq[n], n) = 1.0;
|
||||
}
|
||||
auto lminput = decodeinputMatrices.begin();
|
||||
lminput->second.pMBLayout->Init(1, 1);
|
||||
lminput->second.pMBLayout->Init(1, plength);
|
||||
//std::swap(lminput->second.GetMatrix<ElemType>(), lmin);
|
||||
lminput->second.GetMatrix<ElemType>().SetValue(lmin);
|
||||
if (plength == 1)
|
||||
{
|
||||
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, SentinelValueIndicatingUnspecifedSequenceBeginIdx, 1);
|
||||
}
|
||||
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
|
||||
bool shallowCopy = false;
|
||||
for (size_t i = 0; i < m_nodesToCache.size(); i++)
|
||||
{
|
||||
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
|
||||
if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0)
|
||||
{
|
||||
oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
|
||||
shallowCopy = true;
|
||||
}
|
||||
}
|
||||
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, plength);
|
||||
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
|
||||
DataReaderHelpers::NotifyChangedNodes<ElemType>(m_net, decodeinputMatrices);
|
||||
m_net->ForwardProp(decodeOutputNodes[0]);
|
||||
//Matrix<ElemType> tempMatrix = *(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value());
|
||||
oneSeq.decodeoutput->SetValue((*(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value())));
|
||||
oneSeq.decodeoutput->SetValue((*(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value())).ColumnSlice(plength - 1, 1));
|
||||
oneSeq.processlength = plength;
|
||||
for (size_t i = 0; i < m_nodesToCache.size(); i++)
|
||||
{
|
||||
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
|
||||
if (shallowCopy)
|
||||
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
|
||||
else
|
||||
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);
|
||||
}
|
||||
lmin.ReleaseMemory();
|
||||
}
|
||||
}
|
||||
|
@ -533,16 +467,6 @@ public:
|
|||
//get decode input matrix
|
||||
std::vector<std::wstring> decodeOutputNodeNames(outputNodeNames.begin() + 1, outputNodeNames.begin() + 2);
|
||||
std::vector<ComputationNodeBasePtr> decodeOutputNodes = m_net->OutputNodesByName(decodeOutputNodeNames);
|
||||
std::list<ComputationNodeBasePtr> pastValueNodes = m_net->PastValueNodesForOutputs(decodeOutputNodes);
|
||||
std::list<ComputationNodeBasePtr>::iterator it;
|
||||
for (it = pastValueNodes.begin(); it != pastValueNodes.end(); ++it)
|
||||
{
|
||||
auto pastValueNode = dynamic_pointer_cast<PastValueNode<ElemType>>(*it); //DelayedValueNodeBase
|
||||
if (pastValueNode || !(*it)->NodeName().compare(0, 5, L"Loop_"))
|
||||
{
|
||||
m_nodesToCache.push_back((*it)->NodeName());
|
||||
}
|
||||
}
|
||||
std::vector<ComputationNodeBasePtr> decodeinputNodes = m_net->InputNodesForOutputs(decodeOutputNodeNames);
|
||||
StreamMinibatchInputs decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes);
|
||||
|
||||
|
@ -1000,7 +924,6 @@ public:
|
|||
|
||||
private:
|
||||
ComputationNetworkPtr m_net;
|
||||
std::vector<wstring> m_nodesToCache;
|
||||
int m_verbosity;
|
||||
void operator=(const SimpleOutputWriter&); // (not assignable)
|
||||
};
|
||||
|
|
Загрузка…
Ссылка в новой задаче