Fix CntkEval.dll Related Issues

Consumers of the CntkEval interface can pass in feature values
without label values. Most of the code was okay with this, but
the SimpleOutputWriter insisted on initializing both features
and labels. This caused GetMinibatchIntoNetwork function in
DataReaderHelpers.h to fail, when the layout of the label node
was incorrect.

I've changed the SimpleOutputWriter to only use input nodes
that are ancestors of the desired output nodes. This fixes the
issue for CntkEval usage.
This commit is contained in:
Jasha Droppo 2016-01-28 17:26:13 -08:00
Родитель b8badf66b1
Коммит 91cf01e8d1
1 изменённых файлов: 8 добавлений и 8 удалений

Просмотреть файл

@ -54,13 +54,11 @@ public:
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
// specify feature value nodes
std::vector<ComputationNodeBasePtr>& featureNodes = m_net->FeatureNodes();
std::vector<ComputationNodeBasePtr>& labelNodes = m_net->LabelNodes();
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;
for (size_t i = 0; i < featureNodes.size(); i++)
inputMatrices[featureNodes[i]->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(featureNodes[i])->Value();
for (size_t i = 0; i < labelNodes.size(); i++)
inputMatrices[labelNodes[i]->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(labelNodes[i])->Value();
for (auto& onode : outputNodes)
for (auto& inode : m_net->InputNodes(onode))
inputMatrices[inode->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(inode)->Value();
// Matrix<ElemType> endOfFile = Matrix<ElemType>((size_t)1,(size_t)1);
// endOfFile(0,0)=0;
@ -76,8 +74,10 @@ public:
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
// Update timestamp for all input nodes ancestors of the output nodes
for (auto& onode : outputNodes)
for (auto& inode : m_net->InputNodes(onode))
inode->BumpEvalTimeStamp();
for (int i = 0; i < outputNodes.size(); i++)
{