renamed -StrongSCC- to -SCC- since the S already stands for 'Strong'
This commit is contained in:
Родитель
f2145fd949
Коммит
34bbf53e0e
|
@ -712,8 +712,8 @@ private:
|
|||
|
||||
// This is part of the FormRecurrentLoops() process, and only called from there.
|
||||
void FormRecurrentLoops(const ComputationNodeBasePtr& rootNode);
|
||||
void DetermineStrongSCCs(const ComputationNodeBasePtr& rootNode);
|
||||
void DetermineStrongSCCsR(ComputationNodeBasePtr cur, std::list<ComputationNodeBasePtr>& sccStack, size_t& index, size_t& loopId);
|
||||
void DetermineSCCs(const ComputationNodeBasePtr& rootNode);
|
||||
void DetermineSCCsR(ComputationNodeBasePtr cur, std::list<ComputationNodeBasePtr>& sccStack, size_t& index, size_t& loopId);
|
||||
void UniqRecurrentLoops();
|
||||
void DetermineLoopForwardOrder(std::unordered_set<ComputationNodeBasePtr>& visited, std::unordered_set<ComputationNodeBasePtr>& recStack, std::list<ComputationNodeBasePtr>& nodesStack, ComputationNodeBasePtr cur);
|
||||
void GatherLoopNodesR(const ComputationNodeBasePtr& rootNode, std::unordered_set<ComputationNodeBasePtr>& visited, std::map<int, std::list<ComputationNodeBasePtr>>& recurrentResult, std::list<ComputationNodeBasePtr>& noRecurrentResult);
|
||||
|
|
|
@ -34,10 +34,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Is often called before ValidateNetwork() on a root; will be called from inside ValidateNetwork() as well.
|
||||
// This function is called for multiple nodes, e.g. eval and training criterion. I.e. it must be able to add to a previous result. E.g. it does not clear the m_visited flags at start. This seems brittle.
|
||||
// BUGBUG: m_visited is also used by ValidateSubNetwork(). Hence, it may be in unexpected state when calling into this multiple times.
|
||||
// BUGBUG: This currently does not handle nested loops. To handle that:
|
||||
// - loops are isolated by a ReconcileMBLayout--loop determination should see right through it, and then include everything inside
|
||||
// - ...? Need to figure this out.
|
||||
void ComputationNetwork::FormRecurrentLoops(const ComputationNodeBasePtr& rootNode)
|
||||
{
|
||||
// determine the strongly connected cliques -> m_recurrentInfo[]
|
||||
DetermineStrongSCCs(rootNode);
|
||||
DetermineSCCs(rootNode);
|
||||
|
||||
list<ComputationNodeBasePtr>& nodes = GetEvalOrder(rootNode, true/*set m_visitedOrder*/);
|
||||
|
||||
|
@ -170,18 +173,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// get the strongly connected components from the graph
|
||||
// This sets index, lowLink, m_visited, and m_inStack.
|
||||
void ComputationNetwork::DetermineStrongSCCs(const ComputationNodeBasePtr& rootNode)
|
||||
void ComputationNetwork::DetermineSCCs(const ComputationNodeBasePtr& rootNode)
|
||||
{
|
||||
// notice that this graph including graphs from a parent networks if two or more networks are connected via PairNetworkNode
|
||||
list<ComputationNodeBasePtr> sccStack;
|
||||
size_t index = 0;
|
||||
size_t loopId = 0;
|
||||
if (!rootNode->m_visited)
|
||||
DetermineStrongSCCsR(rootNode, sccStack, index, loopId);
|
||||
DetermineSCCsR(rootNode, sccStack, index, loopId);
|
||||
}
|
||||
|
||||
// (recursive part of DetermineStrongSCCs())
|
||||
void ComputationNetwork::DetermineStrongSCCsR(ComputationNodeBasePtr cur,
|
||||
// (recursive part of DetermineSCCs())
|
||||
void ComputationNetwork::DetermineSCCsR(ComputationNodeBasePtr cur,
|
||||
list<ComputationNodeBasePtr>& sccStack,
|
||||
size_t& index, size_t& loopId)
|
||||
{
|
||||
|
@ -203,7 +206,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (!cur->Inputs(i)->m_visited)
|
||||
{
|
||||
DetermineStrongSCCsR(cur->Inputs(i), sccStack, index, loopId);
|
||||
DetermineSCCsR(cur->Inputs(i), sccStack, index, loopId);
|
||||
cur->m_lowLink = min(cur->m_lowLink, cur->Inputs(i)->m_lowLink);
|
||||
}
|
||||
else if (cur->Inputs(i)->m_inStack)
|
||||
|
|
|
@ -200,8 +200,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
int m_visitedOrder; // remembers order in which nodes were visited by EnumerateNodes(), but gets updated
|
||||
bool m_visited; // note: also used by ValidateSubNetwork()
|
||||
int m_indexInLoop;
|
||||
// only used inside DetermineStrongSCCs():
|
||||
int m_index; // index denoting order in which nodes were visited in DetermineStrongSCCs()
|
||||
// only used inside DetermineSCCs():
|
||||
int m_index; // index denoting order in which nodes were visited in DetermineSCCs()
|
||||
int m_lowLink; // min of m_index over all nodes within a single loop
|
||||
bool m_inStack;
|
||||
};
|
||||
|
@ -941,25 +941,30 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
public:
|
||||
static bool MaskMissingColumnsToZero(Matrix<ElemType>& matrixToBeMasked, const MBLayoutPtr & pMBLayout, const FrameRange & frameRange)
|
||||
{
|
||||
//fprintf(stderr, "masking column range %d\n", (int)frameRange.timeIdxInSeq);
|
||||
return MaskMissingColumnsTo(matrixToBeMasked, pMBLayout, frameRange, (ElemType)0);
|
||||
}
|
||||
|
||||
void /*ComputationNodeBase::*/MaskMissingValuesColumnsToZero(const FrameRange & frameRange) override final
|
||||
{
|
||||
//fprintf(stderr, "%ls %ls m_functionValues ", NodeName().c_str(), OperationName().c_str());
|
||||
MaskMissingColumnsToZero(*m_functionValues, m_pMBLayout, frameRange);
|
||||
}
|
||||
void /*ComputationNodeBase::*/MaskMissingGradientColumnsToZero(const FrameRange & frameRange) override final
|
||||
{
|
||||
//fprintf(stderr, "%ls %ls m_gradientValues ", NodeName().c_str(), OperationName().c_str());
|
||||
MaskMissingColumnsToZero(*m_gradientValues, m_pMBLayout, frameRange);
|
||||
}
|
||||
|
||||
// for debugging, set the gaps to NaN instead (to track whether it bubbles up somewhere)
|
||||
void InvalidateMissingValuesColumns(const FrameRange & frameRange) override final
|
||||
{
|
||||
//fprintf(stderr, "invalidating %ls %ls m_functionValues column range %d\n", NodeName().c_str(), OperationName().c_str(), (int)frameRange.timeIdxInSeq);
|
||||
MaskMissingColumnsTo(*m_functionValues, m_pMBLayout, frameRange, Matrix<ElemType>::MakeNan(__LINE__));
|
||||
}
|
||||
void InvalidateMissingGradientColumns(const FrameRange & frameRange) override final
|
||||
{
|
||||
//fprintf(stderr, "invalidating %ls %ls m_gradientValues column range %d\n", NodeName().c_str(), OperationName().c_str(), (int)frameRange.timeIdxInSeq);
|
||||
MaskMissingColumnsTo(*m_gradientValues, m_pMBLayout, frameRange, Matrix<ElemType>::MakeNan(__LINE__));
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче