DelayNode finally implemented with recurrent hookup, by taking a lambda to the evaluation of the inputs rather than evaluating them right away;

runtime object construction now passes around shared_ptr<ConfigRecord> instead of const ConfigRecord &, in order to allow for late evaluation--especially MakeRuntimeObject();
new helper base class RecurrentComputationNode
This commit is contained in:
Frank Seide 2015-08-28 11:42:24 -07:00
Родитель 180b405769
Коммит 8694640708
4 изменённых файлов: 61 добавлений и 35 удалений

Просмотреть файл

@ -294,26 +294,33 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
RowSliceNode(vector<ComputationNodePtr> && inputs, size_t firstRow, size_t numRows, const wstring & tag) : UnaryComputationNode(move(inputs), tag), firstRow(firstRow), numRows(numRows) { }
/*ComputationNode::*/ const wchar_t * OperationName() const { return L"RowSlice"; }
};
// DelayNode is special in that it may for cycles.
// Specifically, to break circular references, DelayNode does not resolve its input arg (a ComputationNode), but rather keeps the ConfigValuePtr for now.
// The ConfigValuePtr is meant to be unresolved, i.e. a lambda that will resolve its arg when accessing the value for the first time.
// I.e. after construction, DelayNode can be referenced, but it cannot perform any operation on its argument, since it does not know it yet.
// ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anythin this may reference
// from its or outer scope have been created (if those pointers are to Delay nodes in turn, those would again resolve in their
// Nodes deriving from RecurrentComputationNode are special in that it may involve cycles.
// Specifically, to break circular references, RecurrentComputationNode does not resolve its inputs arg (ComputationNodes),
// but rather keeps a lambda to do so later.
// By contract, the network builders will know to call FinalizeInit() on such nodes at the right time (before traversing its children to allow for more nodes to be created)/
// I.e. after construction, a RecurrentComputationNode can be referenced, but it cannot perform any operation on its inputs, since it does not know them yet.
// ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anything this may reference
// from its or outer scope have been created (if those pointers involve recurrent nodes in turn, those would again resolve in their
// later FinalizeInit() call, which may yet again create new nodes etc.).
struct DelayNode : public ComputationNode, public MustFinalizeInit
struct RecurrentComputationNode : public ComputationNode, public MustFinalizeInit
{
ConfigValuePtr argUnresolved;
ComputationNodePtr arg;
int deltaT;
function<vector<ComputationNodePtr>()> GetInputsLambda;
public:
DelayNode(ConfigValuePtr argUnresolved, int deltaT, const wstring & tag) : argUnresolved(argUnresolved), deltaT(deltaT) { SetTag(tag); }
RecurrentComputationNode(function<vector<ComputationNodePtr>()> GetInputsLambda) : GetInputsLambda(GetInputsLambda) { }
// FinalizeInit() is called form NDLNetworkBuilder when collecting all nodes; this is where we can lazily evaluate the recurrent connections.
/*MustFinalizeInit::*/ void FinalizeInit()
{
AttachInputs(vector<ComputationNodePtr>(1,argUnresolved)); // the implied type cast resolves it
argUnresolved = ConfigValuePtr(); // and free any references it may hold
vector<ComputationNodePtr> inputs = GetInputsLambda(); // this evaluates the nodes, and possibly creates local downstream pieces of the graph
AttachInputs(move(inputs));
GetInputsLambda = []() -> vector<ComputationNodePtr> { LogicError("RecurrentComputationNode::FinalizeInit: called twice"); }; // avoid it being called twice
// dim?
}
};
struct DelayNode : public RecurrentComputationNode
{
int deltaT;
public:
DelayNode(function<vector<ComputationNodePtr>()> GetInputsLambda, int deltaT, const wstring & tag) : RecurrentComputationNode(GetInputsLambda), deltaT(deltaT) { SetTag(tag); }
/*ComputationNode::*/ const wchar_t * OperationName() const { return L"Delay"; }
};
class InputValue : public ComputationNode
@ -356,12 +363,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
}
// factory function for ComputationNodes
template<>
shared_ptr<ComputationNode> MakeRuntimeObject<ComputationNode>(const ConfigRecord & config)
shared_ptr<ComputationNode> MakeRuntimeObject<ComputationNode>(const ConfigRecordPtr configp)
{
let & config = *configp;
let classIdParam = config[L"class"];
wstring classId = classIdParam;
let tagp = config.Find(L"tag");
wstring tag = tagp ? *tagp : wstring();
// TODO: factor these GetInputs() calls out
if (classId == L"LearnableParameterNode")
return make_shared<LearnableParameter>(config[L"outDim"], config[L"inDim"], tag);
else if (classId == L"PlusNode")
@ -372,7 +381,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
return make_shared<TimesNode>(GetInputs(config, 2, L"TimesNode"), tag);
else if (classId == L"DiagTimesNode")
return make_shared<DiagTimesNode>(GetInputs(config, 2, L"DiagTimesNode"), tag);
// BUGBUG: ScaleNode is given a BoxOf<Double>, not ComputationNode
// BUGBUG: ScaleNode is given a BoxOf<Double>, not ComputationNode; need to create a Const first
else if (classId == L"ScaleNode")
return make_shared<ScaleNode>(GetInputs(config, 2, L"ScaleNode"), tag);
else if (classId == L"LogNode")
@ -391,8 +400,23 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
return make_shared<CrossEntropyWithSoftmaxNode>(GetInputs(config, 2, L"CrossEntropyWithSoftmaxNode"), tag);
else if (classId == L"ErrorPredictionNode")
return make_shared<ErrorPredictionNode>(GetInputs(config, 2, L"ErrorPredictionNode"), tag);
else if (classId == L"DelayNode")
return make_shared<DelayNode>(config[L"input"], config[L"deltaT"], tag);
else
throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation());
}
// factory function for RecurrentComputationNodes
// The difference to the above is that the children are not resolved immediately but later during network connection.
// This takes the record as a shared_ptr so that we can keep it inside a lambda.
template<>
shared_ptr<RecurrentComputationNode> MakeRuntimeObject<RecurrentComputationNode>(const ConfigRecordPtr configp)
{
let & config = *configp;
let classIdParam = config[L"class"];
wstring classId = classIdParam;
let tagp = config.Find(L"tag");
wstring tag = tagp ? *tagp : wstring();
// instead of passing the array of input nodes, we pass a lambda that computes this array in the network-gathering path in NDLComputationNetwork
if (classId == L"DelayNode")
return make_shared<DelayNode>([configp](){ return GetInputs(configp, 1, L"DelayNode"); }, config[L"deltaT"], tag);
else
throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation());
}
@ -424,8 +448,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
set<ComputationNodePtr> outputs; // all output nodes
set<ComputationNodePtr> parameters; // all parameter nodes
public:
NDLComputationNetwork(const ConfigRecord & config)
NDLComputationNetwork(const ConfigRecordPtr configp)
{
let & config = *configp;
deque<ComputationNodePtr> workList;
// flatten the set of all nodes
// we collect all ComputationNodes from the config; that's it
@ -620,8 +645,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
}
};
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord &);
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord &);
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr);
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr);
// =======================================================================
// Evaluator -- class for evaluating a syntactic parse tree
@ -687,14 +712,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
struct ConfigurableRuntimeType
{
bool isConfigRecord;
function<ConfigValuePtr(const ConfigRecord &, TextLocation, const wstring &)> construct; // lambda to construct an object of this class
function<ConfigValuePtr(const ConfigRecordPtr, TextLocation, const wstring &)> construct; // lambda to construct an object of this class
};
template<class C>
static ConfigurableRuntimeType MakeRuntimeTypeConstructor()
{
ConfigurableRuntimeType info;
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
{
return ConfigValuePtr(MakeRuntimeObject<C>(config), location, exprPath);
};
@ -705,7 +730,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
static ConfigurableRuntimeType MakeExperimentalComputationNetworkConstructor()
{
ConfigurableRuntimeType info;
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
{
return ConfigValuePtr(MakeExperimentalComputationNetwork(config), location, exprPath);
};
@ -715,7 +740,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
static ConfigurableRuntimeType MakeExperimentalComputationNodeConstructor()
{
ConfigurableRuntimeType info;
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
{
return ConfigValuePtr(MakeExperimentalComputationNode(config), location, exprPath);
};
@ -731,6 +756,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
#define DefineRuntimeType(T) { L#T, MakeRuntimeTypeConstructor<T>() }
// ComputationNodes
DefineRuntimeType(ComputationNode),
DefineRuntimeType(RecurrentComputationNode),
// other relevant classes
DefineRuntimeType(NDLComputationNetwork), // currently our fake
// Functions
@ -879,15 +905,15 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
if (newIter == configurableRuntimeTypes.end())
LogicError("unknown magic runtime-object class");
// form the ConfigRecord
ConfigRecord config(nullptr);
auto config = make_shared<ConfigRecord>(nullptr);
// Note on scope: This config holds the arguments of the XXXNode runtime-object instantiations.
// When they fetch their parameters, they should only look in this record, not in any parent scope (if they don't find what they are looking for, it's a bug in this routine here).
// The values themselves are already in ConfigValuePtr form, so we won't need any scope lookups there either.
config.Add(L"class", e->location, ConfigValuePtr(make_shared<String>(classId), e->location, exprPath));
config->Add(L"class", e->location, ConfigValuePtr(make_shared<String>(classId), e->location, exprPath));
vector<ConfigValuePtr> inputs;
inputs.push_back(leftVal);
inputs.push_back(rightVal);
config.Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftVal.GetLocation(), exprPath));
config->Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftVal.GetLocation(), exprPath));
// instantiate
let value = newIter->second.construct(config, e->location, exprPath);
let valueWithName = dynamic_cast<HasName*>(value.get());
@ -978,7 +1004,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
// form the config record
let dictExpr = e->args[0];
let argsExprPath = newIter->second.isConfigRecord ? L"" : exprPath; // reset expr-name path if object exposes a dictionary
let value = newIter->second.construct(*ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it
let value = newIter->second.construct(ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it
// if object has a name, we set it
let valueWithName = dynamic_cast<HasName*>(value.get());
if (valueWithName)

Просмотреть файл

@ -263,7 +263,7 @@ namespace Microsoft{ namespace MSR { namespace CNTK { namespace Config {
// create a runtime object from its type --general case
// There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode.
template<typename C>
shared_ptr<C> MakeRuntimeObject(const ConfigRecord & config)
shared_ptr<C> MakeRuntimeObject(const ConfigRecordPtr config)
{
return make_shared<C>(config);
}

Просмотреть файл

@ -78,7 +78,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
// initialize a ComputationNetwork<ElemType> from a ConfigRecord
template<typename ElemType>
shared_ptr<ComputationNetwork<ElemType>> CreateComputationNetwork(const ConfigRecord & config)
shared_ptr<ComputationNetwork<ElemType>> CreateComputationNetwork(const ConfigRecordPtr config)
{
DEVICEID_TYPE deviceId = -1; // (DEVICEID_TYPE)(int)config[L"deviceId"];
auto net = make_shared<ComputationNetwork<ElemType>>(deviceId);
@ -161,7 +161,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
}
// create a ComputationNetwork<ElemType> from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord & config)
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr config)
{
wstring precision = config[L"precision"]; // TODO: we need to look those up while traversing upwards
if (precision == L"float")
@ -184,7 +184,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
}
// create a ComputationNetwork<ElemType> from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord & config)
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr config)
{
wstring precision = L"float"; // config[L"precision"]; // TODO: we need to look those up while traversing upwards
if (precision == L"float")

Просмотреть файл

@ -11,8 +11,8 @@ using namespace Microsoft::MSR::CNTK::Config;
#endif
namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord &) { return nullptr; }
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord &) { return nullptr; }
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr) { return nullptr; }
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr) { return nullptr; }
}}}}
#if 0
@ -91,7 +91,7 @@ L"PerDimMeanVarNormalization(feat,mean,invStdDev, tag='') = new ComputationNode
L"Parameter(outD, inD, tag='parameter') = new ComputationNode [ class = 'LearnableParameterNode' ; outDim = outD ; inDim = inD /*; tag = tag*/ ]\n"
L"Input(dim,tag='features') = Parameter(dim,1,tag=tag) // TODO: for now \n"
L"RowSlice(firstRow, rows, features, tag='') = new ComputationNode [ class = 'RowSliceNode' ; inputs = features ; first = firstRow ; num = rows /* ; tag = tag */ ]\n"
L"Delay(in, delay, tag='') = new ComputationNode [ class = 'DelayNode' ; input = in ; deltaT = -delay /* ; tag = tag */ ]\n"
L"Delay(in, delay, tag='') = new RecurrentComputationNode [ class = 'DelayNode' ; inputs = in ; deltaT = -delay /* ; tag = tag */ ]\n"
L"Sigmoid(z, tag='') = new ComputationNode [ class = 'SigmoidNode' ; inputs = z /* ; tag = tag */ ]\n"
L"Log(z, tag='') = new ComputationNode [ class = 'LogNode' ; inputs = z /* ; tag = tag */ ]\n"
L"CrossEntropyWithSoftmax(labels, outZ, tag='') = new ComputationNode [ class = 'CrossEntropyWithSoftmaxNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"