Support GetEvalOder call in MEL.

This commit is contained in:
Dong Yu 2016-08-15 16:41:16 -07:00
Родитель ce8ee688d8
Коммит fcb74cac78
2 изменённых файлов: 13 добавлений и 0 удалений

Просмотреть файл

@ -258,13 +258,20 @@ public:
m_evalOrders[rootNode] = nodes;
}
bool EvalOrderExists(const ComputationNodeBasePtr& rootNode) const
{
return m_evalOrders.find(rootNode) != m_evalOrders.end();
}
// get depth-first traversal order
// TODO: This is currently not immutable because it gets patched w.r.t. recurrent loops. Ideally we don't patch. Need to review and verify that it is sufficient.
const std::list<ComputationNodeBasePtr>& GetEvalOrder(const ComputationNodeBasePtr& rootNode) const
{
auto iter = m_evalOrders.find(rootNode);
if (iter == m_evalOrders.end())
{
LogicError("GetEvalOrder: Called without prior call to FormEvalOrder() for %ls %ls operation", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());
}
return iter->second;
}

Просмотреть файл

@ -76,6 +76,9 @@ void ComputationNetwork::CopySubTree(const ComputationNetwork& fromNet,
ComputationNodeBasePtr fromRoot = fromNet.GetNodeFromName(fromName);
if (!fromNet.EvalOrderExists(fromRoot))
const_cast<ComputationNetwork&>(fromNet).FormEvalOrder(fromRoot);
for (const auto& fromNode : fromNet.GetEvalOrder(fromRoot)) // BUGBUG: This probably will fail because the precomputed eval orders are invalid at this point.
{
wstring fromNodeName = fromNode->NodeName();
@ -353,6 +356,9 @@ void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const floa
else
{
// for calculating a specific node
if (!EvalOrderExists(rootNode))
const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode);
for (const auto& node : GetAllNodesForRoot(rootNode))
{
if (node->OperationName() == OperationNameOf(LearnableParameter))