DelayNode finally implemented with recurrent hookup, by taking a lambda to the evaluation of the inputs rather than evaluating them right away;
runtime object construction now passes around shared_ptr<ConfigRecord> instead of const ConfigRecord &, in order to allow for late evaluation--especially MakeRuntimeObject(); new helper base class RecurrentComputationNode
This commit is contained in:
Родитель
180b405769
Коммит
8694640708
|
@ -294,26 +294,33 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
RowSliceNode(vector<ComputationNodePtr> && inputs, size_t firstRow, size_t numRows, const wstring & tag) : UnaryComputationNode(move(inputs), tag), firstRow(firstRow), numRows(numRows) { }
|
||||
/*ComputationNode::*/ const wchar_t * OperationName() const { return L"RowSlice"; }
|
||||
};
|
||||
// DelayNode is special in that it may for cycles.
|
||||
// Specifically, to break circular references, DelayNode does not resolve its input arg (a ComputationNode), but rather keeps the ConfigValuePtr for now.
|
||||
// The ConfigValuePtr is meant to be unresolved, i.e. a lambda that will resolve its arg when accessing the value for the first time.
|
||||
// I.e. after construction, DelayNode can be referenced, but it cannot perform any operation on its argument, since it does not know it yet.
|
||||
// ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anythin this may reference
|
||||
// from its or outer scope have been created (if those pointers are to Delay nodes in turn, those would again resolve in their
|
||||
// Nodes deriving from RecurrentComputationNode are special in that it may involve cycles.
|
||||
// Specifically, to break circular references, RecurrentComputationNode does not resolve its inputs arg (ComputationNodes),
|
||||
// but rather keeps a lambda to do so later.
|
||||
// By contract, the network builders will know to call FinalizeInit() on such nodes at the right time (before traversing its children to allow for more nodes to be created)/
|
||||
// I.e. after construction, a RecurrentComputationNode can be referenced, but it cannot perform any operation on its inputs, since it does not know them yet.
|
||||
// ComputationNetwork knows to call FinalizeInit() to resolve this, at a time when pointers for anything this may reference
|
||||
// from its or outer scope have been created (if those pointers involve recurrent nodes in turn, those would again resolve in their
|
||||
// later FinalizeInit() call, which may yet again create new nodes etc.).
|
||||
struct DelayNode : public ComputationNode, public MustFinalizeInit
|
||||
struct RecurrentComputationNode : public ComputationNode, public MustFinalizeInit
|
||||
{
|
||||
ConfigValuePtr argUnresolved;
|
||||
ComputationNodePtr arg;
|
||||
int deltaT;
|
||||
function<vector<ComputationNodePtr>()> GetInputsLambda;
|
||||
public:
|
||||
DelayNode(ConfigValuePtr argUnresolved, int deltaT, const wstring & tag) : argUnresolved(argUnresolved), deltaT(deltaT) { SetTag(tag); }
|
||||
RecurrentComputationNode(function<vector<ComputationNodePtr>()> GetInputsLambda) : GetInputsLambda(GetInputsLambda) { }
|
||||
// FinalizeInit() is called form NDLNetworkBuilder when collecting all nodes; this is where we can lazily evaluate the recurrent connections.
|
||||
/*MustFinalizeInit::*/ void FinalizeInit()
|
||||
{
|
||||
AttachInputs(vector<ComputationNodePtr>(1,argUnresolved)); // the implied type cast resolves it
|
||||
argUnresolved = ConfigValuePtr(); // and free any references it may hold
|
||||
vector<ComputationNodePtr> inputs = GetInputsLambda(); // this evaluates the nodes, and possibly creates local downstream pieces of the graph
|
||||
AttachInputs(move(inputs));
|
||||
GetInputsLambda = []() -> vector<ComputationNodePtr> { LogicError("RecurrentComputationNode::FinalizeInit: called twice"); }; // avoid it being called twice
|
||||
// dim?
|
||||
}
|
||||
};
|
||||
struct DelayNode : public RecurrentComputationNode
|
||||
{
|
||||
int deltaT;
|
||||
public:
|
||||
DelayNode(function<vector<ComputationNodePtr>()> GetInputsLambda, int deltaT, const wstring & tag) : RecurrentComputationNode(GetInputsLambda), deltaT(deltaT) { SetTag(tag); }
|
||||
/*ComputationNode::*/ const wchar_t * OperationName() const { return L"Delay"; }
|
||||
};
|
||||
class InputValue : public ComputationNode
|
||||
|
@ -356,12 +363,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
}
|
||||
// factory function for ComputationNodes
|
||||
template<>
|
||||
shared_ptr<ComputationNode> MakeRuntimeObject<ComputationNode>(const ConfigRecord & config)
|
||||
shared_ptr<ComputationNode> MakeRuntimeObject<ComputationNode>(const ConfigRecordPtr configp)
|
||||
{
|
||||
let & config = *configp;
|
||||
let classIdParam = config[L"class"];
|
||||
wstring classId = classIdParam;
|
||||
let tagp = config.Find(L"tag");
|
||||
wstring tag = tagp ? *tagp : wstring();
|
||||
// TODO: factor these GetInputs() calls out
|
||||
if (classId == L"LearnableParameterNode")
|
||||
return make_shared<LearnableParameter>(config[L"outDim"], config[L"inDim"], tag);
|
||||
else if (classId == L"PlusNode")
|
||||
|
@ -372,7 +381,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
return make_shared<TimesNode>(GetInputs(config, 2, L"TimesNode"), tag);
|
||||
else if (classId == L"DiagTimesNode")
|
||||
return make_shared<DiagTimesNode>(GetInputs(config, 2, L"DiagTimesNode"), tag);
|
||||
// BUGBUG: ScaleNode is given a BoxOf<Double>, not ComputationNode
|
||||
// BUGBUG: ScaleNode is given a BoxOf<Double>, not ComputationNode; need to create a Const first
|
||||
else if (classId == L"ScaleNode")
|
||||
return make_shared<ScaleNode>(GetInputs(config, 2, L"ScaleNode"), tag);
|
||||
else if (classId == L"LogNode")
|
||||
|
@ -391,8 +400,23 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
return make_shared<CrossEntropyWithSoftmaxNode>(GetInputs(config, 2, L"CrossEntropyWithSoftmaxNode"), tag);
|
||||
else if (classId == L"ErrorPredictionNode")
|
||||
return make_shared<ErrorPredictionNode>(GetInputs(config, 2, L"ErrorPredictionNode"), tag);
|
||||
else if (classId == L"DelayNode")
|
||||
return make_shared<DelayNode>(config[L"input"], config[L"deltaT"], tag);
|
||||
else
|
||||
throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation());
|
||||
}
|
||||
// factory function for RecurrentComputationNodes
|
||||
// The difference to the above is that the children are not resolved immediately but later during network connection.
|
||||
// This takes the record as a shared_ptr so that we can keep it inside a lambda.
|
||||
template<>
|
||||
shared_ptr<RecurrentComputationNode> MakeRuntimeObject<RecurrentComputationNode>(const ConfigRecordPtr configp)
|
||||
{
|
||||
let & config = *configp;
|
||||
let classIdParam = config[L"class"];
|
||||
wstring classId = classIdParam;
|
||||
let tagp = config.Find(L"tag");
|
||||
wstring tag = tagp ? *tagp : wstring();
|
||||
// instead of passing the array of input nodes, we pass a lambda that computes this array in the network-gathering path in NDLComputationNetwork
|
||||
if (classId == L"DelayNode")
|
||||
return make_shared<DelayNode>([configp](){ return GetInputs(configp, 1, L"DelayNode"); }, config[L"deltaT"], tag);
|
||||
else
|
||||
throw EvaluationError(L"unknown ComputationNode class " + classId, classIdParam.GetLocation());
|
||||
}
|
||||
|
@ -424,8 +448,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
set<ComputationNodePtr> outputs; // all output nodes
|
||||
set<ComputationNodePtr> parameters; // all parameter nodes
|
||||
public:
|
||||
NDLComputationNetwork(const ConfigRecord & config)
|
||||
NDLComputationNetwork(const ConfigRecordPtr configp)
|
||||
{
|
||||
let & config = *configp;
|
||||
deque<ComputationNodePtr> workList;
|
||||
// flatten the set of all nodes
|
||||
// we collect all ComputationNodes from the config; that's it
|
||||
|
@ -620,8 +645,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
}
|
||||
};
|
||||
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord &);
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord &);
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr);
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr);
|
||||
|
||||
// =======================================================================
|
||||
// Evaluator -- class for evaluating a syntactic parse tree
|
||||
|
@ -687,14 +712,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
struct ConfigurableRuntimeType
|
||||
{
|
||||
bool isConfigRecord;
|
||||
function<ConfigValuePtr(const ConfigRecord &, TextLocation, const wstring &)> construct; // lambda to construct an object of this class
|
||||
function<ConfigValuePtr(const ConfigRecordPtr, TextLocation, const wstring &)> construct; // lambda to construct an object of this class
|
||||
};
|
||||
|
||||
template<class C>
|
||||
static ConfigurableRuntimeType MakeRuntimeTypeConstructor()
|
||||
{
|
||||
ConfigurableRuntimeType info;
|
||||
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
{
|
||||
return ConfigValuePtr(MakeRuntimeObject<C>(config), location, exprPath);
|
||||
};
|
||||
|
@ -705,7 +730,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
static ConfigurableRuntimeType MakeExperimentalComputationNetworkConstructor()
|
||||
{
|
||||
ConfigurableRuntimeType info;
|
||||
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
{
|
||||
return ConfigValuePtr(MakeExperimentalComputationNetwork(config), location, exprPath);
|
||||
};
|
||||
|
@ -715,7 +740,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
static ConfigurableRuntimeType MakeExperimentalComputationNodeConstructor()
|
||||
{
|
||||
ConfigurableRuntimeType info;
|
||||
info.construct = [](const ConfigRecord & config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
info.construct = [](const ConfigRecordPtr config, TextLocation location, const wstring & exprPath) // lambda to construct
|
||||
{
|
||||
return ConfigValuePtr(MakeExperimentalComputationNode(config), location, exprPath);
|
||||
};
|
||||
|
@ -731,6 +756,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
#define DefineRuntimeType(T) { L#T, MakeRuntimeTypeConstructor<T>() }
|
||||
// ComputationNodes
|
||||
DefineRuntimeType(ComputationNode),
|
||||
DefineRuntimeType(RecurrentComputationNode),
|
||||
// other relevant classes
|
||||
DefineRuntimeType(NDLComputationNetwork), // currently our fake
|
||||
// Functions
|
||||
|
@ -879,15 +905,15 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
if (newIter == configurableRuntimeTypes.end())
|
||||
LogicError("unknown magic runtime-object class");
|
||||
// form the ConfigRecord
|
||||
ConfigRecord config(nullptr);
|
||||
auto config = make_shared<ConfigRecord>(nullptr);
|
||||
// Note on scope: This config holds the arguments of the XXXNode runtime-object instantiations.
|
||||
// When they fetch their parameters, they should only look in this record, not in any parent scope (if they don't find what they are looking for, it's a bug in this routine here).
|
||||
// The values themselves are already in ConfigValuePtr form, so we won't need any scope lookups there either.
|
||||
config.Add(L"class", e->location, ConfigValuePtr(make_shared<String>(classId), e->location, exprPath));
|
||||
config->Add(L"class", e->location, ConfigValuePtr(make_shared<String>(classId), e->location, exprPath));
|
||||
vector<ConfigValuePtr> inputs;
|
||||
inputs.push_back(leftVal);
|
||||
inputs.push_back(rightVal);
|
||||
config.Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftVal.GetLocation(), exprPath));
|
||||
config->Add(L"inputs", leftVal.GetLocation(), ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftVal.GetLocation(), exprPath));
|
||||
// instantiate
|
||||
let value = newIter->second.construct(config, e->location, exprPath);
|
||||
let valueWithName = dynamic_cast<HasName*>(value.get());
|
||||
|
@ -978,7 +1004,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
|||
// form the config record
|
||||
let dictExpr = e->args[0];
|
||||
let argsExprPath = newIter->second.isConfigRecord ? L"" : exprPath; // reset expr-name path if object exposes a dictionary
|
||||
let value = newIter->second.construct(*ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it
|
||||
let value = newIter->second.construct(ConfigRecordFromDictExpression(dictExpr, scope, argsExprPath), e->location, exprPath); // this constructs it
|
||||
// if object has a name, we set it
|
||||
let valueWithName = dynamic_cast<HasName*>(value.get());
|
||||
if (valueWithName)
|
||||
|
|
|
@ -263,7 +263,7 @@ namespace Microsoft{ namespace MSR { namespace CNTK { namespace Config {
|
|||
// create a runtime object from its type --general case
|
||||
// There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode.
|
||||
template<typename C>
|
||||
shared_ptr<C> MakeRuntimeObject(const ConfigRecord & config)
|
||||
shared_ptr<C> MakeRuntimeObject(const ConfigRecordPtr config)
|
||||
{
|
||||
return make_shared<C>(config);
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
|
|||
|
||||
// initialize a ComputationNetwork<ElemType> from a ConfigRecord
|
||||
template<typename ElemType>
|
||||
shared_ptr<ComputationNetwork<ElemType>> CreateComputationNetwork(const ConfigRecord & config)
|
||||
shared_ptr<ComputationNetwork<ElemType>> CreateComputationNetwork(const ConfigRecordPtr config)
|
||||
{
|
||||
DEVICEID_TYPE deviceId = -1; // (DEVICEID_TYPE)(int)config[L"deviceId"];
|
||||
auto net = make_shared<ComputationNetwork<ElemType>>(deviceId);
|
||||
|
@ -161,7 +161,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
|
|||
}
|
||||
|
||||
// create a ComputationNetwork<ElemType> from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord & config)
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr config)
|
||||
{
|
||||
wstring precision = config[L"precision"]; // TODO: we need to look those up while traversing upwards
|
||||
if (precision == L"float")
|
||||
|
@ -184,7 +184,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Config { // n
|
|||
}
|
||||
|
||||
// create a ComputationNetwork<ElemType> from a config--this implements "new ExperimentalComputationNetwork [ ... ]" in the added config snippet above
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord & config)
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr config)
|
||||
{
|
||||
wstring precision = L"float"; // config[L"precision"]; // TODO: we need to look those up while traversing upwards
|
||||
if (precision == L"float")
|
||||
|
|
|
@ -11,8 +11,8 @@ using namespace Microsoft::MSR::CNTK::Config;
|
|||
#endif
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK { namespace Config {
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecord &) { return nullptr; }
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecord &) { return nullptr; }
|
||||
shared_ptr<Object> MakeExperimentalComputationNetwork(const ConfigRecordPtr) { return nullptr; }
|
||||
shared_ptr<Object> MakeExperimentalComputationNode(const ConfigRecordPtr) { return nullptr; }
|
||||
}}}}
|
||||
|
||||
#if 0
|
||||
|
@ -91,7 +91,7 @@ L"PerDimMeanVarNormalization(feat,mean,invStdDev, tag='') = new ComputationNode
|
|||
L"Parameter(outD, inD, tag='parameter') = new ComputationNode [ class = 'LearnableParameterNode' ; outDim = outD ; inDim = inD /*; tag = tag*/ ]\n"
|
||||
L"Input(dim,tag='features') = Parameter(dim,1,tag=tag) // TODO: for now \n"
|
||||
L"RowSlice(firstRow, rows, features, tag='') = new ComputationNode [ class = 'RowSliceNode' ; inputs = features ; first = firstRow ; num = rows /* ; tag = tag */ ]\n"
|
||||
L"Delay(in, delay, tag='') = new ComputationNode [ class = 'DelayNode' ; input = in ; deltaT = -delay /* ; tag = tag */ ]\n"
|
||||
L"Delay(in, delay, tag='') = new RecurrentComputationNode [ class = 'DelayNode' ; inputs = in ; deltaT = -delay /* ; tag = tag */ ]\n"
|
||||
L"Sigmoid(z, tag='') = new ComputationNode [ class = 'SigmoidNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"Log(z, tag='') = new ComputationNode [ class = 'LogNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"CrossEntropyWithSoftmax(labels, outZ, tag='') = new ComputationNode [ class = 'CrossEntropyWithSoftmaxNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"
|
||||
|
|
Загрузка…
Ссылка в новой задаче