272 строки
15 KiB
C++
272 строки
15 KiB
C++
//
|
|
// Copyright (c) Microsoft. All rights reserved.
|
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
|
//
|
|
// NetworkDescriptionLanguage.cpp : Code used to interpret the Network Description Language.
|
|
//
|
|
|
|
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
|
|
|
|
#include "NetworkDescriptionLanguage.h"
|
|
#include "NDLNetworkBuilder.h"
|
|
|
|
#include "ConvolutionalNodes.h"
|
|
#include "RNNNodes.h"
|
|
#include "DeprecatedNodes.h"
|
|
#include "EvaluationNodes.h"
|
|
#include "InputAndParamNodes.h"
|
|
#include "LinearAlgebraNodes.h"
|
|
#include "NonlinearityNodes.h"
|
|
#include "PreComputeNodes.h"
|
|
#include "ReshapingNodes.h"
|
|
#include "RecurrentNodes.h"
|
|
#include "SpecialPurposeNodes.h"
|
|
#include "TrainingNodes.h"
|
|
|
|
using namespace std;
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
// DuplicateNode - Duplicate a node in a macro as needed (it might already exist)
|
|
// node - node we are duplicating
|
|
// return - the new duplicated node if it didn't exist, or the previously duplicated node if it already did
|
|
template <typename ElemType>
|
|
NDLNode<ElemType>* NDLScript<ElemType>::DuplicateNode(NDLNode<ElemType>* node)
|
|
{
|
|
NDLNode<ElemType>* newNode = node->Copy();
|
|
m_children.push_back(newNode);
|
|
newNode->SetParentScript(this);
|
|
return newNode;
|
|
}
|
|
|
|
template <typename ElemType>
|
|
NDLScript<ElemType>::NDLScript(const NDLScript& copyMe)
|
|
: ConfigParser(copyMe)
|
|
{
|
|
m_baseName = copyMe.m_baseName;
|
|
m_scriptString = copyMe.m_scriptString;
|
|
m_macroNode = copyMe.m_macroNode;
|
|
m_noDefinitions = copyMe.m_noDefinitions; // no definitions can be made in this script, interpret all macro/function names as calls
|
|
m_definingMacro = false; // not defining when expanding macros (only reason to call this method
|
|
m_cn = copyMe.m_cn; // computation network to use for backup symbol lookup. Used for MEL where NDL and network nodes are mixed
|
|
|
|
// script lines in parsed node order
|
|
for (NDLNode<ElemType>* node : copyMe.m_script)
|
|
{
|
|
// duplicate this node
|
|
NDLNode<ElemType>* newNode = DuplicateNode(node);
|
|
AddSymbol(newNode->GetName(), newNode);
|
|
|
|
// now get the parameters to the functions added
|
|
ConfigValue value = newNode->GetParamString();
|
|
ParseParameters(newNode, value, true /*createNew*/);
|
|
|
|
// add it to the new script
|
|
m_script.push_back(newNode);
|
|
}
|
|
|
|
// now search the symbol table for other symbols that haven't been copied yet
|
|
// this happens for constants defined in macros and such
|
|
for (std::pair<std::string, NDLNode<ElemType>*> pair : copyMe.m_symbols)
|
|
{
|
|
// if we can't find the symbol in the copied symbol table, copy it here
|
|
if (m_symbols.find(pair.first) == end(m_symbols))
|
|
{
|
|
// duplicate this node
|
|
NDLNode<ElemType>* newNode = DuplicateNode(pair.second);
|
|
AddSymbol(pair.first, newNode);
|
|
// anything that takes parameters should be evaluated in the script loop
|
|
assert(newNode->GetParamString().empty());
|
|
}
|
|
}
|
|
// NOTE: the child nodes get populated as the nodes are duplicated in the loop above
|
|
// we shouldn't try to duplicate them separately
|
|
}
|
|
|
|
// copy constructor, creates a new disconnected copy of this node
|
|
// doesn't copy everything, so use for macro expansion only (it's private)
|
|
// copyMe - node to copy
|
|
template <typename ElemType>
|
|
NDLNode<ElemType>::NDLNode(const NDLNode<ElemType>& copyMe)
|
|
{
|
|
m_name = copyMe.m_name; // value on the left of the equals
|
|
m_value = copyMe.m_value; // value on the right of the equals (CN node name, or value)
|
|
m_parent = copyMe.m_parent; // parent script
|
|
m_type = copyMe.m_type; // type of node
|
|
m_paramString = copyMe.m_paramString; // parameter of a function/array
|
|
m_paramMacro = copyMe.m_paramMacro; // parameter of a macro (the variables used in the macro definition)
|
|
// don't copy over m_parameters, they will be reparsed after the copy
|
|
|
|
m_eval = nullptr; // pointer to an arbitrary eval structure
|
|
// script for macro calls, need to expand the macro for each call
|
|
// if it's not expanded the evalValue will be overwitten on multiple calls to a macro
|
|
m_script = (copyMe.m_script) ? new NDLScript<ElemType>(*copyMe.m_script) : nullptr;
|
|
}
|
|
template <typename ElemType>
|
|
NDLScript<ElemType>::NDLScript(const NDLScript&& moveMe)
|
|
: ConfigParser(move(moveMe))
|
|
{
|
|
m_baseName = move(moveMe.m_baseName);
|
|
m_scriptString = move(moveMe.m_scriptString);
|
|
m_script = move(moveMe.m_script); // script lines in parsed node order, macros will have definition followed by body
|
|
m_symbols = move(moveMe.m_symbols); // symbol table
|
|
m_macroNode = move(moveMe.m_macroNode); // set when interpretting a macro definition
|
|
m_noDefinitions = move(moveMe.m_noDefinitions); // no definitions can be made in this script, interpret all macro/function names as calls
|
|
m_definingMacro = move(moveMe.m_definingMacro);
|
|
m_children = move(moveMe.m_children); // child nodes. Note that m_script nodes may not be children of this object, they include macro nodes
|
|
m_cn = move(moveMe.m_cn); // computation network to use for backup symbol lookup. Used for MEL where NDL and network nodes are mixed
|
|
}
|
|
|
|
// EqualInsensitive - check to see if two nodes are equal
|
|
// string1 - [in,out] string to compare, if comparision is equal insensitive but not sensitive, will replace with sensitive version
|
|
// string2 - second string to compare
|
|
// alternate - alternate naming of the string
|
|
// return - true if strings are equal insensitive and modifies string1 to sensitive version if different
|
|
bool EqualInsensitive(std::wstring& string1, const std::wstring& string2, const wchar_t* alternate /*=NULL*/)
|
|
{
|
|
bool equal = EqualCI(string1, string2) ||
|
|
(alternate && EqualCI(string1, alternate));
|
|
|
|
if (equal)
|
|
string1 = string2;
|
|
|
|
return equal;
|
|
}
|
|
|
|
// ++ operator for this enum, so loops work
|
|
NDLPass& operator++(NDLPass& ndlPass)
|
|
{
|
|
assert(ndlPass != ndlPassMax);
|
|
ndlPass = static_cast<NDLPass>(ndlPass + 1);
|
|
return ndlPass;
|
|
}
|
|
|
|
// CheckFunction - check to see if we match a function name
|
|
// string1 - [in,out] string to compare, if comparision is equal and at least half the full node name will replace with full node name
|
|
// allowUndeterminedVariable - [out] set to true if undetermined variables (symbols yet to be defined) are allowed here
|
|
// return - true if function name found
|
|
bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
|
|
{
|
|
if (allowUndeterminedVariable)
|
|
*allowUndeterminedVariable = true; // be default we allow undetermined variables
|
|
|
|
wstring nodeType = msra::strfun::utf16(p_nodeType);
|
|
bool ret = false;
|
|
if (EqualInsensitive(nodeType, OperationNameOf(AbsNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(BatchNormalizationNode))) ret = true;
|
|
#ifdef COMING_SOON
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CRFNode), L"CRF")) ret = true;
|
|
#endif
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode), L"CBCEWithSM")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ClassificationErrorNode), L"ErrorPrediction")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(EqualNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(GreaterEqualNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(GreaterNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LessEqualNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LessNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(NotEqualNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ClipNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ConvolutionNode), L"Convolve")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CropNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(PoolingNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceNode), L"CosDist")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceWithNegativeSamplesNode), L"CosWithNegSamples")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CosineNode), L"Cos")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyWithSoftmaxNode), L"CEWithSM")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(DiagTimesNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(DiagonalNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(DropoutNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(DummyCriterionNode), L"DummyCriterion")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ElementTimesNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ExpNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(FloorNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(FutureValueNode))) ret = true;
|
|
#ifdef COMING_SOON
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(GMMLogLikelihoodNode), L"GMMLL")) ret = true;
|
|
#endif
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(HardmaxNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(IfNode), L"If")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(InputValue), L"Input")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(InvStdDevNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(KhatriRaoProductNode), L"ColumnwiseCrossProduct")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LearnableParameter), L"Parameter")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LogNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LogPlusNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LogSoftmaxNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LogisticNode), L"Logistic")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(LookupTableNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL1RegNode), L"L1Reg")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL2RegNode), L"L2Reg")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MaxPoolingNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MaxUnpoolingNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MeanNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(MinusNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(NegateNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(PastValueNode), L"Delay")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarDeNormalizationNode), L"PerDimMVDeNorm")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarNormalizationNode), L"PerDimMVNorm")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(PlusNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ReciprocalNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ReconcileDynamicAxisNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(RectifiedLinearNode), L"ReLU")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ReshapeNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(ROIPoolingNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(RowRepeatNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(RowStackNode))) ret = true;
|
|
#ifdef COMING_SOON
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SequenceDecoderNode), L"SEWithSM")) ret = true;
|
|
#endif
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SequenceWithSoftmaxNode), L"SEWithSM")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SigmoidNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SinNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SoftmaxNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SparseInputValue), L"SparseInput")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SqrtNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SquareErrorNode), L"SE")) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SumColumnElementsNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(SumElementsNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(TanhNode))) ret = true;
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(TimesNode))) ret = true;
|
|
//else if (EqualInsensitive(nodeType, OperationNameOf(TransposeDimensionsNode))) ret = true; // not supported from NDL, use Transpose()
|
|
else if (EqualInsensitive(nodeType, OperationNameOf(TransposeTimesNode))) ret = true;
|
|
// legacy names:
|
|
else if (EqualInsensitive(nodeType, L"ColumnElementTimes")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"Constant", L"Const")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"ImageInput", L"Image")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"ImageParameter")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"RowElementTimes")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"RowSlice")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"Scale")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"SparseImageInput", L"SparseImage")) ret = true;
|
|
else if (EqualInsensitive(nodeType, L"Transpose")) ret = true;
|
|
|
|
// return the actual node name in the parameter if we found something
|
|
if (ret)
|
|
p_nodeType = msra::strfun::utf8(nodeType);
|
|
return ret;
|
|
}
|
|
|
|
template <typename ElemType>
|
|
NDLScript<ElemType> NDLScript<ElemType>::s_global("global");
|
|
|
|
// declare the static variables from the classes
|
|
template <>
|
|
NDLScript<float> NDLScript<float>::s_global{};
|
|
template <>
|
|
NDLScript<double> NDLScript<double>::s_global{};
|
|
|
|
template <>
|
|
int NDLNode<float>::s_nameCounter = 0;
|
|
template <>
|
|
int NDLNode<double>::s_nameCounter = 0;
|
|
|
|
template class NDLNode<float>;
|
|
template class NDLNode<double>;
|
|
|
|
template class NDLScript<float>;
|
|
template class NDLScript<double>;
|
|
|
|
}}}
|