Fix the error during validation stage. Bug: when the anchor node is not the root in a strong component connection, the post visited order may initialize the forward computation incorrect.

This commit is contained in:
yzhang87 2015-01-03 01:55:27 -05:00
Родитель 12b8ec6e0f
Коммит e773dc6264
1 изменённых файлов: 34 добавлений и 13 удалений

Просмотреть файл

@ -1746,8 +1746,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{ {
for (ComputationNodePtr node : FinalCriterionNodes()) for (ComputationNodePtr node : FinalCriterionNodes())
{ {
PrintComputationTree(node, false);
if(!allowFragment) FormRecurentLoops(node); if(!allowFragment) FormRecurentLoops(node);
PrintComputationTree(node, false);
size_t actualMBSize = this->GetActualMBSize(); size_t actualMBSize = this->GetActualMBSize();
this->SetActualMiniBatchSize(actualMBSize); this->SetActualMiniBatchSize(actualMBSize);
ValidateNetwork(node); ValidateNetwork(node);
@ -1760,8 +1760,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now output nodes // now output nodes
if (OutputNodes().size() > 0) if (OutputNodes().size() > 0)
{ {
for (ComputationNodePtr node : OutputNodes()) for (ComputationNodePtr node : OutputNodes())
ValidateNetwork(node); {
if (!allowFragment) FormRecurentLoops(node);
ValidateNetwork(node);
}
} }
else if (!allowFragment) else if (!allowFragment)
{ {
@ -1770,8 +1773,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now evaluation nodes // now evaluation nodes
if (EvaluationNodes().size() > 0) if (EvaluationNodes().size() > 0)
{ {
for (ComputationNodePtr node : EvaluationNodes()) for (ComputationNodePtr node : EvaluationNodes())
ValidateNetwork(node); {
if (!allowFragment) FormRecurentLoops(node);
ValidateNetwork(node);
}
} }
} }
@ -2040,6 +2046,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::vector<ComputationNodePtr> sourceLoopNodes; std::vector<ComputationNodePtr> sourceLoopNodes;
getStrongSCC(rootNode); getStrongSCC(rootNode);
std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes); std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes);
std::list<ComputationNodePtr> nodesForGrad;
/// debug purpose /// debug purpose
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++) for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
@ -2081,7 +2088,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++) for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
{ {
// sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R // sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R
if ((*iter).m_recurrentNodes.size() > 1 && (*iter).m_recurrentNodesForForward.size() == 0) (*iter).m_recurrentNodesForForward.clear();
if ((*iter).m_recurrentNodes.size() > 1)
{ {
std::list<ComputationNodePtr> result; std::list<ComputationNodePtr> result;
std::unordered_set<ComputationNodePtr> visited; std::unordered_set<ComputationNodePtr> visited;
@ -2113,7 +2121,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
result.pop_front(); result.pop_front();
} }
(*iter).m_recurrentNodes = (*iter).m_recurrentNodesForForward;
} }
} }
@ -2125,12 +2133,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::list<ComputationNodePtr> noRecurrentNodes; std::list<ComputationNodePtr> noRecurrentNodes;
noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes); noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes);
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
nodes.sort(IsSmaller); nodes.sort(IsSmaller);
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
m_cacheEvalOrders[rootNode] = nodes; m_cacheEvalOrders[rootNode] = nodes;
nodesForGrad = nodes;
nodesForGrad.reverse();
m_cacheGradientCalcOrders[rootNode] = nodesForGrad;
#ifdef DISPLAY_DEBUG #ifdef DISPLAY_DEBUG
fprintf(stderr, "Reordered nodes\n"); fprintf(stderr, "Reordered nodes\n");
@ -2150,13 +2161,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::list<ComputationNodePtr> vTmp; std::list<ComputationNodePtr> vTmp;
std::list<ComputationNodePtr> vRecurrentTmp; std::list<ComputationNodePtr> vRecurrentTmp;
int prevId = -1; //int prevId = -1;
vector<bool> accessed;
accessed.assign(m_recurrentInfo.size(),false);
for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++) for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++)
{ {
int iId = FindInRecurrentLoop(*nodeIter); int iId = FindInRecurrentLoop(*nodeIter);
if (iId >= 0) if (iId >= 0)
{ {
if (prevId != iId && vRecurrentTmp.size() > 0)
if (! accessed[iId])
{
newList.insert(newList.end(), m_recurrentInfo[iId].m_recurrentNodes.begin(), m_recurrentInfo[iId].m_recurrentNodes.end());
accessed[iId] = true;
}
/*if (prevId != iId && vRecurrentTmp.size() > 0)
{ {
newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end()); newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end());
vRecurrentTmp.clear(); vRecurrentTmp.clear();
@ -2170,11 +2190,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
vRecurrentTmp.push_back(*nodeIter); vRecurrentTmp.push_back(*nodeIter);
prevId = iId; prevId = iId;*/
} }
else else
{ {
vTmp.push_back(*nodeIter); //vTmp.push_back(*nodeIter);
newList.push_back(*nodeIter);
} }
} }