DelayedValueNodeBase::BackpropTo() now also uses a masked operation
This commit is contained in:
Родитель
484059e055
Коммит
61cf0f76d1
|
@ -27,9 +27,10 @@ DelayedValueNodeBase<ElemType, direction>::DelayedValueNodeBase(DEVICEID_TYPE de
|
|||
ElemType initialActivationValue, const TensorShape& sampleLayout,
|
||||
size_t timeStep) :
|
||||
Base(deviceId, name),
|
||||
m_initialActivationValueMatrix(make_shared<Matrix<ElemType>>(deviceId)),
|
||||
m_sourceFrameValidMatrix(make_shared<Matrix<ElemType>>(deviceId)),
|
||||
m_delayedValue(make_shared<Matrix<ElemType>>(deviceId)),
|
||||
m_initialActivationValueMatrix(make_shared<Matrix<ElemType>>(deviceId))
|
||||
m_zeroMatrix(make_shared<Matrix<ElemType>>(deviceId)),
|
||||
m_delayedValue(make_shared<Matrix<ElemType>>(deviceId))
|
||||
{
|
||||
m_initialActivationValue = initialActivationValue;
|
||||
m_timeStep = 1;
|
||||
|
@ -37,6 +38,8 @@ DelayedValueNodeBase<ElemType, direction>::DelayedValueNodeBase(DEVICEID_TYPE de
|
|||
SetDims(sampleLayout, HasMBLayout() /*false at this point*/);
|
||||
m_initialActivationValueMatrix->Resize(1, 1);
|
||||
m_initialActivationValueMatrix->SetValue(m_initialActivationValue);
|
||||
m_zeroMatrix->Resize(1, 1);
|
||||
m_zeroMatrix->SetValue((ElemType)0);
|
||||
m_timeStep = (int)timeStep;
|
||||
}
|
||||
|
||||
|
@ -49,6 +52,7 @@ template<class ElemType, int direction>
|
|||
auto node = dynamic_pointer_cast<DelayedValueNodeBase<ElemType, direction /*, SequenceStart_or_End*/>>(nodeP);
|
||||
node->m_timeStep = m_timeStep;
|
||||
node->m_initialActivationValue = m_initialActivationValue;
|
||||
node->m_initialActivationValueMatrix->SetValue(m_initialActivationValue);
|
||||
node->m_delayedValue->SetValue(*m_delayedValue);
|
||||
if (m_delayedActivationMBLayout)
|
||||
(node->m_delayedActivationMBLayout = make_shared<MBLayout>())->CopyFrom(m_delayedActivationMBLayout);
|
||||
|
@ -83,7 +87,10 @@ template<class ElemType, int direction>
|
|||
m_delayedValue->Resize(m_sampleLayout.GetNumElements(), 0); // Note: If we try to access history in first minibatch, we shall crash. It would be a consequence of a missing sentence-begin flag
|
||||
|
||||
if (modelVersion >= CNTK_MODEL_VERSION_2)
|
||||
{
|
||||
fstream >> m_initialActivationValue;
|
||||
m_initialActivationValueMatrix->SetValue(m_initialActivationValue);
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType, int direction>
|
||||
|
@ -101,9 +108,10 @@ template<class ElemType, int direction>
|
|||
fstream << m_initialActivationValue;
|
||||
}
|
||||
|
||||
// determine which parallel sequences should be copied; return true if we can do all (where we don't care about gaps)
|
||||
// determine which parallel sequences have a valid source frame to be copied/propagated
|
||||
// Remember that we ask about copying from the delayed position to the current position.
|
||||
// Gaps will be considered as "shall be copied", since that gives us a higher chance of collating all.
|
||||
// This function also determines whether we can do it all (allValid) or none (!anyValid).
|
||||
// Gaps will be considered as "valid", since that gives us a higher chance of collating all.
|
||||
template<class ElemType, int direction>
|
||||
/*private*/ void DelayedValueNodeBase<ElemType, direction>::DetermineValidMask(const FrameRange& frDelayed, bool& anyValid, bool& allValid)
|
||||
{
|
||||
|
@ -113,7 +121,7 @@ template<class ElemType, int direction>
|
|||
anyValid = true;
|
||||
return;
|
||||
}
|
||||
// create a vector of 0 and 1 for every parallel sequence
|
||||
// create a vector of 0 and 1 for every parallel sequence:
|
||||
// 0 --> source frame invalid: do not copy/propagate
|
||||
// 1 --> source frame valid (or target is gap): copy/propagate
|
||||
let S = GetNumParallelSequences();
|
||||
|
@ -131,7 +139,18 @@ template<class ElemType, int direction>
|
|||
anyValid = numValid > 0;
|
||||
allValid = numValid == S;
|
||||
if (allValid)
|
||||
return; // all valid (or gap): just copy all --breakpoint has not been hit; that's not right (gaps!)
|
||||
sin(1.0); // all valid (or gap): just copy all --breakpoint has not been hit; that's not right (gaps!)
|
||||
}
|
||||
|
||||
// convert the m_sourceFrameValid vector into a (potentially GPU-side) TensorView
|
||||
template<class ElemType, int direction>
|
||||
/*private*/ TensorView<ElemType> DelayedValueNodeBase<ElemType, direction>::MakeMaskTensor(size_t rank) const
|
||||
{
|
||||
// send to Matrix object (likely living on a GPU)
|
||||
m_sourceFrameValidMatrix->SetValue(1, m_sourceFrameValid.size(), m_deviceId, const_cast<ElemType*>(m_sourceFrameValid.data()), matrixFlagNormal);
|
||||
// tensor shape is a 1-frame sequence, one element per parallel sequence.
|
||||
auto tensorShape = TensorShape(1).AppendInPlace(rank, GetMBLayout()->GetNumParallelSequences());
|
||||
return TensorView<ElemType>(m_sourceFrameValidMatrix, tensorShape);
|
||||
}
|
||||
|
||||
// This function assumes EndForwardProp() to be called after the iteration loop.
|
||||
|
@ -158,15 +177,13 @@ template<class ElemType, int direction>
|
|||
// compute logical position of delayed value
|
||||
assert(m_timeStep > 0);
|
||||
|
||||
//size_t t = fr.t();
|
||||
|
||||
// determine the parallel sequences to mask
|
||||
bool anyValid, allValid;
|
||||
DetermineValidMask(frDelayed, anyValid, allValid);
|
||||
|
||||
// source tensor --considering truncated BPTT
|
||||
size_t rank = DetermineElementwiseTensorRank();
|
||||
TensorView<ElemType> inp;
|
||||
TensorView<ElemType> src;
|
||||
int t_delayed = (int)(fr.t() + direction * m_timeStep); // this might end up outside the current window
|
||||
if (t_delayed < 0) // handle special case of truncated BPTT
|
||||
{
|
||||
|
@ -179,7 +196,7 @@ template<class ElemType, int direction>
|
|||
auto tensorShape = GetTensorShape(rank);
|
||||
auto slice = TensorSliceWithMBLayoutFor(tensorShape.GetDims(), FrameRange(m_delayedActivationMBLayout, t_delayed/*<0*/ + T_delayedActivation), m_delayedActivationMBLayout);
|
||||
tensorShape.NarrowTo(slice);
|
||||
inp = TensorView<ElemType>(m_delayedValue, tensorShape);
|
||||
src = TensorView<ElemType>(m_delayedValue, tensorShape);
|
||||
}
|
||||
else
|
||||
LogicError("The delay node tries to access past values that are out of bound, possibly because there is no sentence start marker in the MBLayout.");
|
||||
|
@ -187,10 +204,10 @@ template<class ElemType, int direction>
|
|||
else if (t_delayed >= GetNumTimeSteps()) // truncated BPTT goes left-to-right only
|
||||
LogicError("The delay node tries to access future values that are out of bound, possibly because there is no sentence end marker in the MBLayout.");
|
||||
else // regular case
|
||||
inp = Input(0)->ValueTensorFor(rank, frDelayed);
|
||||
src = Input(0)->ValueTensorFor(rank, frDelayed);
|
||||
|
||||
// target tensor
|
||||
auto out = ValueTensorFor(rank, fr);
|
||||
auto tgt = ValueTensorFor(rank, fr);
|
||||
|
||||
// init value tensor (a [1] tensor with broadcasting)
|
||||
TensorView<ElemType> init(m_initialActivationValueMatrix, TensorShape(1));
|
||||
|
@ -198,20 +215,15 @@ template<class ElemType, int direction>
|
|||
// now perform the copy operation
|
||||
if (allValid) // all frames are valid: copy as one tensor-copy operation
|
||||
{
|
||||
out.AssignCopyOf(inp);
|
||||
tgt.AssignCopyOf(src);
|
||||
}
|
||||
else if (!anyValid) // no frame is valid: initialize from init value
|
||||
{
|
||||
out.AssignCopyOf(init);
|
||||
tgt.AssignCopyOf(init);
|
||||
}
|
||||
else // some are valid, some are not: use a OpCond to select 'inp' for valid and 'init' for invalid frames
|
||||
else // some are valid, some are not: use a OpCond to select 'src' for valid and 'init' for invalid frames
|
||||
{
|
||||
// treat mask like a 1-frame sequence, one element per parallel sequence
|
||||
m_sourceFrameValidMatrix->SetValue(1, m_sourceFrameValid.size(), m_deviceId, m_sourceFrameValid.data(), matrixFlagNormal);
|
||||
auto tensorShape = TensorShape(1).AppendInPlace(rank, GetMBLayout()->GetNumParallelSequences());
|
||||
TensorView<ElemType> cond(m_sourceFrameValidMatrix, tensorShape);
|
||||
// now assign either input or init value, based on the mask
|
||||
out.AssignCondOf(cond, inp, init);
|
||||
tgt.AssignCondOf(MakeMaskTensor(rank), src, init); // assign either input or init value, based on the mask
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,43 +268,34 @@ template<class ElemType, int direction>
|
|||
return;
|
||||
}
|
||||
|
||||
// we backpropagated into the delayed frame
|
||||
FrameRange frDelayed = fr.WithTimeOffset(direction * m_timeStep);
|
||||
|
||||
// if delayed input is within valid time range then add its gradient
|
||||
size_t rank = DetermineElementwiseTensorRank();
|
||||
size_t t = fr.t();
|
||||
int t_delayed = (int) (t + direction * m_timeStep); // this might end up outside the current window
|
||||
if (t_delayed >= 0 && t_delayed < GetNumTimeSteps()) // only propagate if our source is inside the minibatch
|
||||
{
|
||||
size_t S = m_pMBLayout->GetNumParallelSequences();
|
||||
// Boundary frames must not propagate. Gaps must also not propagate.
|
||||
// if there is a boundary in this frame, we treat each stream separately; otherwise we do all in one go
|
||||
if (m_pMBLayout->IsGap(fr) || m_pMBLayout->IsBeyondStartOrEnd(frDelayed)) // true if at least one parallel sequence we pull the gradient for has a boundary
|
||||
// we backpropagated into the delayed frame
|
||||
FrameRange frTgt = fr.WithTimeStep(t_delayed); // target frame
|
||||
FrameRange frSrc = frTgt.WithTimeOffset(direction * -m_timeStep); // source frame
|
||||
|
||||
// determine the parallel sequences to mask
|
||||
bool anyValid, allValid;
|
||||
DetermineValidMask(frSrc, anyValid, allValid);
|
||||
|
||||
auto src = GradientTensorFor(rank, frSrc); // incoming gradient from top
|
||||
auto tgt = Input(0)->GradientTensorFor(rank, frTgt); // outgoing gradient to input
|
||||
TensorView<ElemType> zero(m_zeroMatrix, TensorShape(1));
|
||||
|
||||
if (allValid) // all valid: just jam it over in one go
|
||||
{
|
||||
#if 0
|
||||
m_backpropMask.resize(S);
|
||||
for (size_t s = 0; s < S; s++)
|
||||
m_backpropMask[s] = (m_pMBLayout->IsGap(fr.Sequence(s)) || m_pMBLayout->IsBeyondStartOrEnd(frDelayed.Sequence(s))) ? 0 : 1;
|
||||
m_backpropMaskMatrix.SetValue(m_backpropMask.data()); ...
|
||||
#endif
|
||||
for (size_t s = 0; s < S; s++)
|
||||
{
|
||||
if (!m_pMBLayout->IsGap(fr.Sequence(s)) && !m_pMBLayout->IsBeyondStartOrEnd(frDelayed.Sequence(s))) // don't propagate boundary frames or gaps
|
||||
{
|
||||
auto frm = GradientTensorFor(rank, fr.Sequence(s));
|
||||
auto to = Input(0)->GradientTensorFor(rank, frDelayed.Sequence(s));
|
||||
to.AddCopyOf(frm);
|
||||
}
|
||||
}
|
||||
tgt.AddCopyOf(src);
|
||||
}
|
||||
else // operate on entire time step in one go (over all parallel sequences)
|
||||
else if (anyValid) // // some are valid, some are not: use a OpCond tgt select 'src' for valid and 'zero' for invalid frames
|
||||
{
|
||||
// TODO: change this to a TensorView operation
|
||||
auto frm = GradientTensorFor(rank, fr);
|
||||
auto to = Input(0)->GradientTensorFor(rank, frDelayed);
|
||||
to.AddCopyOf(frm);
|
||||
tgt.AddCondOf(MakeMaskTensor(rank), src, zero); // now add either source or zero value, based on the mask
|
||||
}
|
||||
else // none valid: nothing tgt back-prop
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ class DelayedValueNodeBase : public ComputationNode<ElemType>, public IRecurrent
|
|||
|
||||
private:
|
||||
void DetermineValidMask(const FrameRange& frDelayed, bool& anyValid, bool& allValid);
|
||||
TensorView<ElemType> MakeMaskTensor(size_t rank) const;
|
||||
|
||||
protected:
|
||||
DelayedValueNodeBase(DEVICEID_TYPE deviceId, const wstring& name, ElemType initialActivationValue, const TensorShape& sampleLayout, size_t timeStep);
|
||||
|
@ -74,13 +75,15 @@ public:
|
|||
|
||||
protected:
|
||||
ElemType m_initialActivationValue; // starting value for hidden activation vector at boundary
|
||||
shared_ptr<Matrix<ElemType>> m_initialActivationValueMatrix; // ...and as a potentially GPU-side object
|
||||
int m_timeStep; // delay in frames (typ. 1)
|
||||
|
||||
function<void()> m_attachInputsFn; // for late expansion of inputs (scripting)
|
||||
|
||||
vector<ElemType> m_sourceFrameValid; // mask for copying/propagating source frames is prepared here...
|
||||
shared_ptr<Matrix<ElemType>> m_sourceFrameValidMatrix; // ...and used from here
|
||||
|
||||
shared_ptr<Matrix<ElemType>> m_initialActivationValueMatrix; // potentially GPU-side versions
|
||||
shared_ptr<Matrix<ElemType>> m_sourceFrameValidMatrix;
|
||||
shared_ptr<Matrix<ElemType>> m_zeroMatrix; // constant [1]-dimensional 0 used for backprop
|
||||
|
||||
shared_ptr<Matrix<ElemType>> m_delayedValue; // saves the activation of the previous step that this node points to
|
||||
MBLayoutPtr m_delayedActivationMBLayout; // layout for m_delayedValue
|
||||
|
|
Загрузка…
Ссылка в новой задаче