Support GetEvalOder call in MEL.
This commit is contained in:
Родитель
ce8ee688d8
Коммит
fcb74cac78
|
@ -258,13 +258,20 @@ public:
|
|||
m_evalOrders[rootNode] = nodes;
|
||||
}
|
||||
|
||||
bool EvalOrderExists(const ComputationNodeBasePtr& rootNode) const
|
||||
{
|
||||
return m_evalOrders.find(rootNode) != m_evalOrders.end();
|
||||
}
|
||||
|
||||
// get depth-first traversal order
|
||||
// TODO: This is currently not immutable because it gets patched w.r.t. recurrent loops. Ideally we don't patch. Need to review and verify that it is sufficient.
|
||||
const std::list<ComputationNodeBasePtr>& GetEvalOrder(const ComputationNodeBasePtr& rootNode) const
|
||||
{
|
||||
auto iter = m_evalOrders.find(rootNode);
|
||||
if (iter == m_evalOrders.end())
|
||||
{
|
||||
LogicError("GetEvalOrder: Called without prior call to FormEvalOrder() for %ls %ls operation", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
|
|
|
@ -76,6 +76,9 @@ void ComputationNetwork::CopySubTree(const ComputationNetwork& fromNet,
|
|||
|
||||
ComputationNodeBasePtr fromRoot = fromNet.GetNodeFromName(fromName);
|
||||
|
||||
if (!fromNet.EvalOrderExists(fromRoot))
|
||||
const_cast<ComputationNetwork&>(fromNet).FormEvalOrder(fromRoot);
|
||||
|
||||
for (const auto& fromNode : fromNet.GetEvalOrder(fromRoot)) // BUGBUG: This probably will fail because the precomputed eval orders are invalid at this point.
|
||||
{
|
||||
wstring fromNodeName = fromNode->NodeName();
|
||||
|
@ -353,6 +356,9 @@ void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const floa
|
|||
else
|
||||
{
|
||||
// for calculating a specific node
|
||||
if (!EvalOrderExists(rootNode))
|
||||
const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode);
|
||||
|
||||
for (const auto& node : GetAllNodesForRoot(rootNode))
|
||||
{
|
||||
if (node->OperationName() == OperationNameOf(LearnableParameter))
|
||||
|
|
Загрузка…
Ссылка в новой задаче