diff --git a/MachineLearning/CNTK/ConfigEvaluator.cpp b/MachineLearning/CNTK/ConfigEvaluator.cpp index 3601f0cd9..fa0642b3b 100644 --- a/MachineLearning/CNTK/ConfigEvaluator.cpp +++ b/MachineLearning/CNTK/ConfigEvaluator.cpp @@ -294,26 +294,33 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { RowSliceNode(vector && inputs, size_t firstRow, size_t numRows, const wstring & tag) : UnaryComputationNode(move(inputs), tag), firstRow(firstRow), numRows(numRows) { } /*ComputationNode::*/ const wchar_t * OperationName() const { return L"RowSlice"; } }; - // DelayNode is special in that it may for cycles. - // Specifically, to break circular references, DelayNode does not resolve its input arg (a ComputationNode), but rather keeps the ConfigValuePtr for now. - // The ConfigValuePtr is meant to be unresolved, i.e. a lambda that will resolve its arg when accessing the value for the first time. - // I.e. after construction, DelayNode can be referenced, but it cannot perform any operation on its argument, since it does not know it yet. - // ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anythin this may reference - // from its or outer scope have been created (if those pointers are to Delay nodes in turn, those would again resolve in their + // Nodes deriving from RecurrentComputationNode are special in that it may involve cycles. + // Specifically, to break circular references, RecurrentComputationNode does not resolve its inputs arg (ComputationNodes), + // but rather keeps a lambda to do so later. + // By contract, the network builders will know to call FinalizeInit() on such nodes at the right time (before traversing its children to allow for more nodes to be created)/ + // I.e. after construction, a RecurrentComputationNode can be referenced, but it cannot perform any operation on its inputs, since it does not know them yet. + // ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anything this may reference + // from its or outer scope have been created (if those pointers involve recurrent nodes in turn, those would again resolve in their // later FinalizeInit() call, which may yet again create new nodes etc.). - struct DelayNode : public ComputationNode, public MustFinalizeInit + struct RecurrentComputationNode : public ComputationNode, public MustFinalizeInit { - ConfigValuePtr argUnresolved; - ComputationNodePtr arg; - int deltaT; + function()> GetInputsLambda; public: - DelayNode(ConfigValuePtr argUnresolved, int deltaT, const wstring & tag) : argUnresolved(argUnresolved), deltaT(deltaT) { SetTag(tag); } + RecurrentComputationNode(function()> GetInputsLambda) : GetInputsLambda(GetInputsLambda) { } + // FinalizeInit() is called form NDLNetworkBuilder when collecting all nodes; this is where we can lazily evaluate the recurrent connections. /*MustFinalizeInit::*/ void FinalizeInit() { - AttachInputs(vector(1,argUnresolved)); // the implied type cast resolves it - argUnresolved = ConfigValuePtr(); // and free any references it may hold + vector inputs = GetInputsLambda(); // this evaluates the nodes, and possibly creates local downstream pieces of the graph + AttachInputs(move(inputs)); + GetInputsLambda = []() -> vector { LogicError("RecurrentComputationNode::FinalizeInit: called twice"); }; // avoid it being called twice // dim? } + }; + struct DelayNode : public RecurrentComputationNode + { + int deltaT; + public: + DelayNode(function()> GetInputsLambda, int deltaT, const wstring & tag) : RecurrentComputationNode(GetInputsLambda), deltaT(deltaT) { SetTag(tag); } /*ComputationNode::*/ const wchar_t * OperationName() const { return L"Delay"; } }; class InputValue : public ComputationNode @@ -356,12 +363,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { } // factory function for ComputationNodes template<> - shared_ptr MakeRuntimeObject(const ConfigRecord & config) + shared_ptr MakeRuntimeObject(const ConfigRecordPtr configp) { + let & config = *configp; let classIdParam = config[L"class"]; wstring classId = classIdParam; let tagp = config.Find(L"tag"); wstring tag = tagp ? *tagp : wstring(); + // TODO: factor these GetInputs() calls out if (classId == L"LearnableParameterNode") return make_shared(config[L"outDim"], config[L"inDim"], tag); else if (classId == L"PlusNode") @@ -372,7 +381,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { return make_shared(GetInputs(config, 2, L"TimesNode"), tag); else if (classId == L"DiagTimesNode") return make_shared(GetInputs(config, 2, L"DiagTimesNode"), tag); - // BUGBUG: ScaleNode is given a BoxOf, not ComputationNode + // BUGBUG: ScaleNode is given a BoxOf, not ComputationNode; need to create a Const first else if (classId == L"ScaleNode") return make_shared(GetInputs(config, 2, L"ScaleNode"), tag); else if (classId == L"LogNode") @@ -391,8 +400,23 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { return make_shared(GetInputs(config, 2, L"CrossEntropyWithSoftmaxNode"), tag); else if (classId == L"ErrorPredictionNode") return make_shared(GetInputs(config, 2, L"ErrorPredictionNode"), tag); - else if (classId == L"DelayNode") - return make_shared(config[L"input"], config[L"deltaT"], tag); + else + throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation()); + } + // factory function for RecurrentComputationNodes + // The difference to the above is that the children are not resolved immediately but later during network connection. + // This takes the record as a shared_ptr so that we can keep it inside a lambda. + template<> + shared_ptr MakeRuntimeObject(const ConfigRecordPtr configp) + { + let & config = *configp; + let classIdParam = config[L"class"]; + wstring classId = classIdParam; + let tagp = config.Find(L"tag"); + wstring tag = tagp ? *tagp : wstring(); + // instead of passing the array of input nodes, we pass a lambda that computes this array in the network-gathering path in NDLComputationNetwork + if (classId == L"DelayNode") + return make_shared([configp](){ return GetInputs(configp, 1, L"DelayNode"); }, config[L"deltaT"], tag); else throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation()); } @@ -424,8 +448,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { set outputs; // all output nodes set parameters; // all parameter nodes public: - NDLComputationNetwork(const ConfigRecord & config) + NDLComputationNetwork(const ConfigRecordPtr configp) { + let & config = *configp; deque workList; // flatten the set of all nodes // we collect all ComputationNodes from the config; that's it @@ -620,8 +645,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { } }; - shared_ptr MakeExperimentalComputationNetwork(const ConfigRecord &); - shared_ptr MakeExperimentalComputationNode(const ConfigRecord &); + shared_ptr MakeExperimentalComputationNetwork(const ConfigRecordPtr); + shared_ptr MakeExperimentalComputationNode(const ConfigRecordPtr); // ======================================================================= // Evaluator -- class for evaluating a syntactic parse tree @@ -687,14 +712,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { struct ConfigurableRuntimeType { bool isConfigRecord; - function construct; // lambda to construct an object of this class + function construct; // lambda to construct an object of this class }; template static ConfigurableRuntimeType MakeRuntimeTypeConstructor() { ConfigurableRuntimeType info; - info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct + info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct { return ConfigValuePtr(MakeRuntimeObject(config), location, exprPath); }; @@ -705,7 +730,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { static ConfigurableRuntimeType MakeExperimentalComputationNetworkConstructor() { ConfigurableRuntimeType info; - info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct + info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct { return ConfigValuePtr(MakeExperimentalComputationNetwork(config), location, exprPath); }; @@ -715,7 +740,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { static ConfigurableRuntimeType MakeExperimentalComputationNodeConstructor() { ConfigurableRuntimeType info; - info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct + info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct { return ConfigValuePtr(MakeExperimentalComputationNode(config), location, exprPath); }; @@ -731,6 +756,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { #define DefineRuntimeType(T) { L#T, MakeRuntimeTypeConstructor() } // ComputationNodes DefineRuntimeType(ComputationNode), + DefineRuntimeType(RecurrentComputationNode), // other relevant classes DefineRuntimeType(NDLComputationNetwork), // currently our fake // Functions @@ -879,15 +905,15 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { if (newIter == configurableRuntimeTypes.end()) LogicError("unknown magic runtime-object class"); // form the ConfigRecord - ConfigRecord config(nullptr); + auto config = make_shared(nullptr); // Note on scope: This config holds the arguments of the XXXNode runtime-object instantiations. // When they fetch their parameters, they should only look in this record, not in any parent scope (if they don't find what they are looking for, it's a bug in this routine here). // The values themselves are already in ConfigValuePtr form, so we won't need any scope lookups there either. - config.Add(L"class", e->location, ConfigValuePtr(make_shared(classId), e->location, exprPath)); + config->Add(L"class", e->location, ConfigValuePtr(make_shared(classId), e->location, exprPath)); vector inputs; inputs.push_back(leftVal); inputs.push_back(rightVal); - config.Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared(0, move(inputs)), leftVal.GetLocation(), exprPath)); + config->Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared(0, move(inputs)), leftVal.GetLocation(), exprPath)); // instantiate let value = newIter->second.construct(config, e->location, exprPath); let valueWithName = dynamic_cast(value.get()); @@ -978,7 +1004,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // form the config record let dictExpr = e->args[0]; let argsExprPath = newIter->second.isConfigRecord ? L"" : exprPath; // reset expr-name path if object exposes a dictionary - let value = newIter->second.construct(*ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it + let value = newIter->second.construct(ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it // if object has a name, we set it let valueWithName = dynamic_cast(value.get()); if (valueWithName) diff --git a/MachineLearning/CNTK/ConfigEvaluator.h b/MachineLearning/CNTK/ConfigEvaluator.h index 4bc1d77d9..6d1c41b25 100644 --- a/MachineLearning/CNTK/ConfigEvaluator.h +++ b/MachineLearning/CNTK/ConfigEvaluator.h @@ -263,7 +263,7 @@ namespace Microsoft{ namespace MSR { namespace CNTK { namespace Config { // create a runtime object from its type --general case // There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode. template - shared_ptr MakeRuntimeObject(const ConfigRecord & config) + shared_ptr MakeRuntimeObject(const ConfigRecordPtr config) { return make_shared(config); } diff --git a/MachineLearning/CNTK/ExperimentalNetworkBuilder.cpp b/MachineLearning/CNTK/ExperimentalNetworkBuilder.cpp index 73eb9bd3c..e97be5644 100644 --- a/MachineLearning/CNTK/ExperimentalNetworkBuilder.cpp +++ b/MachineLearning/CNTK/ExperimentalNetworkBuilder.cpp @@ -78,7 +78,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n // initialize a ComputationNetwork from a ConfigRecord template - shared_ptr> CreateComputationNetwork(const ConfigRecord & config) + shared_ptr> CreateComputationNetwork(const ConfigRecordPtr config) { DEVICEID_TYPE deviceId = -1; // (DEVICEID_TYPE)(int)config[L"deviceId"]; auto net = make_shared>(deviceId); @@ -161,7 +161,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n } // create a ComputationNetwork from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above - shared_ptr MakeExperimentalComputationNetwork(const ConfigRecord & config) + shared_ptr MakeExperimentalComputationNetwork(const ConfigRecordPtr config) { wstring precision = config[L"precision"]; // TODO: we need to look those up while traversing upwards if (precision == L"float") @@ -184,7 +184,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n } // create a ComputationNetwork from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above - shared_ptr MakeExperimentalComputationNode(const ConfigRecord & config) + shared_ptr MakeExperimentalComputationNode(const ConfigRecordPtr config) { wstring precision = L"float"; // config[L"precision"]; // TODO: we need to look those up while traversing upwards if (precision == L"float") diff --git a/MachineLearning/ParseConfig/main.cpp b/MachineLearning/ParseConfig/main.cpp index 69f084919..5d92f969b 100644 --- a/MachineLearning/ParseConfig/main.cpp +++ b/MachineLearning/ParseConfig/main.cpp @@ -11,8 +11,8 @@ using namespace Microsoft::MSR::CNTK::Config; #endif namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { - shared_ptr MakeExperimentalComputationNetwork(const ConfigRecord &) { return nullptr; } - shared_ptr MakeExperimentalComputationNode(const ConfigRecord &) { return nullptr; } + shared_ptr MakeExperimentalComputationNetwork(const ConfigRecordPtr) { return nullptr; } + shared_ptr MakeExperimentalComputationNode(const ConfigRecordPtr) { return nullptr; } }}}} #if 0 @@ -91,7 +91,7 @@ L"PerDimMeanVarNormalization(feat,mean,invStdDev, tag='') = new ComputationNode L"Parameter(outD, inD, tag='parameter') = new ComputationNode [ class = 'LearnableParameterNode' ; outDim = outD ; inDim = inD /*; tag = tag*/ ]\n" L"Input(dim,tag='features') = Parameter(dim,1,tag=tag) // TODO: for now \n" L"RowSlice(firstRow, rows, features, tag='') = new ComputationNode [ class = 'RowSliceNode' ; inputs = features ; first = firstRow ; num = rows /* ; tag = tag */ ]\n" -L"Delay(in, delay, tag='') = new ComputationNode [ class = 'DelayNode' ; input = in ; deltaT = -delay /* ; tag = tag */ ]\n" +L"Delay(in, delay, tag='') = new RecurrentComputationNode [ class = 'DelayNode' ; inputs = in ; deltaT = -delay /* ; tag = tag */ ]\n" L"Sigmoid(z, tag='') = new ComputationNode [ class = 'SigmoidNode' ; inputs = z /* ; tag = tag */ ]\n" L"Log(z, tag='') = new ComputationNode [ class = 'LogNode' ; inputs = z /* ; tag = tag */ ]\n" L"CrossEntropyWithSoftmax(labels, outZ, tag='') = new ComputationNode [ class = 'CrossEntropyWithSoftmaxNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"