CNTK/Source/ComputationNetworkLib/ComputationNetworkEvaluatio...

849 строки
39 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#include "Basics.h"
#include "ComputationNode.h"
#include "ComputationNetwork.h"
#include "RecurrentNodes.h"
#include "InputAndParamNodes.h"
#include <string>
#include <vector>
#include <list>
#include <set>
#include <algorithm>
#include <map>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
// This source file contains methods related to evaluation (forward prop, backprop), network validation, and matrix memory allocation (memory sharing).
// -----------------------------------------------------------------------
// forward and backward propagation
// -----------------------------------------------------------------------
// MAIN ENTRY POINT for evaluating one minibatch (forward prop)
// This calls ForwardProp() on all nodes in order of data flow through the network.
// By default, the network is applied concurrently on all frames in a minibatch in parallel (PAR mode, a "map" operation)
// Recurrent loops must be treated differently:
// - a recurrent loop is the loop of nodes that make up computation for one time step (e.g. Times -> Plus -> Sigmoid -> Delay)
// - these must be executed frame by frame (SEQuential) rather than as a map
// - such a loop is treated as if they were a little nested network; this is done inside SEQTraversalFlowControlNodes
// - these little nested networks are defined in the execution network in the form of nested sentinel nodes of type SEQTraversalFlowControlNode
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode)
{
VerifyIsCompiled("ForwardProp");
// traverse all nodes in the pre-determined evaluation order
GetNestedNetwork(rootNode)->ForwardProp(FrameRange(nullptr));
}
// set the gradient matrix of a node to an 1x1 matrix containing 1.0
// Returns false if the node is not a ComputationNode<ElemType>.
template <class ElemType>
static bool SetGradientToScalarOne(ComputationNodeBasePtr nodep)
{
auto node = dynamic_pointer_cast<ComputationNode<ElemType>>(nodep);
bool hasMatchingType = (node != nullptr);
if (hasMatchingType)
{
node->Value().VerifySize(1, 1);
node->Gradient().Resize(1, 1);
node->Gradient().SetValue((ElemType) 1.0);
}
return hasMatchingType;
}
// MAIN ENTRY POINT for evaluation followed by gradient computation (forward prop then back prop)
// The typical calling pattern is:
// - ForwardProp() for eval nodes
// - ForwardProp() for the training criterion (which will reuse computation results from the previous step)
// - Backprop() for the training criterion
void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // training criterion to compute the gradients for
{
// reset all gradients to zero (actually, internally, this is lazy, but we don't care here)
ZeroGradients(rootNode);
// initialize root gradient with a scalar value of 1.0
if (!SetGradientToScalarOne<float>(rootNode) && !SetGradientToScalarOne<double>(rootNode))
LogicError("Backprop: Training criterion is neither ComputationNode<float> nor ComputationNode<double>.");
// backpropagate through the network
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
}
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
{
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
fprintf(stderr, "FormNestedNetwork: WARNING: Was called twice for %ls %ls operation\n", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());
m_nestedNetworks[rootNode] = make_shared<PARTraversalFlowControlNode>(m_allSEQNodes, GetEvalOrder(rootNode));
}
ComputationNodeBasePtr ComputationNetwork::GetNestedNetwork(const ComputationNodeBasePtr& rootNode)
{
if (m_nestedNetworks.find(rootNode) == m_nestedNetworks.end())
LogicError("GetNestedNetwork: Called without prior call to FormNestedNetwork() for %ls %ls operation", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());
return m_nestedNetworks[rootNode];
}
// -----------------------------------------------------------------------
// PARTraversalFlowControlNode methods -- implements PAR traversal
//
// This implements an outer loop over non-recurrent nodes, where each node can be
// executed in PAR mode; that is, all samples are independent and allow for
// concurrent computation in bulk CUDA launches.
// -----------------------------------------------------------------------
ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(const std::vector<shared_ptr<SEQTraversalFlowControlNode>>& recurrentInfo, const std::list<ComputationNodeBasePtr>& allNodes /*must be in eval order*/)
{
// traverse the network in evaluation order and create a new list that replaces all recurrence by a SEQTraversalFlowControlNode
set<shared_ptr<IComputationNode>> loopsSeen; // for consistency check only
for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end();)
{
shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(recurrentInfo, *nodeIter); // check if this node participates in a recurrent loop
if (recInfo) // node is part of a SEQ loop: gather all of them. The nodes must be consecutive in 'allNodes'
{
// instead of the node itself, include the sentinel SEQTraversalFlowControlNode in our list
m_nestedNodes.push_back(recInfo);
// and verify that we only encountered the loop once (all nodes should have been consecutive)
if (!loopsSeen.insert(recInfo).second)
LogicError("PARTraversalFlowControlNode: members of loop %ls are not consecutive in node list.", recInfo->NodeName().c_str());
// consume all nodes that are part of the same loop (they are all consecutive)
while (nodeIter != allNodes.end() && (*nodeIter)->IsPartOfLoop() && FindInRecurrentLoops(recurrentInfo, *nodeIter) == recInfo)
nodeIter++;
}
else // regular top-level node (non-looping, PAR)
{
m_nestedNodes.push_back(*nodeIter);
nodeIter++; // and consume this node
}
}
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange& fr) /*override*/
{
for (auto& node : m_nestedNodes)
{
if (node->IsOutOfDateWrtInputs())
{
node->BeginForwardProp();
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
node->EndForwardProp();
node->BumpEvalTimeStamp();
}
}
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
{
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
// process nodes in pre-determined order
for (auto pnode = m_nestedNodes.rbegin(); pnode != m_nestedNodes.rend(); pnode++) // iterate backwards over evaluation order
{
auto& node = *pnode;
node->BeginBackprop();
node->Backprop(fr.WithLayout(node->GetMBLayout()), true /*childrenInThisLoop*/, true /*childrenInOuterLoop*/);
node->EndBackprop();
}
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
{
}
// -----------------------------------------------------------------------
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)
//
// While PAR mode processes all samples in the MB independently, and thus in
// PARallel, SEQ mode is to honor sequential dependencies. As such, it
// unrolls the loop over time steps and runs the network once per time step.
// -----------------------------------------------------------------------
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginForwardProp() /*override*/
{
// take the opportunity to check that layout is shared by all nodes in the loop
// TODO: we should do this in a constructor.
for (auto& node : m_nestedNodes)
{
if (node->GetMBLayout() != GetMBLayout())
LogicError("Evaluate: all nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes '%ls' vs. '%ls'",
node->NodeName().c_str(), m_nestedNodes[0]->NodeName().c_str());
}
// tell all that loop is about to commence
for (auto& node : m_nestedNodes)
node->BeginForwardProp();
}
// evaluation of a SEQTraversalFlowControlNode FlowControlNode
// This evaluates all nodes in this FlowControlNode in SEQ mode: process the loop frame by frame in a nested loop.
// This is where the time axis changes.
// TODO: Once we do nested loops, then the FrameRange argument to this will refer to the outer loop.
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ForwardProp(const FrameRange&) /*override*/
{
// get layout associated with this loop
// All nodes share the same layout.
assert(GetMBLayout() == m_nestedNodes[0]->GetMBLayout());
// for every time step run through all nodes in this particular loop (treat the loop like a little ComputationNetwork)
// Note: Currently, this is limited to linear-time loops. But nothing stops the iteration below to, e.g., be a 2D iteration over an image
// if we implement an according FrameRangeIteration.
FrameRangeIteration range(GetMBLayout(), m_steppingDirection);
for (auto t = range.begin(); t != range.end(); t++)
{
for (auto& node : m_nestedNodes)
{
node->ForwardProp(t);
node->BumpEvalTimeStamp();
}
}
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndForwardProp() /*override*/
{
// tell all that loop is done --e.g. PastValueNode will capture its state for BPTT processing
for (auto& node : m_nestedNodes)
node->EndForwardProp();
}
// called before first iteration step of ComputeGradient()
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginBackprop() /*override*/
{
for (auto& node2 : m_nestedNodes)
node2->BeginBackprop();
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::Backprop(const FrameRange&, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
{
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
const auto& recurrentNodes = m_nestedNodes; // BUGBUG: -ForForward?? Does this mean we can remove non-ForForward?
auto pMBLayout = recurrentNodes[0]->GetMBLayout();
FrameRangeIteration range(pMBLayout, m_steppingDirection);
for (auto t = range.rbegin(); t != range.rend(); t++) // note: reverse iteration
{
for (auto nodeIter2 = recurrentNodes.rbegin(); nodeIter2 != recurrentNodes.rend(); ++nodeIter2)
{
auto& node2 = *nodeIter2;
node2->Backprop(t, true /*childrenInThisLoop*/, false /*childrenInOuterLoop*/);
// The above flags tell Backprop() to skip back-propagation from inside a node into
// a node that is outside the loop, which is done later in EndBackprop() in PAR mode.
}
}
}
// called after last iteration step of ComputeGradient()
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndBackprop() /*override*/
{
// The following loop handles the case that a node inside the loop back-propagates a gradient into a node outside of the loop.
// For efficiency, we perform this outside the loop in PAR mode. E.g., in one LSTM speech setup, we measured 12..14% overall speed-up.
for (auto nodeIter2 = m_nestedNodes.rbegin(); nodeIter2 != m_nestedNodes.rend(); ++nodeIter2)
{
auto& node2 = *nodeIter2;
node2->Backprop(FrameRange(m_nestedNodes[0]->GetMBLayout()), false /*childrenInThisLoop*/, true /*childrenInOuterLoop*/);
}
// tell all nodes we are done for this iteraTion
for (auto& node2 : m_nestedNodes)
node2->EndBackprop();
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/
{
for (auto& nodeLoopIter : m_nestedNodes)
nodeLoopIter->RequestMatricesBeforeForwardProp(matrixPool);
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
{
// TODO: should we deallocate in opposite order?
for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
{
(*nodeIter)->AllocateGradientMatricesForInputs(matrixPool);
}
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
{
for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
{
if ((*nodeIter)->NeedsGradient())
(*nodeIter)->ReleaseMatricesAfterBackprop(matrixPool);
}
}
// find if node is part of a recurrent loop; and return the loop id
// If found then return a pointer to the list of nodes of this loop.
/*static*/ shared_ptr<ComputationNetwork::SEQTraversalFlowControlNode> ComputationNetwork::FindInRecurrentLoops(const std::vector<std::shared_ptr<SEQTraversalFlowControlNode>>& recurrentInfo, const ComputationNodeBasePtr& node)
{
// look in all recurrent loops of the network
// TODO: Check for IsPartOfLoop(). Also why not store the loop id in the node for direct lookup?
for (auto& iter : recurrentInfo)
if (std::find(iter->m_nestedNodes.begin(), iter->m_nestedNodes.end(), node) != iter->m_nestedNodes.end()) // TODO: should this loop need to be a method of SEQTraversalFlowControlNode?
return iter;
return nullptr; // not part of a recurrent loop
}
// check if any of the nodes in the recurrence IsOutOfDateWrtInputs(), with exception of delay nodes for which this check would fail and must be skipped
// TODO: Would it be sufficient to check against our own time stamp, so that we can use a unified time-stamping mechanism? Then we'd not need this special check for delayed nodes; just check all inputs against our own time stamp.
bool ComputationNetwork::SEQTraversalFlowControlNode::IsOutOfDateWrtInputs() const
{
for (auto& ptr : m_nestedNodes)
{
if (ptr->IsOutOfDateWrtInputs() &&
ptr->OperationName() != OperationNameOf(PastValueNode) &&
ptr->OperationName() != OperationNameOf(FutureValueNode))
// TODO: when ShiftNode lands, check this as well. Ideally just test whether ptr is a IRecurrentNode
{
return true;
}
}
return false;
}
// TODO: do this on PARTraversalFlowControlNode
void ComputationNetwork::ResetEvalTimeStamps()
{
for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
nodeIter->second->ResetEvalTimeStamp();
}
/*static*/ void ComputationNetwork::BumpEvalTimeStamp(const vector<ComputationNodeBasePtr>& nodes)
{
for (size_t i = 0; i < nodes.size(); i++)
nodes[i]->BumpEvalTimeStamp();
}
// for debugging
void ComputationNetwork::PrintComputationTree(const ComputationNodeBasePtr& rootNode,
const bool forwardCompute,
const bool printMatrices)
{
auto nodes = GetEvalOrder(rootNode); // note: don't take a reference, since we reverse() below
if (forwardCompute)
{
fprintf(stderr, "\n\nPrinting forward-computation node order ... \n");
}
else
{
fprintf(stderr, "\n\nPrinting gradient-computation node order ... \n");
nodes.reverse();
}
if (nodes.size() == 0)
fprintf(stderr, "\n(empty)\n");
else
for (const auto& node : nodes)
node->PrintSelf(printMatrices);
}
// -----------------------------------------------------------------------
// preparation of network
// -----------------------------------------------------------------------
// called by model editing operations, such as DeleteNode(); and by RebuildNetwork()
// These invalidates any post-processed structures. If they are accessed, we will fail.
void ComputationNetwork::InvalidateCompiledNetwork()
{
m_isCompiled = false;
m_allSEQNodes.clear();
m_evalOrders.clear();
m_nestedNetworks.clear();
m_inputValues.clear();
m_learnableParameters.clear();
}
// verify that network has undergone CompileNetwork()
void ComputationNetwork::VerifyIsCompiled(const char* where) const
{
if (!IsCompiled())
LogicError("%s: A compiled network was expected.", where);
}
// CompileNetwork() -- bring network into executable state
// Call this after creation, load, and any modification.
// This method sets up all members that are cleared in InvalidateCompiledNetwork();
// TODO: This should be the only entry point, subsuming all other Validate, Build, etc. functions.
// TODO: Related functions today do lots of stuff lazily. There are redundant calls. That will all be removed.
// TODO: This is in a somewhat partial state in that we now have a global eval order (keyed by a nullptr), but don't use it yet.
void ComputationNetwork::CompileNetwork()
{
fprintf(stderr, "\nPost-processing network...\n");
// all steps below have to be repeated for all root nodes (=nodes without parents and PreComputeNodes)
DetermineSetOfAllRoots();
fprintf(stderr, "\n%d roots:\n", (int) m_allRoots.size());
for (const auto& root : m_allRoots)
fprintf(stderr, "\t%ls = %ls\n", root->NodeName().c_str(), root->OperationName().c_str());
// Note: Steps below are loops over root nodes. We will gradually push those loops through to the functions,
// to reduce redundant operation on shared portions of the network.
// STEP: Create a depth-first tree-traversal order through original graph for every root.
// This is used wherever a nested structure is not relevant.
FormEvalOrder(nullptr); // form the global one
for (auto& node : m_allRoots)
FormEvalOrder(node);
// STEP: form the m_inputValues and m_learnableParameters sets for this rootNode
CollectInputAndLearnableParameters(nullptr);
for (const auto& root : m_allRoots)
CollectInputAndLearnableParameters(root);
// STEP: Discover nested loops.
FormRecurrentLoops(nullptr); // form the global one --TODO: just use this; should be no need to do this for each root
for (auto& node : m_allRoots)
FormRecurrentLoops(node);
// STEP: Form nested structure of PAR and SEQ traversal nodes.
for (auto& node : m_allRoots)
FormNestedNetwork(node);
// STEP: Infer node dimensions.
ValidateNetwork();
// STEP: Optimize the network.
// :)
// STEP: Some final details.
ResetEvalTimeStamps(); // invalidate all m_value fields. Really belongs into StartEvaluateMinibatchLoop()
fprintf(stderr, "\nPost-processing network complete.\n");
m_isCompiled = true;
}
// determine the set of all root nodes
// Roots are nodes that ForwardProp() may be called for.
// - training criterion, eval criteria
// - outputs
// - PreComputeNodes
// Result is stored in m_allRoots.
// BUGBUG: In the current implementation, outputs that are also inputs to others must be specified explicitly e.g. by a tag.
void ComputationNetwork::DetermineSetOfAllRoots()
{
// start with all non-referenced nodes
set<ComputationNodeBasePtr> allNodes, referencedNodes;
for (const auto& iter : m_nameToNodeMap)
{
auto node = iter.second;
allNodes.insert(node);
for (size_t i = 0; i < node->GetNumInputs(); i++)
{
auto input = node->Input(i);
if (!input) // this may be the result of an incorrect MEL operation
{
InvalidArgument("DetermineSetOfAllRoots: Input %d of %ls %ls operation if not connected, network is malformed.",
(int) i, node->NodeName().c_str(), node->OperationName().c_str());
}
referencedNodes.insert(input);
}
}
set<ComputationNodeBasePtr> unreferencedNodes;
set_difference(allNodes.begin(), allNodes.end(), referencedNodes.begin(), referencedNodes.end(), inserter(unreferencedNodes, unreferencedNodes.end()));
// add in all explicitly specified nodes.
// TODO: This is not ideal. We will also need on-demand compilation, to allow any node to be used as an output after the fact.
set<ComputationNodeBasePtr> allKnownRoots;
for (const auto& node : FinalCriterionNodes())
allKnownRoots.insert(node);
for (const auto& node : EvaluationNodes())
allKnownRoots.insert(node);
for (const auto& node : OutputNodes())
allKnownRoots.insert(node);
for (const auto& iter : m_nameToNodeMap) // PreComputeNodes
{
auto node = iter.second;
if (node->RequiresPreCompute())
allKnownRoots.insert(node);
}
// set m_allRoots to include both non-referenced nodes and also all explicitly specified roots
m_allRoots.clear();
set_union(unreferencedNodes.begin(), unreferencedNodes.end(), allKnownRoots.begin(), allKnownRoots.end(), inserter(m_allRoots, m_allRoots.end()));
// and bring the roots into a well-defined order
// I did observe different order depending on complexity of non-Node BrainScript expressions.
sort(m_allRoots.begin(), m_allRoots.end(),[](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b)
{
return a->NodeName() < b->NodeName();
});
}
// -----------------------------------------------------------------------
// validation
// -----------------------------------------------------------------------
// validate sub-network needed to evalute a specific output node
// This calls Validate() on every node in evaluation order (allowing to propagate things forwards through the net).
// This is called lazily but once only per node until next ClearCache().
// This also sets up MBLayout links.
void ComputationNetwork::ValidateNetwork()
{
// reset to a well-defined MBLayout (any meaningful layout should do here)
// Note that Validate is never called during operation. Any actual computation will lead to MBLayout to be set.
m_pMBLayout->Init(1, 0);
// set up MBLayout links of inputs (all others get propagated upwards through Validate())
// TODO: Once we support mismatching layouts, this will be more involved. For now, everything shares the one layout that the Network knows about.
for (auto node : InputNodes(nullptr))
node->LinkToMBLayout(m_pMBLayout);
// we call all nodes' Validate() in order to validate, that is, set up MBLayout and FunctionValues dimension
// A problem is that recurrent loops may require partial validation.
// Nodes validated on partial input (i.e. some children not yet validated) will be revisited.
const auto& nodes = GetEvalOrder(nullptr);
for (auto& node : nodes)
{
node->m_visited = false;
node->m_needsGradient = node->IsParameterUpdateRequired(); // these get propagated upwards in the following
}
// loop and validate until we are done
// steps:
// - validate (not final) // not final means no dimension checks
// Keep going through the list until all nodes have been validated and all inputs have been validated as well.
// - validate (final) // final means consistency checks
// Fail if any change during this stage.
size_t pass = 0;
size_t toValidate = nodes.size();
while (toValidate > 0)
{
pass++;
fprintf(stderr, "\n\nValidating network. %d nodes to process in pass %d.\n", (int) toValidate, (int) pass);
ValidateNodes(nodes, false /*isFinalValidationPass*/, toValidate);
}
fprintf(stderr, "\n\nValidating network, final pass.\n");
ValidateNodes(nodes, true /*isFinalValidationPass*/, toValidate);
if (toValidate != 0)
LogicError("ValidateSubNetwork: ValidateNodes(true) unexpectedly returned with work left to do.");
// propagate some info to SEQTraversalFlowControlNode
// TODO: In the future we should validate not on the flat list but the PARTraversalFlowControlNode structure. Then this will be unnecessary.
for (auto& recInfo : m_allSEQNodes)
{
auto& node = recInfo->m_sourceNode;
recInfo->m_needsGradient = node->m_needsGradient;
recInfo->LinkToMBLayout(node->GetMBLayout());
}
for (auto& node : nodes)
{
// nodes must output non-zero dimensional data, otherwise assume user error
if (node->GetSampleLayout().GetNumElements() == 0)
RuntimeError("%ls operation has 0 elements", node->NodeName().c_str());
}
fprintf(stderr, "\n\n");
// logging the non-default-layout nodes
vector<ComputationNodeBasePtr> nonDefaultNodes;
for (auto node : nodes)
{
if (!(node->GetMBLayout() == m_pMBLayout))
nonDefaultNodes.push_back(node);
}
if (!nonDefaultNodes.empty())
{
fprintf(stderr, "%d out of %d nodes do not share the minibatch layout with the input data.\n", (int) nonDefaultNodes.size(), (int) nodes.size());
// for (auto node : nonDefaultNodes)
// fprintf(stderr, " %ls\n", node->NodeName().c_str());
// fprintf(stderr, "\n\n");
}
}
// helper to discover dimension changes
static pair<TensorShape, bool> GetDims(const ComputationNodeBasePtr& node)
{
return make_pair(node->GetSampleLayout(), node->HasMBLayout());
}
void ComputationNetwork::ValidateNodes(list<ComputationNodeBasePtr> nodes, bool isFinalValidationPass, size_t& todo)
{
todo = 0; // returns how many nodes are to be redone
for (auto& node : nodes)
{
const auto& children = node->GetInputs();
const bool isLeaf = node->IsLeaf();
// only validate a node if it has at least one child
bool hasVisitedChild = false;
bool allChildrenVisited = true;
for (auto& child : children)
{
hasVisitedChild |= child->m_visited; // if not a single visited child then no point in validating
allChildrenVisited &= child->m_visited;
}
// if there is not at least one visited child
bool valid = false;
if (hasVisitedChild || isLeaf)
{
// got at least one child: it makes sense to call Validate()
// keep state
MBLayoutPtr oldMBLayoutPtr = node->GetMBLayout();
auto dim = GetDims(node);
vector<pair<TensorShape, bool>> childDims;
for (auto& child : children)
childDims.push_back(GetDims(child));
auto sampleLayout = node->GetSampleLayout();
// We do call validate(final) as many times as needed, since stuff may have changed underneath.
node->PrintSelfBeforeValidation();
node->Validate(isFinalValidationPass /*final*/); // all nodes have been visited: do verification instead of just inference
fprintf(stderr, " -> [%s%s]", string(node->GetSampleLayout()).c_str(), node->HasMBLayout() ? " x *" : "");
node->m_visited = true;
// also take the opportunity to propagate m_needsGradient
auto needsGradient = node->m_needsGradient;
for (auto& child : children) // TODO: do we need a check that this is stable if isFinalValidationPass?
node->m_needsGradient |= child->m_needsGradient;
// check state --node will be valid if all nodes have been visited and node has not been updated
bool unchanged = true;
unchanged &= (oldMBLayoutPtr == node->GetMBLayout());
unchanged &= (dim == GetDims(node));
vector<pair<TensorShape, bool>> newChildDims;
for (auto& child : children)
newChildDims.push_back(GetDims(child));
unchanged &= (childDims == newChildDims);
unchanged &= (sampleLayout == node->GetSampleLayout());
unchanged &= (needsGradient == node->m_needsGradient);
if (isFinalValidationPass && !unchanged)
LogicError("ValidateSubNetwork: %ls %ls operation changed during final validation.", node->NodeName().c_str(), node->OperationName().c_str());
if (isFinalValidationPass && !allChildrenVisited)
LogicError("ValidateSubNetwork: %ls %ls operation in final validation although not all children were visited?", node->NodeName().c_str(), node->OperationName().c_str());
// if all children valid then
valid = (allChildrenVisited && unchanged) || isLeaf;
}
// count those that we need to redo
if (!valid)
todo++;
}
}
// -----------------------------------------------------------------------
// memory allocation
// -----------------------------------------------------------------------
// mark nodes that are purely induced by parameters as non-sharable and create space for value if null
void ComputationNetwork::MarkValueNonSharableNodes()
{
const auto& nodes = GetEvalOrder(nullptr);
std::map<wstring, bool> allLeafDescendentsAreParameters;
std::list<ComputationNodeBasePtr> allLearnableParameters = GetNodesWithType(OperationNameOf(LearnableParameter));
// note that: we cannot use m_learnableParameters because we need all parameters node, regardless whether it requires update or not
for (auto& node : nodes)
{
auto children = node->GetInputs();
wstring myname = node->NodeName();
bool allParameters = true;
if (children.size()) // we don't do the check for leaf node, cause all the possible leaf nodes (input/parameters/precompute node) are marked as non-sharable already
{
for (auto child : children)
{
wstring ChildName = child->NodeName();
if (allLeafDescendentsAreParameters.find(ChildName) == allLeafDescendentsAreParameters.end())
{
// not found, means it is a leaf node (we are at eval order )
assert(child->IsLeaf() || child->IsPartOfLoop());
if (std::find(allLearnableParameters.begin(), allLearnableParameters.end(), child) != allLearnableParameters.end())
{
allLeafDescendentsAreParameters[ChildName] = true;
}
else
{
allParameters = false;
allLeafDescendentsAreParameters[ChildName] = false;
break;
}
}
else
{
if (allLeafDescendentsAreParameters[ChildName] == false)
{
allParameters = false;
break;
}
}
}
allLeafDescendentsAreParameters[myname] = allParameters;
if (allParameters)
{
node->MarkValueNonSharable();
}
}
}
}
// this function will need to be called before actual validation and execution to
// predetermine how to share matrices to reduce memory usage.
// TODO: find a simple topological order and allocateEvalMatrices on that order directly
// without passing in eval, out, and train nodes.
void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBasePtr>& evalRootNodes,
const std::vector<ComputationNodeBasePtr>& outValueRootNodes,
ComputationNodeBasePtr trainRootNode)
{
// Allocate memory for forward/backward computation
fprintf(stderr, "\n\nAllocating matrices for forward and/or backward propagation.\n");
VerifyIsCompiled("AllocateAllMatrices");
// Due to special topology, if a node is solely induced by parameters, its function value should not be shared
MarkValueNonSharableNodes();
bool performingBackPropagation = (trainRootNode != nullptr);
// Create a composite Eval order with the specified nodes as roots
std::vector<ComputationNodeBasePtr> forwardPropRoots;
forwardPropRoots.insert(forwardPropRoots.end(), evalRootNodes.begin(), evalRootNodes.end());
forwardPropRoots.insert(forwardPropRoots.end(), outValueRootNodes.begin(), outValueRootNodes.end());
if (trainRootNode != nullptr)
forwardPropRoots.push_back(trainRootNode);
// For each node determine parents and whether the output of the
// node is needed during back propagation
std::unordered_map<ComputationNodeBasePtr, bool> outputValueNeededDuringBackProp;
std::unordered_map<ComputationNodeBasePtr, std::unordered_set<ComputationNodeBasePtr>> parentsMap;
for (auto& rootNode : forwardPropRoots)
{
list<ComputationNodeBasePtr>& currentRootEvalNodes = GetEvalOrder(rootNode);
for (auto& currentNode : currentRootEvalNodes)
{
for (int i = 0; i < currentNode->GetNumInputs(); i++)
{
ComputationNodeBasePtr pNode = currentNode->GetInputs()[i];
parentsMap[pNode].insert(currentNode);
if (performingBackPropagation)
{
if (outputValueNeededDuringBackProp.find(pNode) == outputValueNeededDuringBackProp.end())
outputValueNeededDuringBackProp[pNode] = pNode->OutputUsedInComputingInputNodesGradients();
outputValueNeededDuringBackProp[pNode] |= currentNode->InputUsedInComputingInputNodesGradients(i);
}
else
{
outputValueNeededDuringBackProp[pNode] = false;
}
}
}
}
std::unordered_map<ComputationNodeBasePtr, int> parentCount;
for (auto& keyValue : parentsMap)
{
parentCount[keyValue.first] = keyValue.second.size();
}
// Construct the composite forward prop eval order by enumerating the
// nodes corresponding to each of our roots and then arranging them in the
// relative order that they appear in the global evaluation order
const std::list<ComputationNodeBasePtr>& allNodesEvalOrder = GetEvalOrder(nullptr);
std::list<ComputationNodeBasePtr> nodesForForwardPropRoots = ComputationNodeBase::EnumerateNodes(forwardPropRoots);
std::vector<ComputationNodeBasePtr> compositeForwardPropEvalOrder;
for (auto& node : allNodesEvalOrder)
{
if (std::find(nodesForForwardPropRoots.cbegin(), nodesForForwardPropRoots.cend(), node) != nodesForForwardPropRoots.cend())
{
compositeForwardPropEvalOrder.push_back(node);
}
}
set<ComputationNodeBasePtr> completedEvaluate;
for (auto& nodeIter : compositeForwardPropEvalOrder)
{
nodeIter->SetOutputNeededDuringBackprop(outputValueNeededDuringBackProp[nodeIter]);
if (nodeIter->IsPartOfLoop())
{
// TODO: use FormNestedNetwork() here to avoid completedEvaluate[] check
shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(m_allSEQNodes, nodeIter);
assert(recInfo != nullptr);
if (completedEvaluate.insert(recInfo).second)
{
recInfo->RequestMatricesBeforeForwardProp(m_matrixPool);
for (auto& nodeLoopIter : recInfo->m_nestedNodes)
{
ReleaseMatricesAfterEvalForChildren(nodeLoopIter, parentCount);
}
}
}
else
{
nodeIter->RequestMatricesBeforeForwardProp(m_matrixPool);
// we only release matrices for the children since the root node's informatioin will be used and should not be shared
// with others
ReleaseMatricesAfterEvalForChildren(nodeIter, parentCount);
}
}
if (trainRootNode != nullptr)
{
std::list<ComputationNodeBasePtr>& backPropNodes = GetEvalOrder(trainRootNode);
// now, simulate the gradient computation order to determine how to allocate matrices
set<ComputationNodeBasePtr> completedGradient;
// we need to call it here since we always compute gradients for children and root node is not children of other node
trainRootNode->RequestMatricesBeforeBackprop(m_matrixPool);
for (auto iter = backPropNodes.rbegin(); iter != backPropNodes.rend(); iter++) // for gradient computation, traverse in reverse order
{
auto n = *iter;
if (n->IsPartOfLoop())
{
std::vector<ComputationNodeBasePtr> recurrentNodes;
shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(m_allSEQNodes, n);
if (completedGradient.insert(recInfo).second)
{
// SEQ mode: allocate all in loop first, then deallocate again
// TODO: next step: use PARTraversalFlowControlNode::AllocateGradientMatricesForInputs() and ReleaseMatricesAfterBackprop()...
// BUGBUG: naw, ^^ would not work! Wrong order! Need to rethink this. Need to make AllocateEvalMatrices() and AllocateGradientMatrices() the virtual functions.
recInfo->AllocateGradientMatricesForInputs(m_matrixPool);
// Loops are computed sample by sample so we have to allocate them all
recInfo->ReleaseMatricesAfterBackprop(m_matrixPool);
}
}
else
{
// PAR mode: we can allocate and immediately deallocate one by one
n->AllocateGradientMatricesForInputs(m_matrixPool);
// Root node's information will be used and should not be shared with others, also it's small (1x1)
if ((n != trainRootNode) && n->NeedsGradient())
n->ReleaseMatricesAfterBackprop(m_matrixPool);
}
}
}
}
void ComputationNetwork::ReleaseMatricesAfterEvalForChildren(ComputationNodeBasePtr n, std::unordered_map<ComputationNodeBasePtr, int>& parentCount)
{
for (int i = 0; i < n->GetNumInputs(); i++)
{
ComputationNodeBasePtr pNode = n->GetInputs()[i];
parentCount[pNode]--;
if (parentCount[pNode] == 0)
pNode->ReleaseMatricesAfterForwardProp(m_matrixPool);
}
}
} } }