This commit is contained in:
Vadim Mazalov 2019-01-22 15:50:36 -08:00
Родитель b891ec0759
Коммит a394906917
7 изменённых файлов: 55 добавлений и 14 удалений

Просмотреть файл

@ -109,7 +109,7 @@ struct MBLayout
// -------------------------------------------------------------------
MBLayout(size_t numParallelSequences, size_t numTimeSteps, const std::wstring &name)
: m_distanceToStart(CPUDEVICE), m_distanceToEnd(CPUDEVICE), m_columnsValidityMask(CPUDEVICE), m_rightSplice(0)
: m_distanceToStart(CPUDEVICE), m_distanceToEnd(CPUDEVICE), m_columnsValidityMask(CPUDEVICE), m_rightSplice(0), m_rightLookAhead(0)
{
Init(numParallelSequences, numTimeSteps);
SetUniqueAxisName(name != L"" ? name : L"DynamicAxis");
@ -142,6 +142,7 @@ struct MBLayout
m_columnsValidityMask.SetValue(other->m_columnsValidityMask);
m_writable = other->m_writable;
m_rightSplice = other->m_rightSplice;
m_rightLookAhead = other->m_rightLookAhead;
if (!keepName)
m_axisName = other->m_axisName;
@ -169,6 +170,7 @@ struct MBLayout
m_columnsValidityMask = std::move(other->m_columnsValidityMask);
m_writable = other->m_writable;
m_rightSplice = other->m_rightSplice;
m_rightLookAhead = other->m_rightLookAhead;
m_axisName = std::move(other->m_axisName);
}
@ -199,12 +201,14 @@ public:
m_writable = true;
}
void Init(size_t numParallelSequences, size_t numTimeSteps, size_t rightSplice)
void Init(size_t numParallelSequences, size_t numTimeSteps, size_t rightSplice, size_t rightLookAhead)
{
Init(numParallelSequences, numTimeSteps);
m_rightSplice = rightSplice;
if (numTimeSteps < rightSplice)
m_rightSplice = 0;
m_rightLookAhead = rightLookAhead;
}
// packing algorithm
@ -508,6 +512,11 @@ public:
return m_rightSplice;
}
size_t RightLookAhead() const
{
return m_rightLookAhead;
}
// test boundary flags for a specific condition
bool IsBeyondStartOrEnd(const FrameRange& fr) const;
bool IsGap(const FrameRange& fr) const;
@ -603,6 +612,7 @@ private:
vector<SequenceInfo> m_sequences;
// right splice for latency control blstm
size_t m_rightSplice;
size_t m_rightLookAhead;
private:
// -------------------------------------------------------------------
@ -1123,9 +1133,12 @@ static inline std::pair<size_t, size_t> ColumnRangeWithMBLayoutFor(size_t numCol
template <class ElemType>
static inline Matrix<ElemType> DataWithMBLayoutFor(const Matrix<ElemType> &data,
const FrameRange &fr /*select frame or entire batch*/,
const MBLayoutPtr &pMBLayout /*the MB layout of 'data'*/)
const MBLayoutPtr &pMBLayout /*the MB layout of 'data'*/,
bool applyLookAhead = false)
{
auto columnRange = ColumnRangeWithMBLayoutFor(data.GetNumCols(), fr, pMBLayout);
if (applyLookAhead)
columnRange.second = columnRange.second - fr.m_pMBLayout->RightLookAhead() * fr.m_pMBLayout->GetNumParallelSequences();
return data.ColumnSlice(columnRange.first, columnRange.second);
}
@ -1243,7 +1256,10 @@ static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MB
auto matrixSliceToMask = DataWithMBLayoutFor(matrixToMask, fr, pMBLayout);
if ((matrixSliceToMask.GetNumCols() % maskSlice.GetNumCols()) != 0)
{
fprintf(stderr, "matrixSliceToMask %zu maskSlice %zu", matrixSliceToMask.GetNumCols(), maskSlice.GetNumCols());
LogicError("MaskMissingColumnsTo: The number of columns of the matrix slice to be masked is not a multiple of the number of columns of the mask slice.");
}
matrixSliceToMask.MaskColumnsValue(maskSlice, val, matrixSliceToMask.GetNumCols() / maskSlice.GetNumCols());
}

Просмотреть файл

@ -1630,11 +1630,11 @@ public:
// function to access any input and output, value and gradient, whole batch or single frame
// Note: This returns a reference into 'data' in the form of a column slice, i.e. a small matrix object that just points into 'data'.
Matrix<ElemType> DataFor(Matrix<ElemType>& data, const FrameRange& fr /*select frame or entire batch*/)
Matrix<ElemType> DataFor(Matrix<ElemType>& data, const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false)
{
try
{
return DataWithMBLayoutFor(data, fr, m_pMBLayout);
return DataWithMBLayoutFor(data, fr, m_pMBLayout, applyLookAhead);
}
catch (const std::exception& e) // catch the error and rethrow it with the node name attached
{
@ -1648,17 +1648,17 @@ public:
}
#endif
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/) { return DataFor(Value(), fr); }
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/) { return DataFor(Gradient(), fr); }
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false) { return DataFor(Value(), fr, applyLookAhead); }
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false) { return DataFor(Gradient(), fr, applyLookAhead); }
#if 0 // causes grief with gcc
Matrix<ElemType> ValueFor (const FrameRange& fr /*select frame or entire batch*/) const { return DataFor(Value(), fr); }
Matrix<ElemType> GradientFor(const FrameRange& fr /*select frame or entire batch*/) const { return DataFor(Gradient(), fr); }
#endif
// use the following two versions if you assume the inputs may contain gaps that must be set to zero because you want to reduce over frames with a BLAS operation
Matrix<ElemType> MaskedValueFor(const FrameRange& fr /*select frame or entire batch*/)
Matrix<ElemType> MaskedValueFor(const FrameRange& fr /*select frame or entire batch*/, bool applyLookAhead = false)
{
MaskMissingValueColumnsToZero(fr);
return ValueFor(fr);
return ValueFor(fr, applyLookAhead);
}
Matrix<ElemType> MaskedGradientFor(const FrameRange& fr /*select frame or entire batch*/)
{

Просмотреть файл

@ -132,9 +132,18 @@ public:
virtual void BackpropToNonLooping(size_t inputIndex) override
{
FrameRange fr(InputRef(0).GetMBLayout());
auto input0 = InputRef(0).ValueFor(fr);
vector<ElemType> zeros(m_logSoftmaxOfRight->GetNumRows(), 0.0);
for (size_t colIndex = m_logSoftmaxOfRight->GetNumCols() - fr.m_pMBLayout->GetNumParallelSequences() * fr.m_pMBLayout->RightLookAhead(); colIndex < m_logSoftmaxOfRight->GetNumCols(); colIndex++)
{
m_softmaxOfRight->SetColumn(&zeros[0], colIndex);
input0.SetColumn(&zeros[0], colIndex);
}
// left input is scalar
if (inputIndex == 0) // left derivative
{
Gradient().Print("CrossEntropyWithSoftmax LEFT-gradientValues");
#if DUMPOUTPUT
m_logSoftmaxOfRight->Print("CrossEntropyWithSoftmax Partial-logSoftmaxOfRight");
Gradient().Print("CrossEntropyWithSoftmax Partial-gradientValues");
@ -150,20 +159,25 @@ public:
else if (inputIndex == 1) // right derivative
{
// Gradient().Print("CrossEntropyWithSoftmax Right-gradientValues");
#if DUMPOUTPUT
m_softmaxOfRight->Print("CrossEntropyWithSoftmax Partial-softmaxOfRight");
InputRef(0).ValueFor(fr).Print("CrossEntropyWithSoftmax Partial-inputFunctionValues");
Gradient().Print("CrossEntropyWithSoftmax Partial-gradientValues");
InputRef(1).GradientFor(fr).Print("CrossEntropyWithSoftmaxNode Partial-Right-in");
#endif
/*m_softmaxOfRight->Print("CrossEntropyWithSoftmax Partial-softmaxOfRight");
input0.Print("CrossEntropyWithSoftmax input0");*/
auto gradient = InputRef(1).GradientFor(fr);
Matrix<ElemType>::AddScaledDifference(Gradient(), *m_softmaxOfRight, InputRef(0).ValueFor(fr), gradient);
Matrix<ElemType>::AddScaledDifference(Gradient(), *m_softmaxOfRight, input0, gradient);
//gradient.Print("CrossEntropyWithSoftmax gradient");
#if DUMPOUTPUT
InputRef(1).GradientFor(fr).Print("CrossEntropyWithSoftmaxNode Partial-Right");
#endif
#ifdef _DEBUG
InputRef(1).InvalidateMissingGradientColumns(fr); // TODO: This should not be necessary.
//InputRef(1).InvalidateMissingGradientColumns(fr); // TODO: This should not be necessary.
#endif
}
}
@ -188,6 +202,7 @@ public:
// BUGBUG: No need to compute m_softmaxOfRight in ForwardProp, should be moved to BackpropTo().
m_softmaxOfRight->SetValue(*m_logSoftmaxOfRight);
m_softmaxOfRight->InplaceExp();
// flatten all gaps to zero, such that gaps will contribute zero to the sum
MaskMissingColumnsToZero(*m_logSoftmaxOfRight, InputRef(1).GetMBLayout(), fr);
// reduce over all frames

Просмотреть файл

@ -88,6 +88,10 @@ CompositeDataReader::CompositeDataReader(const ConfigParameters& config) :
if (m_rightSplice > m_truncationLength)
InvalidArgument("rightSplice should not be greater than truncation length");
m_rightLookAhead = config(L"rightLookAhead", 0);
if (m_rightLookAhead > m_truncationLength)
InvalidArgument("rightLookAhead should not be greater than truncation length");
m_precision = config("precision", "float");
// Creating deserializers.
@ -361,6 +365,7 @@ void CompositeDataReader::StartEpoch(const EpochConfiguration& cfg, const std::m
{
config.m_truncationSize = m_truncationLength;
config.m_rightSplice = m_rightSplice;
config.m_rightLookAhead = m_rightLookAhead;
}
ReaderBase::StartEpoch(config, inputDescriptions);

Просмотреть файл

@ -96,6 +96,9 @@ private:
// rightSplice(nr) for LC-BLSTM
size_t m_rightSplice;
// look ahead window in truncated BPTT chunk
size_t m_rightLookAhead;
};
}

Просмотреть файл

@ -28,7 +28,7 @@ using MSR_CNTK::MBLayoutPtr;
struct ReaderConfiguration
{
ReaderConfiguration()
: m_numberOfWorkers(0), m_workerRank(0), m_minibatchSizeInSamples(0), m_truncationSize(0), m_maxErrors(0), m_rightSplice(0)
: m_numberOfWorkers(0), m_workerRank(0), m_minibatchSizeInSamples(0), m_truncationSize(0), m_maxErrors(0), m_rightSplice(0), m_rightLookAhead(0)
{}
size_t m_numberOfWorkers; // Number of the Open MPI workers for the current epoch
@ -36,6 +36,7 @@ struct ReaderConfiguration
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
size_t m_rightSplice; // RightSplice for latency control BLSTM
size_t m_rightLookAhead; // Look-ahead window in a BPTT truncated chunk
size_t m_maxErrors; // Max number of errors to ignore
// This flag indicates whether the minibatches are allowed to overlap the boundary

Просмотреть файл

@ -168,6 +168,7 @@ void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, co
auto oldMinibatchSize = m_config.m_minibatchSizeInSamples;
auto oldTruncationSize = m_config.m_truncationSize;
m_config.m_rightSplice = config.m_rightSplice;
m_config.m_rightLookAhead = config.m_rightLookAhead;
PackerBase::SetConfiguration(config, memoryProviders);
@ -228,7 +229,7 @@ Minibatch TruncatedBPTTPacker::ReadMinibatch()
// all mblayouts should match anyway.
mbSeqIdToCorpusSeqId.clear();
m_currentLayouts[streamIndex]->Init(m_numParallelSequences, m_config.m_truncationSize, m_config.m_rightSplice);
m_currentLayouts[streamIndex]->Init(m_numParallelSequences, m_config.m_truncationSize, m_config.m_rightSplice, m_config.m_rightLookAhead);
size_t sequenceId = 0;
for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
{