ComputationNodes constructed from BrainScript are now constructed via the config constructor. Had to implement most special constructors (those that differ from the default) were mising
This commit is contained in:
Родитель
46206f5ed4
Коммит
add5afcce5
|
@ -82,5 +82,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
return ImageLayout(std::vector<size_t> { channels, width, height });
|
||||
}
|
||||
// and use this one when the data is a plain vector
|
||||
static inline ImageLayout ImageLayoutVector(size_t n)
|
||||
{
|
||||
return ImageLayout(std::vector<size_t> { 1, 1, n }); // for now storing it as a 3D object as well --TODO: fix this
|
||||
}
|
||||
// TODO: we need a constructor from config; that will generalize
|
||||
|
||||
}}}
|
||||
|
|
|
@ -110,7 +110,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// check more types
|
||||
else if (nodeType == OperationNameOf(AveragePoolingNode)) return New<AveragePoolingNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ConvolutionNode)) return New<ConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == InputValue<ElemType>::SparseTypeName()) return New<InputValue<ElemType>>(forward<_Types>(_Args).../*, true*/); // TODO: will go away; we will have a separate type SparseInputValue instead
|
||||
else if (nodeType == InputValue<ElemType>::SparseTypeName()) LogicError("Node type 'SparseInputValue' temporarily not supported. Will come back as its own proper type.");//return New<InputValue<ElemType>>(forward<_Types>(_Args).../*, true*/); // TODO: will go away; we will have a separate type SparseInputValue instead
|
||||
else if (nodeType == OperationNameOf(InputValue)) return New<InputValue<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(LearnableParameter)) return New<LearnableParameter<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(MaxPoolingNode)) return New<MaxPoolingNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -136,12 +136,19 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
wstring precision = configp->Get(L"precision"); // dispatch on ElemType
|
||||
wstring operationName = configp->Get(L"operation");
|
||||
ComputationNodeBasePtr node;
|
||||
if (precision == L"float")
|
||||
return CreateNode<float>(operationName, configp);
|
||||
node = CreateNode<float>(operationName, configp);
|
||||
else if (precision == L"double")
|
||||
return CreateNode<double>(operationName, configp);
|
||||
node = CreateNode<double>(operationName, configp);
|
||||
else
|
||||
RuntimeError("NewStandardNode: Invalid value '%ls' for 'precision' parameter. Must be 'float' or 'double'.", precision.c_str());
|
||||
// add a tag
|
||||
// Tags are used to declare special node types tp ComputationNetwork.
|
||||
const auto nodeWithTag = dynamic_pointer_cast<ScriptableObjects::WithTag>(node);
|
||||
if (nodeWithTag)
|
||||
nodeWithTag->SetTag(configp->Get(L"tag"));
|
||||
return node;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -362,15 +362,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
static vector<ComputationNodeBasePtr> GetInputsFromConfig(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
{
|
||||
vector<ComputationNodeBasePtr> inputs;
|
||||
const auto inputsArg = configp->Get(L"inputs");
|
||||
if (inputsArg.Is<ComputationNodeBase>()) // single arg
|
||||
inputs.push_back(inputsArg);
|
||||
else // a whole vector
|
||||
const auto * inputsArg = configp->Find(L"inputs");
|
||||
if (inputsArg)
|
||||
{
|
||||
ScriptableObjects::ConfigArrayPtr inputsArray = inputsArg;
|
||||
const auto range = inputsArray->GetIndexRange();
|
||||
for (int i = range.first; i <= range.second; i++) // pull them. This will resolve all of them.
|
||||
inputs.push_back(inputsArray->At(i, [](const wstring &){ LogicError("GetInputs: out of bounds index while iterating??"); }));
|
||||
if (inputsArg->Is<ComputationNodeBase>()) // single arg
|
||||
inputs.push_back(*inputsArg);
|
||||
else // a whole vector
|
||||
{
|
||||
ScriptableObjects::ConfigArrayPtr inputsArray = *inputsArg;
|
||||
const auto range = inputsArray->GetIndexRange();
|
||||
for (int i = range.first; i <= range.second; i++) // pull them. This will resolve all of them.
|
||||
inputs.push_back(inputsArray->At(i, [](const wstring &){ LogicError("GetInputs: out of bounds index while iterating??"); }));
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
@ -742,13 +745,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// for testing number of inputs in constructor (where we don't yet have dynamic_cast), construct this object with the config record
|
||||
// It will verify the number of elements on the 'inputs' argument
|
||||
NumInputs(const ScriptableObjects::IConfigRecordPtr configp) { Check(configp); }
|
||||
#if 0
|
||||
ScriptableObjects::IConfigRecordPtr Check(const ScriptableObjects::IConfigRecordPtr configp) const
|
||||
{
|
||||
auto * val = configp->Find(L"inputs");
|
||||
size_t numInputs = val ? ComputationNodeBase::GetInputsFromConfig(configp).size() : 0;
|
||||
size_t numInputs = ComputationNodeBase::GetInputsFromConfig(configp).size();
|
||||
if (numInputs != GetExpectedNumInputs())
|
||||
{
|
||||
// print an error. For that, find at least one argument
|
||||
auto * val = configp->Find(L"inputs");
|
||||
if (!val) // if there is no 'inputs' then get the first item of this config record for a Fail() function
|
||||
{
|
||||
auto members = configp->GetMemberIds();
|
||||
|
@ -762,6 +766,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
return configp;
|
||||
}
|
||||
#endif
|
||||
size_t GetExpectedNumInputs() const override final { return m_numInputs; }
|
||||
};
|
||||
|
||||
|
@ -798,11 +803,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Nodes with NumInputs<> should say DeclareConstructorFromConfigWithNumInputs(ClassName), and nodes without DeclareConstructorFromConfig(ClassName).
|
||||
// The macro will forward to the regular constructor of the node (which may do more than just calling the base constructor), and then attach the inputs from config.
|
||||
#define DeclareConstructorFromConfig(C) C(const ScriptableObjects::IConfigRecordPtr configp) : C(configp->Get(L"deviceId"), L"<placeholder>") { AttachInputs(configp); }
|
||||
#ifdef _MSC_VER
|
||||
#define DeclareConstructorFromConfigWithNumInputs(C) C(const ScriptableObjects::IConfigRecordPtr configp) : C(configp->Get(L"deviceId"), L"<placeholder>") { AttachInputs(NumInputs::Check(configp)); }
|
||||
#else
|
||||
#define DeclareConstructorFromConfigWithNumInputs DeclareConstructorFromConfig // standard C++ is too stupid to accept NumInputs without template arguments, which is the whole point here
|
||||
#endif
|
||||
#define DeclareConstructorFromConfigWithNumInputs(C) C(const ScriptableObjects::IConfigRecordPtr configp) : C(configp->Get(L"deviceId"), L"<placeholder>") { AttachInputs(configp, this->GetExpectedNumInputs()); }
|
||||
|
||||
virtual ~ComputationNode()
|
||||
{
|
||||
|
@ -851,7 +852,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Note: Nodes with variable number of inputs will not derive from NumInputs<>, but instead check their inputs in Validate().
|
||||
void AttachInputs(const std::vector<ComputationNodeBasePtr>& inputs)
|
||||
{
|
||||
wstring name = NodeName(); name;
|
||||
#ifdef _DEBUG
|
||||
wstring name = NodeName(); name; // (for easier debugging)
|
||||
#endif
|
||||
const auto * pNumInputs = dynamic_cast<INumInputs*>(this); // if this class also derives from NumInputs<N> then N is the expected number of inputs
|
||||
if (pNumInputs && pNumInputs->GetExpectedNumInputs() != inputs.size())
|
||||
RuntimeError("%ls operation '%ls' expects %d inputs (given: %d)", OperationName().c_str(), NodeName().c_str(), (int)pNumInputs->GetExpectedNumInputs(), (int)inputs.size());
|
||||
|
@ -865,9 +868,28 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
protected:
|
||||
// AttachInputs() from config
|
||||
void AttachInputs(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
void AttachInputs(const ScriptableObjects::IConfigRecordPtr configp, size_t expectedNumInputs = SIZE_MAX)
|
||||
{
|
||||
AttachInputs(GetInputsFromConfig(configp));
|
||||
const auto inputs = GetInputsFromConfig(configp);
|
||||
if (expectedNumInputs != SIZE_MAX)
|
||||
{
|
||||
if (inputs.size() != expectedNumInputs)
|
||||
{
|
||||
// print an error. For that, find at least one argument
|
||||
auto * val = configp->Find(L"inputs");
|
||||
if (!val) // if there is no 'inputs' then get the first item of this config record for a Fail() function
|
||||
{
|
||||
auto members = configp->GetMemberIds();
|
||||
if (members.size() > 0)
|
||||
val = configp->Find(members.front());
|
||||
}
|
||||
if (val)
|
||||
val->Fail(msra::strfun::wstrprintf(L"Expected %d inputs, but %d were given.", (int)expectedNumInputs, (int)inputs.size()));
|
||||
else
|
||||
InvalidArgument("Expected %d inputs, but %d were given.", (int)expectedNumInputs, (int)inputs.size());
|
||||
}
|
||||
}
|
||||
AttachInputs(inputs);
|
||||
}
|
||||
public:
|
||||
|
||||
|
@ -1496,7 +1518,7 @@ protected: \
|
|||
using Base::ChildrenSize; using Base::ClearGradientForChildren; using Base::VerifyDims; \
|
||||
using Base::ConstOnes; \
|
||||
using Base::GetImageLayout; using Base::InferImageDimsFromInput; using Base::InferImageDimsFromInputs; using Base::InferMBLayoutFromInputsForStandardCase; \
|
||||
using Base::CopyTo; using Base::CreateUniqNodeName; using Base::DetachInputs; \
|
||||
using Base::CopyTo; using Base::CreateUniqNodeName; using Base::DetachInputs; using Base::GetInputsFromConfig; \
|
||||
using Base::DumpNodeInfo; using Base::EnumerateNodes; \
|
||||
using Base::HasMBLayout; using Base::GetMBLayout; using Base::LinkToMBLayout; \
|
||||
using Base::Inputs; using Base::SetInput; \
|
||||
|
|
|
@ -38,7 +38,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"Convolution"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(ConvolutionNode);
|
||||
ConvolutionNode(DEVICEID_TYPE deviceId, const wstring & name) :
|
||||
Base(deviceId, name),
|
||||
m_kernelWidth(SIZE_MAX), m_kernelHeight(SIZE_MAX),
|
||||
|
@ -46,7 +45,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_horizontalSubsample(SIZE_MAX), m_verticalSubsample(SIZE_MAX),
|
||||
m_zeroPadding(false), m_maxTempMemSizeInSamples(SIZE_MAX)
|
||||
{
|
||||
m_imageLayout = ImageLayoutWHC(1, 1, 0); // TODO: what is this magic #channels == 0?
|
||||
m_imageLayout = ImageLayoutWHC(1, 1, 0); // TODO: what is this magic #channels == 0? Can this even be initialized at this time, or only inferred?
|
||||
}
|
||||
ConvolutionNode(DEVICEID_TYPE deviceId, const wstring & name, const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding = false, const size_t maxTempMemSizeInSamples = 0) :
|
||||
Base(deviceId, name),
|
||||
|
@ -56,6 +55,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
m_imageLayout = ImageLayoutWHC(1, 1, outputChannels);
|
||||
}
|
||||
ConvolutionNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
ConvolutionNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"kernelWidth"), configp->Get(L"kernelHeight"), configp->Get(L"outputChannels"),
|
||||
configp->Get(L"horizontalSubsample"), configp->Get(L"verticalSubsample"),
|
||||
configp->Get(L"zeroPadding"), configp->Get(L"maxTempMemSizeInSamples"))
|
||||
{
|
||||
// weightNodeName, inputValueNodeName, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, zeroPadding = false, maxTempMemSizeInSamples = 0
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
virtual void SaveToFile(File& fstream) const override
|
||||
{
|
||||
|
@ -411,6 +418,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_windowWidth(windowWidth), m_windowHeight(windowHeight),
|
||||
m_horizontalSubsample(horizontalSubsample), m_verticalSubsample(verticalSubsample)
|
||||
{ }
|
||||
PoolingNodeBase(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
PoolingNodeBase(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"windowWidth"), configp->Get(L"windowHeight"), configp->Get(L"horizontalSubsample"), configp->Get(L"verticalSubsample"))
|
||||
{
|
||||
// input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
virtual void SaveToFile(File& fstream) const override
|
||||
{
|
||||
|
@ -540,13 +553,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef PoolingNodeBase<ElemType> Base; UsingPoolingNodeBaseMembers;
|
||||
static const std::wstring TypeName() { return L"MaxPooling"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(MaxPoolingNode);
|
||||
MaxPoolingNode(DEVICEID_TYPE deviceId, const wstring & name) : Base(deviceId, name) { }
|
||||
MaxPoolingNode(DEVICEID_TYPE deviceId, const wstring & name, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) :
|
||||
Base(deviceId, name, windowWidth, windowHeight, horizontalSubsample, verticalSubsample)
|
||||
{ }
|
||||
MaxPoolingNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
Base(configp)
|
||||
{ }
|
||||
|
||||
/*implement*/ void ComputeInputPartialV(const Matrix<ElemType> &gradientValues, Matrix<ElemType> &inputGradientValues, const Matrix<ElemType> &input0, const Matrix<ElemType> &functionValues)
|
||||
virtual void ComputeInputPartialV(const Matrix<ElemType> &gradientValues, Matrix<ElemType> &inputGradientValues, const Matrix<ElemType> &input0, const Matrix<ElemType> &functionValues) override
|
||||
{
|
||||
inputGradientValues.AddMaxPoolingGradient(gradientValues, input0, functionValues, m_inputImageLayout.GetNumChannels(),
|
||||
m_inputImageLayout.GetWidth(), m_inputImageLayout.GetHeight(), m_inputSizePerSample,
|
||||
|
@ -554,7 +569,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_windowWidth, m_windowHeight, m_horizontalSubsample, m_verticalSubsample);
|
||||
}
|
||||
|
||||
/*implement*/ void EvaluateThisNodeV(Matrix<ElemType> &functionValues, const Matrix<ElemType> &input0)
|
||||
virtual void EvaluateThisNodeV(Matrix<ElemType> &functionValues, const Matrix<ElemType> &input0) override
|
||||
{
|
||||
functionValues.AssignMaxPoolingResult(input0, m_inputImageLayout.GetNumChannels(),
|
||||
m_inputImageLayout.GetWidth(), m_inputImageLayout.GetHeight(), m_inputSizePerSample,
|
||||
|
@ -576,13 +591,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef PoolingNodeBase<ElemType> Base; UsingPoolingNodeBaseMembers;
|
||||
static const std::wstring TypeName() { return L"AveragePooling"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(AveragePoolingNode);
|
||||
AveragePoolingNode(DEVICEID_TYPE deviceId, const wstring & name) : Base(deviceId, name) { }
|
||||
AveragePoolingNode(DEVICEID_TYPE deviceId, const wstring & name, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) :
|
||||
Base(deviceId, name, windowWidth, windowHeight, horizontalSubsample, verticalSubsample)
|
||||
{ }
|
||||
AveragePoolingNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
Base(configp)
|
||||
{ }
|
||||
|
||||
/*implement*/ void ComputeInputPartialV(const Matrix<ElemType> &gradientValues, Matrix<ElemType> &inputGradientValues, const Matrix<ElemType> &/*input0*/, const Matrix<ElemType> &/*functionValues*/)
|
||||
virtual void ComputeInputPartialV(const Matrix<ElemType> &gradientValues, Matrix<ElemType> &inputGradientValues, const Matrix<ElemType> &/*input0*/, const Matrix<ElemType> &/*functionValues*/) override
|
||||
{
|
||||
inputGradientValues.AddAveragePoolingGradient(gradientValues, m_inputImageLayout.GetNumChannels(),
|
||||
m_inputImageLayout.GetWidth(), m_inputImageLayout.GetHeight(), m_inputSizePerSample,
|
||||
|
@ -590,7 +607,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_windowWidth, m_windowHeight, m_horizontalSubsample, m_verticalSubsample);
|
||||
}
|
||||
|
||||
/*implement*/ void EvaluateThisNodeV(Matrix<ElemType> &functionValues, const Matrix<ElemType> &input0)
|
||||
virtual void EvaluateThisNodeV(Matrix<ElemType> &functionValues, const Matrix<ElemType> &input0) override
|
||||
{
|
||||
functionValues.AssignAveragePoolingResult(input0, m_inputImageLayout.GetNumChannels(),
|
||||
m_inputImageLayout.GetWidth(), m_inputImageLayout.GetHeight(), m_inputSizePerSample,
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
LearnableParameter(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
LearnableParameter(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"rows"), configp->Get(L"cols"))
|
||||
{
|
||||
NumInputs::Check(configp);
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
// parameters[rows, [cols=1]] plus other optional parameters (needGradient=[true|false], init=[uniform|gaussian|fixedvalue], initValueScale=[1|float], value=[0|float])
|
||||
// TODO: "needGradient" should be renamed to better match m_parameterUpdateRequired
|
||||
SetParameterUpdateRequired(configp->Get(L"needGradient"));
|
||||
|
@ -248,7 +248,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_parameterUpdateRequired = false;
|
||||
}
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(InputValue);
|
||||
InputValue(DEVICEID_TYPE deviceId, const wstring & name) :
|
||||
Base(deviceId, name)
|
||||
{
|
||||
|
@ -268,7 +267,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (rows * cols == 0)
|
||||
LogicError("This InputValue dimension is 0.");
|
||||
|
||||
m_imageLayout = ImageLayoutWHC(1, rows, 1);
|
||||
m_imageLayout = ImageLayoutVector(rows);
|
||||
Init(rows, cols, isSparse);
|
||||
}
|
||||
InputValue(DEVICEID_TYPE deviceId, const wstring & name, const ImageLayout & imageLayout, size_t numImages, bool isSparse = false) :
|
||||
|
@ -284,6 +283,27 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
Init(rows, cols, isSparse);
|
||||
}
|
||||
InputValue(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
InputValue(configp->Get(L"deviceId"), L"<placeholder>")
|
||||
{
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
bool isSparse = configp->Get(L"isSparse"); // TODO: no, this must go into a separate type SparseInputValue
|
||||
bool isImage = configp->Get(L"isImage");
|
||||
if (!isImage)
|
||||
{
|
||||
size_t rows = configp->Get(L"rows");
|
||||
size_t cols = configp->Get(L"cols");
|
||||
m_imageLayout = ImageLayoutVector(rows); // no tensor, just a vector
|
||||
Init(rows, cols, isSparse);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_imageLayout = ImageLayoutWHC(configp->Get(L"imageWidth"), configp->Get(L"imageHeight"), configp->Get(L"imageChannels"));
|
||||
size_t rows = m_imageLayout.GetNumElements();
|
||||
size_t cols = configp->Get(L"numImages"); // this is actually the MB size
|
||||
Init(rows, cols, isSparse);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SaveToFile(File& fstream) const override
|
||||
{
|
||||
|
|
|
@ -273,7 +273,7 @@ namespace Microsoft { namespace MSR { namespace ScriptableObjects {
|
|||
#endif
|
||||
|
||||
// temporary code for BrainScript update (using register)
|
||||
#if 1
|
||||
#if 0
|
||||
template<> shared_ptr<Object> MakeRuntimeObject<ComputationNode<float>>(const IConfigRecordPtr configp)
|
||||
{
|
||||
return DualPrecisionHelpers<float, ComputationNode<float>>::MakeRuntimeObject(configp);
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// TODO: 'direction' is really too general. signOfTimeOffset?
|
||||
template<class ElemType, int direction/*-1 for Past/left-to-right or +1 for Future/right-to-left*/, MinibatchPackingFlags SequenceStart_or_End/*-Start or -End*/>
|
||||
class DelayedValueNodeBase : public ComputationNode<ElemType>, public NumInputs<1>
|
||||
class DelayedValueNodeBase : public ComputationNode<ElemType>, public ILateAttachingNode, public NumInputs<1>
|
||||
{
|
||||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"DelayedValue"; }
|
||||
|
@ -46,8 +46,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_isHistoryCarryOverManagedExternally = false; // used for PairNetworkNode/PastValueNode combination
|
||||
}
|
||||
protected:
|
||||
//virtual ComputationNodeBase * NewThis(DEVICEID_TYPE deviceId, const wstring & name) = 0;
|
||||
//DeclareConstructorFromConfigWithNumInputs(DelayedValueNodeBase);
|
||||
DelayedValueNodeBase(DEVICEID_TYPE deviceId, const wstring & name) :
|
||||
Base(deviceId, name),
|
||||
m_delayedActivation(deviceId), m_pShiftedMBLayout(make_shared<MBLayout>())
|
||||
|
@ -68,6 +66,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
//m_gradientValues->Resize(row_size, col_size);
|
||||
//m_gradientValues->SetValue(0.0f);
|
||||
}
|
||||
DelayedValueNodeBase(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
DelayedValueNodeBase(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"defaultHiddenActivation"), configp->Get(L"rows"), configp->Get(L"cols"), configp->Get(L"timeStep"))
|
||||
{
|
||||
// We do NOT attach the inputs, as we cannot resolve them without causing a circular reference.
|
||||
// Instead, we capture them in a lambda, which will be called by ComputationNetwork during the build process through LateAttachInputs() below.
|
||||
// This is a contract between ComputationNetwork and this specific node type.
|
||||
m_attachInputsFn = [this, configp]() // This is the lambda to complete the process. Note that config captured as a shared_ptr.
|
||||
{
|
||||
AttachInputs(GetInputsFromConfig(configp)); // this is executed by network builder while iterating the nodes
|
||||
};
|
||||
}
|
||||
virtual void /*ILateAttachingNode::*/LateAttachInputs() override final
|
||||
{
|
||||
m_attachInputsFn();
|
||||
m_attachInputsFn = [](){ LogicError("LateAttachingNode::AttachInputs: must only be called once"); };
|
||||
}
|
||||
public:
|
||||
void SaveToFile(File& fstream) const
|
||||
{
|
||||
|
@ -353,7 +367,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
MBLayoutPtr m_delayedActivationMBLayout; // layout for m_delayedActivation
|
||||
int m_timeStep; // delay in frames (typ. 1)
|
||||
MBLayoutPtr m_pShiftedMBLayout; // individual sentence boundary information --TODO: do we actually need this separate variable?
|
||||
bool m_isHistoryCarryOverManagedExternally; // for PastValueNode only
|
||||
bool m_isHistoryCarryOverManagedExternally; // for PastValueNode only
|
||||
function<void()> m_attachInputsFn; // for late expansion of inputs (scripting)
|
||||
};
|
||||
|
||||
#define UsingDelayedValueNodeMembers UsingComputationNodeMembersBoilerplate; \
|
||||
|
@ -372,13 +387,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef DelayedValueNodeBase<ElemType, -1, MinibatchPackingFlags::SequenceStart> Base; UsingDelayedValueNodeMembers;
|
||||
static const std::wstring TypeName() { return L"PastValue"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(PastValueNode);
|
||||
PastValueNode(DEVICEID_TYPE deviceId, const wstring & name) :
|
||||
Base(deviceId, name)
|
||||
{ }
|
||||
PastValueNode(DEVICEID_TYPE deviceId, const wstring & name, ElemType initialActivationValue, size_t row_size, size_t col_size, size_t timeStep) :
|
||||
Base(deviceId, name, initialActivationValue, row_size, col_size, timeStep)
|
||||
{ }
|
||||
PastValueNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
Base(configp)
|
||||
{ }
|
||||
};
|
||||
|
||||
template class PastValueNode<float>;
|
||||
|
@ -396,13 +413,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef DelayedValueNodeBase<ElemType, +1, MinibatchPackingFlags::SequenceEnd> Base; UsingDelayedValueNodeMembers;
|
||||
static const std::wstring TypeName() { return L"FutureValue"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(FutureValueNode);
|
||||
FutureValueNode(DEVICEID_TYPE deviceId, const wstring & name) :
|
||||
Base(deviceId, name)
|
||||
{ }
|
||||
FutureValueNode(DEVICEID_TYPE deviceId, const wstring & name, ElemType initialActivationValue, size_t row_size, size_t col_size, size_t timeStep) :
|
||||
Base(deviceId, name, initialActivationValue, row_size, col_size, timeStep)
|
||||
{ }
|
||||
FutureValueNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
Base(configp)
|
||||
{ }
|
||||
};
|
||||
|
||||
template class FutureValueNode<float>;
|
||||
|
|
|
@ -164,12 +164,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef ReinterpretNodeBase<ElemType> Base; UsingReinterpretNodeBaseMembers;
|
||||
static const std::wstring TypeName() { return L"Reshape"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(ReshapeNode);
|
||||
ReshapeNode(DEVICEID_TYPE deviceId, const wstring & name, size_t numRows = 0, const ImageLayout & imageLayout = ImageLayoutWHC(0,0,0)) :
|
||||
Base(deviceId, name),
|
||||
m_numTargetRows(numRows),
|
||||
m_targetImageLayout(imageLayout)
|
||||
{ }
|
||||
ReshapeNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
ReshapeNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"numRows"), ImageLayoutWHC(configp->Get(L"imageWidth"), configp->Get(L"imageHeight"), configp->Get(L"imageChannels")))
|
||||
{
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
|
@ -483,12 +487,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"RowSlice"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(RowSliceNode);
|
||||
RowSliceNode(DEVICEID_TYPE deviceId, const wstring & name, size_t startIndex = 0, size_t numRows = 0) :
|
||||
Base(deviceId, name),
|
||||
m_startIndex(startIndex),
|
||||
m_sliceHeight(numRows)
|
||||
{ }
|
||||
RowSliceNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
RowSliceNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"startIndex"), configp->Get(L"numRows"))
|
||||
{
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
|
@ -638,11 +646,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"RowRepeat"; }
|
||||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(RowRepeatNode);
|
||||
RowRepeatNode(DEVICEID_TYPE deviceId, const wstring & name, size_t numRepeats = 1) :
|
||||
Base(deviceId, name),
|
||||
m_numRepeat(numRepeats)
|
||||
{ }
|
||||
RowRepeatNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
RowRepeatNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"numRepeats"))
|
||||
{
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче