Fix the error during validation stage. Bug: when the anchor node is not the root in a strong component connection, the post visited order may initialize the forward computation incorrect.
This commit is contained in:
Родитель
12b8ec6e0f
Коммит
e773dc6264
|
@ -1746,8 +1746,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
{
|
{
|
||||||
for (ComputationNodePtr node : FinalCriterionNodes())
|
for (ComputationNodePtr node : FinalCriterionNodes())
|
||||||
{
|
{
|
||||||
PrintComputationTree(node, false);
|
|
||||||
if(!allowFragment) FormRecurentLoops(node);
|
if(!allowFragment) FormRecurentLoops(node);
|
||||||
|
PrintComputationTree(node, false);
|
||||||
size_t actualMBSize = this->GetActualMBSize();
|
size_t actualMBSize = this->GetActualMBSize();
|
||||||
this->SetActualMiniBatchSize(actualMBSize);
|
this->SetActualMiniBatchSize(actualMBSize);
|
||||||
ValidateNetwork(node);
|
ValidateNetwork(node);
|
||||||
|
@ -1760,8 +1760,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
// now output nodes
|
// now output nodes
|
||||||
if (OutputNodes().size() > 0)
|
if (OutputNodes().size() > 0)
|
||||||
{
|
{
|
||||||
for (ComputationNodePtr node : OutputNodes())
|
for (ComputationNodePtr node : OutputNodes())
|
||||||
ValidateNetwork(node);
|
{
|
||||||
|
if (!allowFragment) FormRecurentLoops(node);
|
||||||
|
ValidateNetwork(node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (!allowFragment)
|
else if (!allowFragment)
|
||||||
{
|
{
|
||||||
|
@ -1770,8 +1773,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
// now evaluation nodes
|
// now evaluation nodes
|
||||||
if (EvaluationNodes().size() > 0)
|
if (EvaluationNodes().size() > 0)
|
||||||
{
|
{
|
||||||
for (ComputationNodePtr node : EvaluationNodes())
|
for (ComputationNodePtr node : EvaluationNodes())
|
||||||
ValidateNetwork(node);
|
{
|
||||||
|
if (!allowFragment) FormRecurentLoops(node);
|
||||||
|
ValidateNetwork(node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2040,6 +2046,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
std::vector<ComputationNodePtr> sourceLoopNodes;
|
std::vector<ComputationNodePtr> sourceLoopNodes;
|
||||||
getStrongSCC(rootNode);
|
getStrongSCC(rootNode);
|
||||||
std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes);
|
std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes);
|
||||||
|
std::list<ComputationNodePtr> nodesForGrad;
|
||||||
|
|
||||||
/// debug purpose
|
/// debug purpose
|
||||||
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
||||||
|
@ -2081,7 +2088,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
||||||
{
|
{
|
||||||
// sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R
|
// sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R
|
||||||
if ((*iter).m_recurrentNodes.size() > 1 && (*iter).m_recurrentNodesForForward.size() == 0)
|
(*iter).m_recurrentNodesForForward.clear();
|
||||||
|
if ((*iter).m_recurrentNodes.size() > 1)
|
||||||
{
|
{
|
||||||
std::list<ComputationNodePtr> result;
|
std::list<ComputationNodePtr> result;
|
||||||
std::unordered_set<ComputationNodePtr> visited;
|
std::unordered_set<ComputationNodePtr> visited;
|
||||||
|
@ -2113,7 +2121,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
result.pop_front();
|
result.pop_front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(*iter).m_recurrentNodes = (*iter).m_recurrentNodesForForward;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2125,12 +2133,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
std::list<ComputationNodePtr> noRecurrentNodes;
|
std::list<ComputationNodePtr> noRecurrentNodes;
|
||||||
|
|
||||||
noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes);
|
noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes);
|
||||||
|
|
||||||
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
|
|
||||||
|
|
||||||
nodes.sort(IsSmaller);
|
nodes.sort(IsSmaller);
|
||||||
|
|
||||||
|
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
|
||||||
|
|
||||||
m_cacheEvalOrders[rootNode] = nodes;
|
m_cacheEvalOrders[rootNode] = nodes;
|
||||||
|
nodesForGrad = nodes;
|
||||||
|
nodesForGrad.reverse();
|
||||||
|
m_cacheGradientCalcOrders[rootNode] = nodesForGrad;
|
||||||
|
|
||||||
#ifdef DISPLAY_DEBUG
|
#ifdef DISPLAY_DEBUG
|
||||||
fprintf(stderr, "Reordered nodes\n");
|
fprintf(stderr, "Reordered nodes\n");
|
||||||
|
@ -2150,13 +2161,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
std::list<ComputationNodePtr> vTmp;
|
std::list<ComputationNodePtr> vTmp;
|
||||||
std::list<ComputationNodePtr> vRecurrentTmp;
|
std::list<ComputationNodePtr> vRecurrentTmp;
|
||||||
int prevId = -1;
|
//int prevId = -1;
|
||||||
|
vector<bool> accessed;
|
||||||
|
accessed.assign(m_recurrentInfo.size(),false);
|
||||||
for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++)
|
for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++)
|
||||||
{
|
{
|
||||||
int iId = FindInRecurrentLoop(*nodeIter);
|
int iId = FindInRecurrentLoop(*nodeIter);
|
||||||
if (iId >= 0)
|
if (iId >= 0)
|
||||||
{
|
{
|
||||||
if (prevId != iId && vRecurrentTmp.size() > 0)
|
|
||||||
|
if (! accessed[iId])
|
||||||
|
{
|
||||||
|
newList.insert(newList.end(), m_recurrentInfo[iId].m_recurrentNodes.begin(), m_recurrentInfo[iId].m_recurrentNodes.end());
|
||||||
|
accessed[iId] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*if (prevId != iId && vRecurrentTmp.size() > 0)
|
||||||
{
|
{
|
||||||
newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end());
|
newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end());
|
||||||
vRecurrentTmp.clear();
|
vRecurrentTmp.clear();
|
||||||
|
@ -2170,11 +2190,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
vRecurrentTmp.push_back(*nodeIter);
|
vRecurrentTmp.push_back(*nodeIter);
|
||||||
|
|
||||||
prevId = iId;
|
prevId = iId;*/
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
vTmp.push_back(*nodeIter);
|
//vTmp.push_back(*nodeIter);
|
||||||
|
newList.push_back(*nodeIter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче