Enable cuda evaluation of RNNT models

This commit is contained in:
Vadim Mazalov 2020-01-09 17:27:22 -08:00
Родитель 2ddfea26f6
Коммит 1386278827
1 изменённых файлов: 9 добавлений и 86 удалений

Просмотреть файл

@ -18,7 +18,6 @@
#include <cstdio>
#include "ProgressTracing.h"
#include "ComputationNetworkBuilder.h"
#include "RecurrentNodes.h"
#include <algorithm>
using namespace std;
@ -46,11 +45,9 @@ class SimpleOutputWriter
{
return logP < rhs.logP;
}
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>> nameToNodeValues;
};
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
typedef typename std::vector<Sequence>::iterator iterator;
unordered_map<wstring, vector<shared_ptr<PastValueNode<ElemType>>>> m_nameToPastValueNodeCache;
public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
@ -291,11 +288,6 @@ public:
Sequence newSeq(size_t numRow, size_t numCol, DEVICEID_TYPE deviceId)
{
Sequence oneSeq = {std::vector<size_t>(), 0.0, 0, 0, 0, make_shared<Matrix<ElemType>>(numRow, (size_t) 1, deviceId)};
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
vector<ElemType> v;
oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared<PastValueNode<ElemType>>(deviceId, L"test");
}
return oneSeq;
}
Sequence newSeq(Sequence& a, DEVICEID_TYPE deviceId)
@ -308,34 +300,10 @@ public:
oneSeq.processlength = a.processlength;
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++)
{
auto itin = m_nameToPastValueNodeCache.find(it->first);
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
{
oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back();
m_nameToPastValueNodeCache[it->first].pop_back();
}
else
{
oneSeq.nameToNodeValues[it->first] = make_shared<PastValueNode<ElemType>>(deviceId, it->first);
}
it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
}
return oneSeq;
}
void deleteSeq(Sequence oneSeq)
{
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
for (it = oneSeq.nameToNodeValues.begin(); it != oneSeq.nameToNodeValues.end(); it++)
{
auto itin = m_nameToPastValueNodeCache.find(it->first);
if (itin == m_nameToPastValueNodeCache.end())
m_nameToPastValueNodeCache[it->first] = vector<shared_ptr<PastValueNode<ElemType>>>();
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
}
oneSeq.decodeoutput->ReleaseMemory();
vector<size_t>().swap(oneSeq.labelseq);
}
@ -369,18 +337,7 @@ public:
insequence.lengthwithblank++;
}
std::vector<ComputationNodeBasePtr> GetNodesByNames(const std::vector<std::wstring>& names) const
{
std::vector<ComputationNodeBasePtr> nodes;
for (size_t i = 0; i < names.size(); i++)
{
auto nodesByName = m_net->GetNodesFromName(names[i]);
if (nodesByName.size() != 1)
LogicError("Exactly one node is expected for name %ls", names[i]);
nodes.push_back(nodesByName[0]);
}
return nodes;
}
void forward_decode(Sequence& oneSeq, StreamMinibatchInputs decodeinputMatrices, DEVICEID_TYPE deviceID, const std::vector<ComputationNodeBasePtr>& decodeOutputNodes,
const std::vector<ComputationNodeBasePtr>& decodeinputNodes, size_t vocabSize, size_t plength)
{
@ -393,46 +350,23 @@ public:
Matrix<ElemType> lmin(deviceID);
//greedyOutput.SetValue(greedyOutputMax.ColumnSlice(0, lmt));
lmin.Resize(vocabSize, 1);
lmin.Resize(vocabSize, plength);
lmin.SetValue(0.0);
lmin(oneSeq.labelseq[plength - 1], 0) = 1.0;
for (size_t n = 0; n < plength; n++)
{
lmin(oneSeq.labelseq[n], n) = 1.0;
}
auto lminput = decodeinputMatrices.begin();
lminput->second.pMBLayout->Init(1, 1);
lminput->second.pMBLayout->Init(1, plength);
//std::swap(lminput->second.GetMatrix<ElemType>(), lmin);
lminput->second.GetMatrix<ElemType>().SetValue(lmin);
if (plength == 1)
{
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
}
else
{
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, SentinelValueIndicatingUnspecifedSequenceBeginIdx, 1);
}
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
bool shallowCopy = false;
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0)
{
oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
shallowCopy = true;
}
}
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, plength);
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
DataReaderHelpers::NotifyChangedNodes<ElemType>(m_net, decodeinputMatrices);
m_net->ForwardProp(decodeOutputNodes[0]);
//Matrix<ElemType> tempMatrix = *(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value());
oneSeq.decodeoutput->SetValue((*(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value())));
oneSeq.decodeoutput->SetValue((*(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value())).ColumnSlice(plength - 1, 1));
oneSeq.processlength = plength;
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
if (shallowCopy)
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
else
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);
}
lmin.ReleaseMemory();
}
}
@ -533,16 +467,6 @@ public:
//get decode input matrix
std::vector<std::wstring> decodeOutputNodeNames(outputNodeNames.begin() + 1, outputNodeNames.begin() + 2);
std::vector<ComputationNodeBasePtr> decodeOutputNodes = m_net->OutputNodesByName(decodeOutputNodeNames);
std::list<ComputationNodeBasePtr> pastValueNodes = m_net->PastValueNodesForOutputs(decodeOutputNodes);
std::list<ComputationNodeBasePtr>::iterator it;
for (it = pastValueNodes.begin(); it != pastValueNodes.end(); ++it)
{
auto pastValueNode = dynamic_pointer_cast<PastValueNode<ElemType>>(*it); //DelayedValueNodeBase
if (pastValueNode || !(*it)->NodeName().compare(0, 5, L"Loop_"))
{
m_nodesToCache.push_back((*it)->NodeName());
}
}
std::vector<ComputationNodeBasePtr> decodeinputNodes = m_net->InputNodesForOutputs(decodeOutputNodeNames);
StreamMinibatchInputs decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes);
@ -1000,7 +924,6 @@ public:
private:
ComputationNetworkPtr m_net;
std::vector<wstring> m_nodesToCache;
int m_verbosity;
void operator=(const SimpleOutputWriter&); // (not assignable)
};