Fix the error during validation stage. Bug: when the anchor node is not the root in a strong component connection, the post visited order may initialize the forward computation incorrect.

This commit is contained in:
yzhang87 2015-01-03 01:55:27 -05:00
Родитель 12b8ec6e0f
Коммит e773dc6264
1 изменённых файлов: 34 добавлений и 13 удалений

Просмотреть файл

@ -1746,8 +1746,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
for (ComputationNodePtr node : FinalCriterionNodes())
{
PrintComputationTree(node, false);
if(!allowFragment) FormRecurentLoops(node);
PrintComputationTree(node, false);
size_t actualMBSize = this->GetActualMBSize();
this->SetActualMiniBatchSize(actualMBSize);
ValidateNetwork(node);
@ -1760,8 +1760,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now output nodes
if (OutputNodes().size() > 0)
{
for (ComputationNodePtr node : OutputNodes())
ValidateNetwork(node);
for (ComputationNodePtr node : OutputNodes())
{
if (!allowFragment) FormRecurentLoops(node);
ValidateNetwork(node);
}
}
else if (!allowFragment)
{
@ -1770,8 +1773,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now evaluation nodes
if (EvaluationNodes().size() > 0)
{
for (ComputationNodePtr node : EvaluationNodes())
ValidateNetwork(node);
for (ComputationNodePtr node : EvaluationNodes())
{
if (!allowFragment) FormRecurentLoops(node);
ValidateNetwork(node);
}
}
}
@ -2040,6 +2046,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::vector<ComputationNodePtr> sourceLoopNodes;
getStrongSCC(rootNode);
std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes);
std::list<ComputationNodePtr> nodesForGrad;
/// debug purpose
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
@ -2081,7 +2088,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
{
// sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R
if ((*iter).m_recurrentNodes.size() > 1 && (*iter).m_recurrentNodesForForward.size() == 0)
(*iter).m_recurrentNodesForForward.clear();
if ((*iter).m_recurrentNodes.size() > 1)
{
std::list<ComputationNodePtr> result;
std::unordered_set<ComputationNodePtr> visited;
@ -2113,7 +2121,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
result.pop_front();
}
(*iter).m_recurrentNodes = (*iter).m_recurrentNodesForForward;
}
}
@ -2125,12 +2133,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::list<ComputationNodePtr> noRecurrentNodes;
noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes);
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
nodes.sort(IsSmaller);
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
m_cacheEvalOrders[rootNode] = nodes;
nodesForGrad = nodes;
nodesForGrad.reverse();
m_cacheGradientCalcOrders[rootNode] = nodesForGrad;
#ifdef DISPLAY_DEBUG
fprintf(stderr, "Reordered nodes\n");
@ -2150,13 +2161,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::list<ComputationNodePtr> vTmp;
std::list<ComputationNodePtr> vRecurrentTmp;
int prevId = -1;
//int prevId = -1;
vector<bool> accessed;
accessed.assign(m_recurrentInfo.size(),false);
for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++)
{
int iId = FindInRecurrentLoop(*nodeIter);
if (iId >= 0)
{
if (prevId != iId && vRecurrentTmp.size() > 0)
if (! accessed[iId])
{
newList.insert(newList.end(), m_recurrentInfo[iId].m_recurrentNodes.begin(), m_recurrentInfo[iId].m_recurrentNodes.end());
accessed[iId] = true;
}
/*if (prevId != iId && vRecurrentTmp.size() > 0)
{
newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end());
vRecurrentTmp.clear();
@ -2170,11 +2190,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
vRecurrentTmp.push_back(*nodeIter);
prevId = iId;
prevId = iId;*/
}
else
{
vTmp.push_back(*nodeIter);
//vTmp.push_back(*nodeIter);
newList.push_back(*nodeIter);
}
}