CNTK/Source/ActionsLib/NDLUtil.h

169 строки
7.2 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
#include "NetworkDescriptionLanguage.h"
#include "ComputationNetwork.h"
#include "NDLNetworkBuilder.h"
#include <string>
#include "Config.h"
#include <stdexcept>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
template <typename ElemType>
class NDLNodeEvaluatorImpl;
template <class ElemType>
class NDLUtil
{
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
private:
ComputationNetworkPtr m_net;
public:
NDLUtil(ComputationNetworkPtr net)
: m_net(net)
{
}
// ProcessNDLConfig - Process the NDL script from a configuration string value
// config - configuration string containing script
void ProcessNDLConfig(const ConfigValue& config, bool fullValidate = false)
{
NDLScript<ElemType> script(config);
ProcessNDLScript(&script, ndlPassAll, nullptr, fullValidate);
}
// ProcessNDLScript - Process the NDL script
// netNdl - netNDL structure
// ndlPassUntil - complete processing through this pass, all passes if ndlPassAll
// fullValidate - validate as a complete network? (false if this might be a snippet of a full network)
void ProcessNDLScript(NetNdl<ElemType>* netNdl, NDLPass ndlPassUntil = ndlPassAll, bool fullValidate = false)
{
ProcessNDLScript(netNdl->ndl, ndlPassUntil, netNdl->lastNode, fullValidate);
}
// ProcessNDLScript - Process the NDL script
// script - NDL Script to process
// ndlPassUntil - complete processing through this pass, all passes if ndlPassAll
// skipThrough - [in/out] for iterative processing, a pointer to an array of NDLNode*, one for each pass
// the pointer will be updated to last node processed for that pass, can be NULL if all node processing is desired
// fullValidate - validate as a complete network? (false if this might be a snippet of a full network)
void ProcessNDLScript(NDLScript<ElemType>* script, NDLPass ndlPassUntil = ndlPassAll, NDLNode<ElemType>** skipThrough = nullptr, bool fullValidate = false, const std::wstring& dumpFileName = L"")
{
// if we don't have a script yet, don't bother
if (script == nullptr)
return;
// set the Computational network in the script, so we can do name lookup in the model
script->SetComputationNetwork(m_net);
// loop through the different passes, processing as we go
// skipThrough (when not null) is a pointer to the following structure in the NetNdl class:
// NDLNode<ElemType>* lastNode[ndlPassMax]; // last node we evaluated for each pass
NDLNode<ElemType>* lastNode = nullptr;
for (NDLPass ndlPass = ndlPassInitial; ndlPass <= ndlPassUntil; ++ndlPass)
{
NDLNode<ElemType>* skipThroughNode = skipThrough ? *skipThrough : nullptr;
lastNode = ProcessPassNDLScript(script, ndlPass, skipThroughNode, fullValidate, dumpFileName);
if (skipThrough)
{
*skipThrough = lastNode;
skipThrough++;
}
}
}
// ProcessPassNDLScript - Process a pass of the NDL script
// script - NDL Script to process
// ndlPass - complete processing for this pass, all passes if ndlPassAll
// skipThrough - for iterative processing, skip through this node in the script (used for in-line MEL processing)
// fullValidate - validate as a complete network? (false if this might be a snippet of a full network)
// returns: last NDL node processed
NDLNode<ElemType>* ProcessPassNDLScript(NDLScript<ElemType>* script, NDLPass ndlPass, NDLNode<ElemType>* skipThrough = nullptr, bool fullValidate = false, const std::wstring& dumpFileName = L"")
{
if (ndlPass == ndlPassFinal)
{
// make sure to clear the caches so we pick up the new nodes
m_net->InvalidateCompiledNetwork();
// if requested then dump the nodes
// Note: This happens on the invalidated network.
if (dumpFileName != L"")
m_net->DumpAllNodesToFile(false, true, dumpFileName);
}
NDLNodeEvaluatorImpl<ElemType> ndlEvaluator(m_net);
NDLNode<ElemType>* lastNode = script->Evaluate(ndlEvaluator, L"", ndlPass, skipThrough);
if (ndlPass == ndlPassResolve)
SetOutputNodes(script);
return lastNode;
}
// CheckOutputNodes - check output nodes
// symbolName - name of the computation nodes we are collecting
// compNodes - array of computation nodes
void CheckOutputNodes(NDLScript<ElemType>* script, std::string symbolName, std::wstring groupTag)
{
NDLNode<ElemType>* nodeArray = script->FindSymbol(symbolName);
bool valid = m_net->FeatureNodes().size() > 0; // see if it's already valid
if (!valid && nodeArray) // otherwise, see if we found a symbol
{
NDLType outputType = nodeArray->GetType();
// accept either an array of nodes, or a single node
valid = (outputType == ndlTypeArray || outputType == ndlTypeFunction || outputType == ndlTypeMacroCall);
}
if (!valid)
RuntimeError("Invalid network node definition for '%s', nonexistant or wrong type", symbolName.c_str());
if (nodeArray)
{
vector<NDLNode<ElemType>*> nodes;
if (nodeArray->GetType() == ndlTypeArray)
nodes = nodeArray->GetParameters();
else
nodes.push_back(nodeArray);
for (size_t i = 0; i < nodes.size(); i++)
{
// get the computation node
auto cnNode = ComputationNode<ElemType>::FromVoidPtr(nodes[i]->GetEvalValue());
// if no evaluation value exists throw an error
if (!cnNode)
RuntimeError("Invalid node '%s' as a(n) %s node, nonexistant or wrong type", nodes[i]->GetName().c_str(), msra::strfun::utf8(groupTag).c_str());
// add to the desired node group
m_net->AddToNodeGroup(groupTag, cnNode);
}
}
}
// SetOutputNodes - Set the output nodes for the Computational Network
// NOTE: seems to be specific to NDLBuilderImpl, should be in a derived class for that execution engine
void SetOutputNodes(NDLScript<ElemType>* script)
{
// NOTE: all optional parameter nodes (i.e. tag="feature") have already been processed in ProcessOptionalParameters()
// handle the alternate way of specifying nodes, the array of nodes method
// parameter name node-group tag
CheckOutputNodes(script, "featureNodes" , L"feature" );
CheckOutputNodes(script, "labelNodes" , L"label" );
CheckOutputNodes(script, "criterionNodes", L"criterion" );
CheckOutputNodes(script, "evalNodes" , L"evaluation");
CheckOutputNodes(script, "outputNodes" , L"output" );
// legacy name:
CheckOutputNodes(script, "criteriaNodes" , L"criterion" );
}
};
template class NDLUtil<float>;
template class NDLUtil<double>;
}}}