folded SGD::ForwardBackward() back inline, since it is such an integral piece that warrants to be front and center

This commit is contained in:
Frank Seide 2015-12-04 20:14:35 -08:00
Родитель 9d5ea88604
Коммит 73fa9f8004
2 изменённых файлов: 61 добавлений и 73 удалений

Просмотреть файл

@ -1659,28 +1659,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
return result;
}
template<class ElemType>
void SGD<ElemType>::ForwardBackward(ComputationNetwork& net,
const std::vector<ComputationNodeBasePtr>& evalNodes,
shared_ptr<ComputationNodeBase> criterionNode,
bool isLRLargeEnough)
{
// evaluate eval nodes
// The bulk of this evaluation is reused in ComputeGradient() below.
net.ForwardProp(evalNodes);
// compute the gradient
// This is where the magic happens, baby!!
// forward prop
net.ForwardProp(criterionNode);
// backprop
// only compute gradient when learning rate is large enough
if (isLRLargeEnough)
net.Backprop(criterionNode);
}
template<class ElemType>
size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
ComputationNetworkPtr refNet,
@ -1766,27 +1744,28 @@ namespace Microsoft { namespace MSR { namespace CNTK {
refNet->StartEvaluateMinibatchLoop(refNode);
}
DataReaderHelpers::SubminibatchDispatcher<ElemType> smbDisplatcher;
size_t samplesInRAM = m_maxSamplesInRAM;
// convert it to SubminibatchRequested
size_t numSubminibatchRequested = 0;
if (samplesInRAM < SIZE_MAX) // if samplesInRAM = 0 , we will not use subminibatch dispatcher
// prepare for sub-minibatching
// Sub-minibatching is used if a single minibatch is too large to fit into GPU RAM.
DataReaderHelpers::SubminibatchDispatcher<ElemType> smbDispatcher;
size_t numSubminibatchesNeeded = 0;
if (m_maxSamplesInRAM < SIZE_MAX) // user-specified maximum number of samples that fit into GPU RAM; or 0 if not enabled
{
size_t nParallelSequences = trainSetDataReader->GetNumParallelSequences();
size_t estimatedMBSize = tunedMBSize * nParallelSequences;
numSubminibatchRequested = (size_t)std::ceil( (float)estimatedMBSize / samplesInRAM);
// into how many pieces would we need to break the minibatch?
// TODO: The following calculation relies on the ill-devised definition of "minibatch" of the current truncated BPTT implementation. Adapt this once fixed.
size_t numParallelSequences = trainSetDataReader->GetNumParallelSequences();
size_t estimatedMBSize = tunedMBSize * numParallelSequences;
numSubminibatchesNeeded = (size_t)std::ceil((float)estimatedMBSize / m_maxSamplesInRAM);
}
if (numSubminibatchRequested > 1) // only use subminibatch dispatcher if more than 1 subminibatch is required
{
smbDisplatcher.Init(net, learnableNodes, criterionNodes, evaluationNodes);
}
size_t actualNumSubminibatch=0;
// this is non-trivial, we need a manager object to handle this
if (numSubminibatchesNeeded > 1)
smbDispatcher.Init(net, learnableNodes, criterionNodes, evaluationNodes);
// Attemps to compute the error signal for the whole utterance, which will
// The following is a special feature only supported by the Kaldi2Reader for more efficient sequence training.
// This attemps to compute the error signal for the whole utterance, which will
// be fed to the neural network as features. Currently it is a workaround
// for the two-forward-pass sequence and ctc training, which allows
// processing more utterances at the same time. Only used in Kaldi2Reader.
// TODO: move the two-forward-pass support out of the reader.
// processing more utterances at the same time.
// TODO: move the two-forward-pass support out of the reader, make a first-class citizen.
AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, featureNodes, inputMatrices);
fprintf(stderr, "\nStarting minibatch loop");
@ -1799,9 +1778,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
fprintf(stderr, ", distributed reading is ENABLED");
}
if (numSubminibatchRequested > 1)
if (numSubminibatchesNeeded > 1)
{
fprintf(stderr, ", with maximum %d samples in RAM", (int)samplesInRAM);
fprintf(stderr, ", with maximum %d samples in RAM", (int)m_maxSamplesInRAM);
}
fprintf(stderr, ".\n");
@ -1822,15 +1801,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
nSamplesSinceLastModelSync += actualMBSize;
if (numSubminibatchRequested > 1)
{
actualNumSubminibatch = smbDisplatcher.GetMinibatchIntoCache(*trainSetDataReader, *net, *inputMatrices, numSubminibatchRequested);
}
else
{
actualNumSubminibatch = 1;
}
// node data was changed
// TODO: move this to that function as well--just tired to pass everything as arguments
// TODO: We should do this right after the GetMinibatch() call, since that's where these changed.
@ -1845,7 +1815,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (m_doGradientCheck && GradientCheck(net, criterionNodes, learnableNodes, 0) == false)
LogicError("cannot pass gradient checker");
#endif
// TODO: currently only support one node regularization
// TODO: currently we only support one node for regularization
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode)
{
#if 0 // TODO: where does refNet get its features from?
@ -1866,31 +1836,50 @@ namespace Microsoft { namespace MSR { namespace CNTK {
dynamic_pointer_cast<ComputationNode<ElemType>>(labelNodes[0])->Output());
}
//compute eval node first since when gradient is computed the forward function values
//may be changed and need to be recomputed when gradient and function value share the same matrix
if (actualNumSubminibatch > 1)
{
for (size_t ismb = 0; ismb < actualNumSubminibatch; ismb++)
{
smbDisplatcher.GetSubMinibatchToNet(ismb);
ComputationNetwork::UpdateEvalTimeStamps(featureNodes);
ComputationNetwork::UpdateEvalTimeStamps(labelNodes);
ForwardBackward(*net, evaluationNodes, criterionNodes[0], learnRatePerSample > 0.01 * m_minLearnRate);
smbDisplatcher.DoneWithCurrentSubMinibatch(ismb);
}
smbDisplatcher.DoneWithCurrentMinibatch();
}
else
{
ForwardBackward(*net, evaluationNodes, criterionNodes[0], learnRatePerSample > 0.01 * m_minLearnRate);
}
// do forward and back propagation
// We optionally break the minibatch into sub-minibatches.
// This, when enabled, is used when a full minibatch does not fit into GPU RAM.
size_t actualNumSubminibatches = numSubminibatchesNeeded == 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*trainSetDataReader, *net, *inputMatrices, numSubminibatchesNeeded);
for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++)
{
if (actualNumSubminibatches > 1)
{
smbDispatcher.GetSubMinibatchToNet(ismb); // get sub-minibatch from full-size one
ComputationNetwork::UpdateEvalTimeStamps(featureNodes);
ComputationNetwork::UpdateEvalTimeStamps(labelNodes);
}
// ===========================================================
// forward prop for evaluate eval nodes
// ===========================================================
// compute eval node first since when gradient is computed the forward function values
// may be changed and need to be recomputed when gradient and function value share the same matrix
net->ForwardProp(evaluationNodes); // the bulk of this evaluation is reused in ComputeGradient() below
// ===========================================================
// forward prop for training criterion
// ===========================================================
net->ForwardProp(criterionNodes[0]);
// ===========================================================
// backprop
// ===========================================================
if (learnRatePerSample > 0.01 * m_minLearnRate) // only compute gradient when learning rate is large enough
net->Backprop(criterionNodes[0]);
// house-keeping for sub-minibatching
if (actualNumSubminibatches > 1)
smbDispatcher.DoneWithCurrentSubMinibatch(ismb); // page state out
} // end sub-minibatch loop
if (actualNumSubminibatches > 1)
smbDispatcher.DoneWithCurrentMinibatch();
} // if (actualMBSize > 0)
// Some labels may be missing (e.g. forced alignment failed, or being gaps due to packing parallel sequences).
//for now since we share the same label masking flag we call this on the network.
//Later, when we apply different labels on different nodes
//we need to add code to call this function multiple times, one for each criteria node
// for progress and statistics, we should only count frames that are not gaps
size_t numSamplesWithLabel = net->GetNumSamplesWithLabel(actualMBSize);
// Sum of actualMBSize across all nodes when using parallel training

Просмотреть файл

@ -493,7 +493,6 @@ protected:
private:
int SGDTrace(FILE *__restrict __stream, const char *__restrict __format, ...);
void ForwardBackward(ComputationNetwork& net,const std::vector<ComputationNodeBasePtr>& evalNodes,shared_ptr<ComputationNodeBase> criterionNode,bool dobackpropogate);
};
}}}