Fix the error during validation stage. Bug: when the anchor node is not the root in a strong component connection, the post visited order may initialize the forward computation incorrect.
This commit is contained in:
Родитель
12b8ec6e0f
Коммит
e773dc6264
|
@ -1746,8 +1746,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
for (ComputationNodePtr node : FinalCriterionNodes())
|
||||
{
|
||||
PrintComputationTree(node, false);
|
||||
if(!allowFragment) FormRecurentLoops(node);
|
||||
PrintComputationTree(node, false);
|
||||
size_t actualMBSize = this->GetActualMBSize();
|
||||
this->SetActualMiniBatchSize(actualMBSize);
|
||||
ValidateNetwork(node);
|
||||
|
@ -1760,8 +1760,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// now output nodes
|
||||
if (OutputNodes().size() > 0)
|
||||
{
|
||||
for (ComputationNodePtr node : OutputNodes())
|
||||
ValidateNetwork(node);
|
||||
for (ComputationNodePtr node : OutputNodes())
|
||||
{
|
||||
if (!allowFragment) FormRecurentLoops(node);
|
||||
ValidateNetwork(node);
|
||||
}
|
||||
}
|
||||
else if (!allowFragment)
|
||||
{
|
||||
|
@ -1770,8 +1773,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// now evaluation nodes
|
||||
if (EvaluationNodes().size() > 0)
|
||||
{
|
||||
for (ComputationNodePtr node : EvaluationNodes())
|
||||
ValidateNetwork(node);
|
||||
for (ComputationNodePtr node : EvaluationNodes())
|
||||
{
|
||||
if (!allowFragment) FormRecurentLoops(node);
|
||||
ValidateNetwork(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2040,6 +2046,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
std::vector<ComputationNodePtr> sourceLoopNodes;
|
||||
getStrongSCC(rootNode);
|
||||
std::list<ComputationNodePtr>& nodes = GetEvalOrder(rootNode, sourceLoopNodes);
|
||||
std::list<ComputationNodePtr> nodesForGrad;
|
||||
|
||||
/// debug purpose
|
||||
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
||||
|
@ -2081,7 +2088,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
for (auto iter = m_recurrentInfo.begin(); iter != m_recurrentInfo.end(); iter++)
|
||||
{
|
||||
// sort the recurrent nodes in their ascending name, which is the same as visiting nodes in G^R
|
||||
if ((*iter).m_recurrentNodes.size() > 1 && (*iter).m_recurrentNodesForForward.size() == 0)
|
||||
(*iter).m_recurrentNodesForForward.clear();
|
||||
if ((*iter).m_recurrentNodes.size() > 1)
|
||||
{
|
||||
std::list<ComputationNodePtr> result;
|
||||
std::unordered_set<ComputationNodePtr> visited;
|
||||
|
@ -2113,7 +2121,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
result.pop_front();
|
||||
}
|
||||
|
||||
|
||||
(*iter).m_recurrentNodes = (*iter).m_recurrentNodesForForward;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2125,12 +2133,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
std::list<ComputationNodePtr> noRecurrentNodes;
|
||||
|
||||
noRecurrentNodes = rootNode->ReshuffleNodes(recurrentNodes);
|
||||
|
||||
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
|
||||
|
||||
nodes.sort(IsSmaller);
|
||||
|
||||
ReorderLoops(nodes, recurrentNodes, noRecurrentNodes);
|
||||
|
||||
m_cacheEvalOrders[rootNode] = nodes;
|
||||
nodesForGrad = nodes;
|
||||
nodesForGrad.reverse();
|
||||
m_cacheGradientCalcOrders[rootNode] = nodesForGrad;
|
||||
|
||||
#ifdef DISPLAY_DEBUG
|
||||
fprintf(stderr, "Reordered nodes\n");
|
||||
|
@ -2150,13 +2161,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
std::list<ComputationNodePtr> vTmp;
|
||||
std::list<ComputationNodePtr> vRecurrentTmp;
|
||||
int prevId = -1;
|
||||
//int prevId = -1;
|
||||
vector<bool> accessed;
|
||||
accessed.assign(m_recurrentInfo.size(),false);
|
||||
for (auto nodeIter=nodes.begin(); nodeIter != nodes.end(); nodeIter++)
|
||||
{
|
||||
int iId = FindInRecurrentLoop(*nodeIter);
|
||||
if (iId >= 0)
|
||||
{
|
||||
if (prevId != iId && vRecurrentTmp.size() > 0)
|
||||
|
||||
if (! accessed[iId])
|
||||
{
|
||||
newList.insert(newList.end(), m_recurrentInfo[iId].m_recurrentNodes.begin(), m_recurrentInfo[iId].m_recurrentNodes.end());
|
||||
accessed[iId] = true;
|
||||
}
|
||||
|
||||
/*if (prevId != iId && vRecurrentTmp.size() > 0)
|
||||
{
|
||||
newList.insert(newList.end(), vRecurrentTmp.begin(), vRecurrentTmp.end());
|
||||
vRecurrentTmp.clear();
|
||||
|
@ -2170,11 +2190,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
vRecurrentTmp.push_back(*nodeIter);
|
||||
|
||||
prevId = iId;
|
||||
prevId = iId;*/
|
||||
}
|
||||
else
|
||||
{
|
||||
vTmp.push_back(*nodeIter);
|
||||
//vTmp.push_back(*nodeIter);
|
||||
newList.push_back(*nodeIter);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче