Enable right context
This commit is contained in:
Родитель
b891ec0759
Коммит
a394906917
|
@ -109,7 +109,7 @@ struct MBLayout
|
|||
// -------------------------------------------------------------------
|
||||
|
||||
MBLayout(size_t numParallelSequences, size_t numTimeSteps, const std::wstring &name)
|
||||
: m_distanceToStart(CPUDEVICE), m_distanceToEnd(CPUDEVICE), m_columnsValidityMask(CPUDEVICE), m_rightSplice(0)
|
||||
: m_distanceToStart(CPUDEVICE), m_distanceToEnd(CPUDEVICE), m_columnsValidityMask(CPUDEVICE), m_rightSplice(0), m_rightLookAhead(0)
|
||||
{
|
||||
Init(numParallelSequences, numTimeSteps);
|
||||
SetUniqueAxisName(name != L"" ? name : L"DynamicAxis");
|
||||
|
@ -142,6 +142,7 @@ struct MBLayout
|
|||
m_columnsValidityMask.SetValue(other->m_columnsValidityMask);
|
||||
m_writable = other->m_writable;
|
||||
m_rightSplice = other->m_rightSplice;
|
||||
m_rightLookAhead = other->m_rightLookAhead;
|
||||
|
||||
if (!keepName)
|
||||
m_axisName = other->m_axisName;
|
||||
|
@ -169,6 +170,7 @@ struct MBLayout
|
|||
m_columnsValidityMask = std::move(other->m_columnsValidityMask);
|
||||
m_writable = other->m_writable;
|
||||
m_rightSplice = other->m_rightSplice;
|
||||
m_rightLookAhead = other->m_rightLookAhead;
|
||||
|
||||
m_axisName = std::move(other->m_axisName);
|
||||
}
|
||||
|
@ -199,12 +201,14 @@ public:
|
|||
m_writable = true;
|
||||
}
|
||||
|
||||
void Init(size_t numParallelSequences, size_t numTimeSteps, size_t rightSplice)
|
||||
void Init(size_t numParallelSequences, size_t numTimeSteps, size_t rightSplice, size_t rightLookAhead)
|
||||
{
|
||||
Init(numParallelSequences, numTimeSteps);
|
||||
m_rightSplice = rightSplice;
|
||||
if (numTimeSteps < rightSplice)
|
||||
m_rightSplice = 0;
|
||||
|
||||
m_rightLookAhead = rightLookAhead;
|
||||
}
|
||||
|
||||
// packing algorithm
|
||||
|
@ -508,6 +512,11 @@ public:
|
|||
return m_rightSplice;
|
||||
}
|
||||
|
||||
size_t RightLookAhead() const
|
||||
{
|
||||
return m_rightLookAhead;
|
||||
}
|
||||
|
||||
// test boundary flags for a specific condition
|
||||
bool IsBeyondStartOrEnd(const FrameRange& fr) const;
|
||||
bool IsGap(const FrameRange& fr) const;
|
||||
|
@ -603,6 +612,7 @@ private:
|
|||
vector<SequenceInfo> m_sequences;
|
||||
// right splice for latency control blstm
|
||||
size_t m_rightSplice;
|
||||
size_t m_rightLookAhead;
|
||||
|
||||
private:
|
||||
// -------------------------------------------------------------------
|
||||
|
@ -1123,9 +1133,12 @@ static inline std::pair<size_t, size_t> ColumnRangeWithMBLayoutFor(size_t numCol
|
|||
template <class ElemType>
|
||||
static inline Matrix<ElemType> DataWithMBLayoutFor(const Matrix<ElemType> &data,
|
||||
const FrameRange &fr /*select frame or entire batch*/,
|
||||
const MBLayoutPtr &pMBLayout /*the MB layout of 'data'*/)
|
||||
const MBLayoutPtr &pMBLayout /*the MB layout of 'data'*/,
|
||||
bool applyLookAhead = false)
|
||||
{
|
||||
auto columnRange = ColumnRangeWithMBLayoutFor(data.GetNumCols(), fr, pMBLayout);
|
||||
if (applyLookAhead)
|
||||
columnRange.second = columnRange.second - fr.m_pMBLayout->RightLookAhead() * fr.m_pMBLayout->GetNumParallelSequences();
|
||||
return data.ColumnSlice(columnRange.first, columnRange.second);
|
||||
}
|
||||
|
||||
|
@ -1243,7 +1256,10 @@ static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MB
|
|||
|
||||
auto matrixSliceToMask = DataWithMBLayoutFor(matrixToMask, fr, pMBLayout);
|
||||
if ((matrixSliceToMask.GetNumCols() % maskSlice.GetNumCols()) != 0)
|
||||
{
|
||||
fprintf(stderr, "matrixSliceToMask %zu maskSlice %zu", matrixSliceToMask.GetNumCols(), maskSlice.GetNumCols());
|
||||
LogicError("MaskMissingColumnsTo: The number of columns of the matrix slice to be masked is not a multiple of the number of columns of the mask slice.");
|
||||
}
|
||||
|
||||
matrixSliceToMask.MaskColumnsValue(maskSlice, val, matrixSliceToMask.GetNumCols() / maskSlice.GetNumCols());
|
||||
}
|
||||
|
|
|
@ -1630,11 +1630,11 @@ public:
|
|||
|
||||
// function to access any input and output, value and gradient, whole batch or single frame
|
||||
// Note: This returns a reference into 'data' in the form of a column slice, i.e. a small matrix object that just points into 'data'.
|
||||
Matrix<ElemType> DataFor(Matrix<ElemType>& data, const FrameRange& fr /*select frame or entire batch*/)
|
||||
Matrix<ElemType> DataFor(Matrix<ElemType>& data, const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false)
|
||||
{
|
||||
try
|
||||
{
|
||||
return DataWithMBLayoutFor(data, fr, m_pMBLayout);
|
||||
return DataWithMBLayoutFor(data, fr, m_pMBLayout, applyLookAhead);
|
||||
}
|
||||
catch (const std::exception& e) // catch the error and rethrow it with the node name attached
|
||||
{
|
||||
|
@ -1648,17 +1648,17 @@ public:
|
|||
}
|
||||
#endif
|
||||
|
||||
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/) { return DataFor(Value(), fr); }
|
||||
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/) { return DataFor(Gradient(), fr); }
|
||||
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false) { return DataFor(Value(), fr, applyLookAhead); }
|
||||
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false) { return DataFor(Gradient(), fr, applyLookAhead); }
|
||||
#if 0 // causes grief with gcc
|
||||
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/) const { return DataFor(Value(), fr); }
|
||||
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/) const { return DataFor(Gradient(), fr); }
|
||||
#endif
|
||||
// use the following two versions if you assume the inputs may contain gaps that must be set to zero because you want to reduce over frames with a BLAS operation
|
||||
Matrix<ElemType> MaskedValueFor(const FrameRange& fr /*select frame or entire batch*/)
|
||||
Matrix<ElemType> MaskedValueFor(const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false)
|
||||
{
|
||||
MaskMissingValueColumnsToZero(fr);
|
||||
return ValueFor(fr);
|
||||
return ValueFor(fr, applyLookAhead);
|
||||
}
|
||||
Matrix<ElemType> MaskedGradientFor(const FrameRange& fr /*select frame or entire batch*/)
|
||||
{
|
||||
|
|
|
@ -132,9 +132,18 @@ public:
|
|||
virtual void BackpropToNonLooping(size_t inputIndex) override
|
||||
{
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
auto input0 = InputRef(0).ValueFor(fr);
|
||||
vector<ElemType> zeros(m_logSoftmaxOfRight->GetNumRows(), 0.0);
|
||||
for (size_t colIndex = m_logSoftmaxOfRight->GetNumCols() - fr.m_pMBLayout->GetNumParallelSequences() * fr.m_pMBLayout->RightLookAhead(); colIndex < m_logSoftmaxOfRight->GetNumCols(); colIndex++)
|
||||
{
|
||||
m_softmaxOfRight->SetColumn(&zeros[0], colIndex);
|
||||
input0.SetColumn(&zeros[0], colIndex);
|
||||
}
|
||||
|
||||
// left input is scalar
|
||||
if (inputIndex == 0) // left derivative
|
||||
{
|
||||
Gradient().Print("CrossEntropyWithSoftmax LEFT-gradientValues");
|
||||
#if DUMPOUTPUT
|
||||
m_logSoftmaxOfRight->Print("CrossEntropyWithSoftmax Partial-logSoftmaxOfRight");
|
||||
Gradient().Print("CrossEntropyWithSoftmax Partial-gradientValues");
|
||||
|
@ -150,20 +159,25 @@ public:
|
|||
|
||||
else if (inputIndex == 1) // right derivative
|
||||
{
|
||||
// Gradient().Print("CrossEntropyWithSoftmax Right-gradientValues");
|
||||
#if DUMPOUTPUT
|
||||
m_softmaxOfRight->Print("CrossEntropyWithSoftmax Partial-softmaxOfRight");
|
||||
InputRef(0).ValueFor(fr).Print("CrossEntropyWithSoftmax Partial-inputFunctionValues");
|
||||
Gradient().Print("CrossEntropyWithSoftmax Partial-gradientValues");
|
||||
|
||||
InputRef(1).GradientFor(fr).Print("CrossEntropyWithSoftmaxNode Partial-Right-in");
|
||||
#endif
|
||||
/*m_softmaxOfRight->Print("CrossEntropyWithSoftmax Partial-softmaxOfRight");
|
||||
input0.Print("CrossEntropyWithSoftmax input0");*/
|
||||
|
||||
|
||||
auto gradient = InputRef(1).GradientFor(fr);
|
||||
Matrix<ElemType>::AddScaledDifference(Gradient(), *m_softmaxOfRight, InputRef(0).ValueFor(fr), gradient);
|
||||
Matrix<ElemType>::AddScaledDifference(Gradient(), *m_softmaxOfRight, input0, gradient);
|
||||
//gradient.Print("CrossEntropyWithSoftmax gradient");
|
||||
#if DUMPOUTPUT
|
||||
InputRef(1).GradientFor(fr).Print("CrossEntropyWithSoftmaxNode Partial-Right");
|
||||
#endif
|
||||
#ifdef _DEBUG
|
||||
InputRef(1).InvalidateMissingGradientColumns(fr); // TODO: This should not be necessary.
|
||||
//InputRef(1).InvalidateMissingGradientColumns(fr); // TODO: This should not be necessary.
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -188,6 +202,7 @@ public:
|
|||
// BUGBUG: No need to compute m_softmaxOfRight in ForwardProp, should be moved to BackpropTo().
|
||||
m_softmaxOfRight->SetValue(*m_logSoftmaxOfRight);
|
||||
m_softmaxOfRight->InplaceExp();
|
||||
|
||||
// flatten all gaps to zero, such that gaps will contribute zero to the sum
|
||||
MaskMissingColumnsToZero(*m_logSoftmaxOfRight, InputRef(1).GetMBLayout(), fr);
|
||||
// reduce over all frames
|
||||
|
|
|
@ -88,6 +88,10 @@ CompositeDataReader::CompositeDataReader(const ConfigParameters& config) :
|
|||
if (m_rightSplice > m_truncationLength)
|
||||
InvalidArgument("rightSplice should not be greater than truncation length");
|
||||
|
||||
m_rightLookAhead = config(L"rightLookAhead", 0);
|
||||
if (m_rightLookAhead > m_truncationLength)
|
||||
InvalidArgument("rightLookAhead should not be greater than truncation length");
|
||||
|
||||
m_precision = config("precision", "float");
|
||||
|
||||
// Creating deserializers.
|
||||
|
@ -361,6 +365,7 @@ void CompositeDataReader::StartEpoch(const EpochConfiguration& cfg, const std::m
|
|||
{
|
||||
config.m_truncationSize = m_truncationLength;
|
||||
config.m_rightSplice = m_rightSplice;
|
||||
config.m_rightLookAhead = m_rightLookAhead;
|
||||
}
|
||||
|
||||
ReaderBase::StartEpoch(config, inputDescriptions);
|
||||
|
|
|
@ -96,6 +96,9 @@ private:
|
|||
|
||||
// rightSplice(nr) for LC-BLSTM
|
||||
size_t m_rightSplice;
|
||||
|
||||
// look ahead window in truncated BPTT chunk
|
||||
size_t m_rightLookAhead;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ using MSR_CNTK::MBLayoutPtr;
|
|||
struct ReaderConfiguration
|
||||
{
|
||||
ReaderConfiguration()
|
||||
: m_numberOfWorkers(0), m_workerRank(0), m_minibatchSizeInSamples(0), m_truncationSize(0), m_maxErrors(0), m_rightSplice(0)
|
||||
: m_numberOfWorkers(0), m_workerRank(0), m_minibatchSizeInSamples(0), m_truncationSize(0), m_maxErrors(0), m_rightSplice(0), m_rightLookAhead(0)
|
||||
{}
|
||||
|
||||
size_t m_numberOfWorkers; // Number of the Open MPI workers for the current epoch
|
||||
|
@ -36,6 +36,7 @@ struct ReaderConfiguration
|
|||
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
|
||||
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
|
||||
size_t m_rightSplice; // RightSplice for latency control BLSTM
|
||||
size_t m_rightLookAhead; // Look-ahead window in a BPTT truncated chunk
|
||||
size_t m_maxErrors; // Max number of errors to ignore
|
||||
|
||||
// This flag indicates whether the minibatches are allowed to overlap the boundary
|
||||
|
|
|
@ -168,6 +168,7 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
|
|||
auto oldMinibatchSize = m_config.m_minibatchSizeInSamples;
|
||||
auto oldTruncationSize = m_config.m_truncationSize;
|
||||
m_config.m_rightSplice = config.m_rightSplice;
|
||||
m_config.m_rightLookAhead = config.m_rightLookAhead;
|
||||
|
||||
PackerBase::SetConfiguration(config, memoryProviders);
|
||||
|
||||
|
@ -228,7 +229,7 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
|
|||
// all mblayouts should match anyway.
|
||||
mbSeqIdToCorpusSeqId.clear();
|
||||
|
||||
m_currentLayouts[streamIndex]->Init(m_numParallelSequences, m_config.m_truncationSize, m_config.m_rightSplice);
|
||||
m_currentLayouts[streamIndex]->Init(m_numParallelSequences, m_config.m_truncationSize, m_config.m_rightSplice, m_config.m_rightLookAhead);
|
||||
size_t sequenceId = 0;
|
||||
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче