Merge branch 'jdroppo/cudnn-rnn-lstm' of https://github.com/Microsoft/cntk into fseide/cudnn5

This commit is contained in:
Frank Seide 2016-08-22 20:11:07 -07:00
Родитель 4983c46deb 2fa1b7033d
Коммит 1f9c539c61
26 изменённых файлов: 1125 добавлений и 53 удалений

Просмотреть файл

@ -92,7 +92,7 @@ SRC:=
all : buildall
# Set up basic nvcc options and add CUDA targets from above
CUFLAGS = -m 64
CUFLAGS = -m 64
ifdef CUDA_PATH
ifndef GDK_INCLUDE_PATH
@ -261,7 +261,7 @@ READER_SRC =\
$(SOURCEDIR)/Readers/ReaderLib/TruncatedBpttPacker.cpp \
$(SOURCEDIR)/Readers/ReaderLib/PackerBase.cpp \
$(SOURCEDIR)/Readers/ReaderLib/FramePacker.cpp \
$(SOURCEDIR)/Readers/ReaderLib/ChunkCache.cpp \
$(SOURCEDIR)/Readers/ReaderLib/ChunkCache.cpp \
COMMON_SRC =\
$(SOURCEDIR)/Common/Config.cpp \

Просмотреть файл

@ -12,6 +12,7 @@
#include "LinearAlgebraNodes.h"
#include "RecurrentNodes.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "NonlinearityNodes.h"
#include "ReshapingNodes.h"
#include "InputAndParamNodes.h"

Просмотреть файл

@ -11,6 +11,7 @@
#include "NDLNetworkBuilder.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "DeprecatedNodes.h"
#include "EvaluationNodes.h"
#include "InputAndParamNodes.h"

Просмотреть файл

@ -497,6 +497,7 @@ PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, ta
PerDimMeanVarNormalization (x, mean, invStdDev) = (x - mean) .* invStdDev
Reciprocal(z, tag='') = new ComputationNode [ operation = 'Reciprocal' ; inputs = z /*plus the function args*/ ]
//# the following is a temporary workaround until we have the C++ version
RNN(A, B, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = new ComputationNode [ operation = 'RNN' ; inputs = ( A : B ) /*plus the function args*/ ]
Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = 'Scale' ; inputs = (scalarScalingFactor : matrix) /*plus the function args*/ ]
# TODO: Scale = ElementTimes
ScatterPacked(cond, indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'ScatterPacked' ; inputs = (cond : indexSequence : sourceData) /*plus the function args*/ ]

Просмотреть файл

@ -12,6 +12,7 @@
#include "LinearAlgebraNodes.h"
#include "NonlinearityNodes.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "RecurrentNodes.h"
#include "ReshapingNodes.h"
#include "TrainingNodes.h"

Просмотреть файл

@ -12,6 +12,7 @@
#include "ComputationNode.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "DeprecatedNodes.h"
#include "EvaluationNodes.h"
#include "InputAndParamNodes.h"
@ -92,6 +93,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(RectifiedLinearNode)) return New<RectifiedLinearNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ReduceElementsNode)) return New<ReduceElementsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ReshapeNode)) return New<ReshapeNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(RNNNode)) return New<RNNNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(RowRepeatNode)) return New<RowRepeatNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(RowStackNode)) return New<RowStackNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ScatterPackedNode)) return New<ScatterPackedNode<ElemType>>(forward<_Types>(_Args)...);

Просмотреть файл

@ -98,6 +98,7 @@
<ClInclude Include="ConvolutionalNodes.h" />
<ClInclude Include="DeprecatedNodes.h" />
<ClInclude Include="PreComputeNodes.h" />
<ClInclude Include="RNNNodes.h" />
<ClInclude Include="SpecialPurposeNodes.h" />
<ClInclude Include="EvaluationNodes.h" />
<ClInclude Include="InputAndParamNodes.h" />
@ -122,6 +123,7 @@
<ClCompile Include="ComputationNodeScripting.cpp" />
<ClCompile Include="InputAndParamNodes.cpp" />
<ClCompile Include="ReshapingNodes.cpp" />
<ClCompile Include="RNNNodes.cpp" />
<ClCompile Include="SpecialPurposeNodes.cpp" />
<ClCompile Include="stdafx.cpp" />
</ItemGroup>

Просмотреть файл

@ -40,6 +40,9 @@
<ClCompile Include="InputAndParamNodes.cpp">
<Filter>Nodes</Filter>
</ClCompile>
<ClCompile Include="RNNNodes.cpp">
<Filter>Nodes</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\Common\Include\fileutil.h">
@ -132,6 +135,9 @@
<ClInclude Include="DeprecatedNodes.h">
<Filter>Nodes</Filter>
</ClInclude>
<ClInclude Include="RNNNodes.h">
<Filter>Nodes</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="Common">

Просмотреть файл

@ -0,0 +1,367 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "Basics.h"
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include "RNNNodes.h"
#include <unordered_set>
#include <map>
#include <string>
#include <vector>
#include <stdexcept>
#include <list>
#include <memory>
#include <algorithm>
#include <utility>
#include <assert.h>
namespace Microsoft { namespace MSR { namespace CNTK {
vector<size_t> numSequencesForFrame;
// -----------------------------------------------------------------------
// RNNNode
// -----------------------------------------------------------------------
template<class ElemType>
RNNNode<ElemType>::RNNNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name),
m_rnnParameters(0, 0, 0, L"LSTM"),
m_BackwardDataCalledYet(false)
{
}
// This constructor helps with BrainScript integration
template<class ElemType>
RNNNode<ElemType>::RNNNode(const ScriptableObjects::IConfigRecordPtr configp)
: Base(configp->Get(L"deviceId"), L"<placeholder>"),
m_rnnParameters(configp->Get(L"bidirectional"), configp->Get(L"numLayers"), configp->Get(L"hiddenSize"), configp->Get(L"rnnMode")),
m_BackwardDataCalledYet(false)
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
template<class ElemType>
/*virtual*/ void RNNNode<ElemType>::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<RNNNode<ElemType>>(nodeP);
node->m_rnnParameters = m_rnnParameters;
}
}
template<class ElemType>
void RNNNode<ElemType>::Save(File& fstream) const
{
Base::Save(fstream);
m_rnnParameters.Write(fstream);
}
template<class ElemType>
void RNNNode<ElemType>::Load(File& fstream, size_t modelVersion)
{
Base::Load(fstream, modelVersion);
m_rnnParameters.Read(fstream);
}
template<class ElemType>
TensorView<ElemType> RNNNode<ElemType>::TensorHelper(int inputIndex/*-1 for output*/, bool gradient/*instead of value*/, const FrameRange& fr)
{
auto input = inputIndex < 0 ? this : Input(inputIndex).get();
return gradient ? input->GradientTensorFor(SIZE_MAX, fr) : input->ValueTensorFor(SIZE_MAX, fr);
}
template<class ElemType>
void RNNNode<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
{
shapeY = shapeX;
shapeY.SwapDimsInPlace(1, 2);
TensorView<ElemType> Y(matY, TensorShape(shapeY.GetDims()));
TensorView<ElemType> X(matX, shapeY);
Y.AssignCopyOf(X);
shapeY = Y.GetShape();
};
template<class ElemType>
void RNNNode<ElemType>::ForwardProp(const FrameRange& fr)
{
// ComputationNode derived classes are guaranteed to have a MBLayout
if (!this->HasMBLayout())
{
LogicError("RNNNode must operate on minibatches");
}
// The parameters are stored in a column matrix
Matrix<ElemType>& paramW = Input(1)->Value();
// Detect frame mode. Bugbug: should never revisit this decision after the first data comes through
MBLayoutPtr mb = this->GetMBLayout();
bool frameMode = (mb->GetNumTimeSteps() == 1) ? true : false;
if (frameMode)
{
TensorView<ElemType> outputY = ValueTensorFor(SIZE_MAX, fr);
// ensure enough storage.
m_transposedOutput->Resize(this->Value());
m_transposedInput->Resize(Input(0)->Value());
// For windowed LSTM, CNTK is providing data with the second dimension being time-like and the third dimension
// being minibatch index. CuDnn expects the second dimension to be minibatch index, and the third dimension
// to be time-like. This sequence of operations creates a transposed copy of the data in m_transposedInput
// and shapeXT
TransposeHelper(Input(0)->ValuePtr(), Input(0)->GetTensorSliceFor(SIZE_MAX, fr), m_transposedInput, shapeXT);
// Similarly, we will eventually need to transpose the output. Generate the necessary shape here, and do
// the transposition after RNNForward() returns.
// create the necessary shape.
shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr));
// this swap results in a shape with swapped dimensions, but also swapped strides
shapeYT.SwapDimsInPlace(1, 2);
// this copy is necessary so that the strides are dense.
shapeYT = TensorShape(shapeYT.GetDims());
// create a vector with the correct number of timesteps(shapeXT[2]) containing the sequence count (shapeXT[1])
numSequencesForFrame = vector<size_t>(shapeXT[2], shapeXT[1]);
try
{
m_transposedOutput->RNNForward(*m_transposedInput, paramW, shapeXT[0], shapeYT[0], numSequencesForFrame, m_rnnParameters, *m_reserve, *m_workspace);
}
catch (exception e)
{
fprintf(stderr, "|m_transposedInput|=%ld\n", m_transposedInput->GetNumElements());
fprintf(stderr, "|m_reserve|=%ld\n", m_reserve->GetNumElements());
fprintf(stderr, "|m_workspace|=%ld\n", m_workspace->GetNumElements());
fprintf(stderr, "shapeXT=%s\n", ((std::string)shapeXT).c_str());
fprintf(stderr, "shapeYT=%s\n", ((std::string)shapeYT).c_str());
fprintf(stderr, "numSequencesForFrame=[");
for (size_t x : numSequencesForFrame) fprintf(stderr, "%ld, ", x);
fprintf(stderr, "\n");
throw e;
}
// No one uses shapeY, but it is necessary
TensorShape shapeY;
TransposeHelper(m_transposedOutput, TensorShape(shapeYT.GetDims()), this->ValuePtr(), shapeY);
}
else
{
shapeXT = TensorShape(Input(0)->GetTensorSliceFor(SIZE_MAX, fr));
shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr));
// This changes the data from "minibatch paking" in Input(0)->Value() to "dense CuDNN packing" in m_transposedInput
this->PackSequencesForCuDNN(Input(0)->Value(), *m_transposedInput, numSequencesForFrame);
// ensure enough storage
m_transposedOutput->Resize(this->Value().GetNumRows(), m_transposedInput->GetNumCols());
m_transposedOutput->RNNForward(*m_transposedInput, paramW, shapeXT[0], shapeYT[0], numSequencesForFrame, m_rnnParameters, *m_reserve, *m_workspace);
this->UnpackSequencesFromCuDNN(*m_transposedOutput, this->Value());
}
m_BackwardDataCalledYet = false;
}
template<class ElemType>
void RNNNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
{
MBLayoutPtr mb = this->GetMBLayout();
bool frameMode = (mb->GetNumTimeSteps() == 1) ? true : false;
// ensure BackwardData is the first method called, as required by CuDnn API
if (!m_BackwardDataCalledYet)
{
Matrix<ElemType>& paramW = Input(1)->Value();
if (frameMode)
{
// To obey the data layout constraints of CuDnn, we take the derivative we're given,
// and transpose it before feeding to the interface.
m_transposedDOutput->Resize(this->Gradient());
TransposeHelper(this->GradientPtr(), this->GetTensorSliceFor(SIZE_MAX, fr), m_transposedDOutput, shapeYT);
}
else
{
m_transposedDOutput->DoGatherColumnsOf(0.0, *(this->m_packingIndex), this->Gradient(), 1.0);
}
// Ensure enough space for the result
m_transposedDInput->Resize(Input(0)->Value().GetNumRows(), m_transposedDOutput->GetNumCols());
// Do the work
try
{
m_transposedOutput->RNNBackwardData(*m_transposedDOutput, paramW, *m_transposedDInput, m_rnnParameters, *m_reserve, *m_workspace);
}
catch (exception e)
{
fprintf(stderr, "|m_transposedDOutput|=%ld\n", m_transposedDOutput->GetNumElements());
fprintf(stderr, "|paramW|=%ld\n", paramW.GetNumElements());
fprintf(stderr, "|m_transposedDInput|=%ld\n", m_transposedDInput->GetNumElements());
fprintf(stderr, "|m_reserve|=%ld\n", m_reserve->GetNumElements());
fprintf(stderr, "|m_workspace|=%ld\n", m_workspace->GetNumElements());
fprintf(stderr, "shapeXT=%s\n", ((std::string)shapeXT).c_str());
fprintf(stderr, "shapeYT=%s\n", ((std::string)shapeYT).c_str());
fprintf(stderr, "numSequencesForFrame=[");
for (size_t x : numSequencesForFrame) fprintf(stderr, "%ld, ", x);
fprintf(stderr, "\n");
throw e;
}
m_BackwardDataCalledYet = true;
}
if (inputIndex == 1) // parameters
{
Matrix<ElemType>& paramDW = Input(1)->Gradient();
try
{
m_transposedOutput->RNNBackwardWeights(*m_transposedInput, *m_transposedOutput, paramDW, m_rnnParameters, *m_reserve, *m_workspace);
}
catch (exception e)
{
fprintf(stderr, "|m_transposedInput|=%ld\n", m_transposedInput->GetNumElements());
fprintf(stderr, "|m_transposedOutput|=%ld\n", m_transposedOutput->GetNumElements());
fprintf(stderr, "|paramDW|=%ld\n", paramDW.GetNumElements());
fprintf(stderr, "|m_reserve|=%ld\n", m_reserve->GetNumElements());
fprintf(stderr, "|m_workspace|=%ld\n", m_workspace->GetNumElements());
fprintf(stderr, "shapeXT=%s\n", ((std::string)shapeXT).c_str());
fprintf(stderr, "shapeYT=%s\n", ((std::string)shapeYT).c_str());
fprintf(stderr, "numSequencesForFrame=[");
for (size_t x : numSequencesForFrame) fprintf(stderr, "%ld, ", x);
fprintf(stderr, "\n");
throw e;
}
}
else if (inputIndex == 0) // data
{
// all of the work was done above, where RNNBackwardData is called. Now, just unpack the result.
if (frameMode)
{
TensorShape tmp;
TransposeHelper(m_transposedDInput, shapeXT, Input(0)->GradientPtr(), tmp);
}
else
{
Input(0)->Gradient().DoScatterColumnsOf(1.0, *(this->m_packingIndex), *m_transposedDInput, 1.0);
}
}
}
template<class ElemType>
void RNNNode<ElemType>::Validate(bool isFinalValidationPass)
{
// N.B.: I need both of these lines.
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// get tensor shapes
auto dimsA = Input(1)->GetSampleLayout().GetDims(); // data
auto dimsB = Input(0)->GetSampleLayout().GetDims(); // parameters
// validate and infer
if (isFinalValidationPass || (dimsA.size() > 0 && dimsB.size() > 0)) // only if we got at least some input dimensions to work with or need to wrap up
{
// now determine result dimensions
auto dimsC = dimsB;
// output dims - bugbug: this is hard-coded for bidirectional models
dimsC[0] = (m_rnnParameters.m_bidirectional ? 2 : 1) * m_rnnParameters.m_hiddenSize;
// N.B. - this is the magical call, the reason for the function
// dimensions would be outputRank * numSamples * minibatch * time.
// This call establishes outputRank * numSamples, the rest will be filled in
// dynamically though the MBLayout.
SetDims(TensorShape(dimsC), HasMBLayout());
}
};
template<class ElemType>
void RNNNode<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame)
{
MBLayoutPtr mb = this->GetMBLayout();
if (mb->HasSequenceBeyondBegin())
RuntimeError("Invalid MBLayout: Only whole-utterance processing is supported");
#if 0
BUGBUG: Disable this check to mask a problem with the way EvalReader creates segments.
if (mb->HasSequenceBeyondEnd())
RuntimeError("Invalid MBLayout: Only whole-utterance processing is supported");
#endif
// retrieve only the non-gap sequences
vector<MBLayout::SequenceInfo> seq;
std::copy_if(
mb->GetAllSequences().begin(),
mb->GetAllSequences().end(),
back_inserter<vector<MBLayout::SequenceInfo>>(seq),
[](const MBLayout::SequenceInfo& x) { return x.seqId != GAP_SEQUENCE_ID; });
// sequenceOrder[i] will eventually be the i'th longest sequence,
// after sorting from longest to shortest. Ties are broken by the sequence id.
size_t numSequences = seq.size();
vector<size_t> sequenceOrder(numSequences);
for (size_t j = 0; j<numSequences; j++)
sequenceOrder[j] = j;
sort(sequenceOrder.begin(), sequenceOrder.end(), [&](size_t a, size_t b)
{
// sort in decreasing order of length
if (seq[a].GetNumTimeSteps() > seq[b].GetNumTimeSteps())
return true;
// break ties with increasing seqId
else if (seq[a].GetNumTimeSteps() == seq[b].GetNumTimeSteps())
return seq[a].seqId < seq[b].seqId;
return false;
}
);
size_t maxSeqLength = seq[sequenceOrder[0]].GetNumTimeSteps();
// BUGBUG: This forces the sequences to fit, due to a very bad convention in the evaldll interface.
if (maxSeqLength > mb->GetNumTimeSteps())
maxSeqLength = mb->GetNumTimeSteps();
// a count of how many sequnces are packed for a particular frame.
// reset to zero, and compute from current layout information
// this information is useful when creating the tensor descriptors for CuDNN.
numSequencesForFrame.resize(maxSeqLength);
fill(numSequencesForFrame.begin(), numSequencesForFrame.end(), 0L);
// make sure the index is on CPU so we can use SetValue()
m_packingIndex->TransferToDeviceIfNotThere(-1, true, false, false);
// Reserve one element for every valid sample. DoGatherColumnsOf() requires it to be a row vector
m_packingIndex->Resize(1, mb->GetActualNumSamples());
size_t dst_frame = 0;
for (size_t fr = 0; fr < maxSeqLength; fr++)
{
for (size_t j = 0; j < numSequences && seq[sequenceOrder[j]].GetNumTimeSteps()>fr; j++)
{
m_packingIndex->SetValue(0, dst_frame++, (ElemType)mb->GetColumnIndex(seq[sequenceOrder[j]], fr));
numSequencesForFrame[fr]++;
}
}
// this->gather(beta,idx,a,alpha) operation is defined as
// *this[:,j] = a[:,idx[j]] * alpha + *this[:,j] * beta
dst.DoGatherColumnsOf(0.0, *(this->m_packingIndex), src, 1.0);
}
template<class ElemType>
void RNNNode<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst)
{
// this->scatter(beta,ndx,a,alpha) operation is defined as
// *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta
dst.DoScatterColumnsOf(0.0, *(this->m_packingIndex), src, 1.0);
}
template class RNNNode<float>;
template class RNNNode<double>;
} } }

Просмотреть файл

@ -0,0 +1,113 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include "RNNCommon.h"
#include <unordered_set>
#include <map>
#include <string>
#include <vector>
#include <stdexcept>
#include <list>
#include <memory>
#include <algorithm>
#include <utility>
#include <assert.h>
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// RNNNode (data, weights)
// -----------------------------------------------------------------------
template <class ElemType>
class RNNNode : public ComputationNode<ElemType>, public NumInputs<2>
{
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"RNN"; }
using Base::OperationName; \
public:
RNNNode(DEVICEID_TYPE deviceId, const wstring& name);
RNNNode(const ScriptableObjects::IConfigRecordPtr configp);
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override;
virtual void Save(File& fstream) const;
virtual void Load(File& fstream, size_t modelVersion) override;
public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override;
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override;
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override;
// request matrices needed to do node function value evaluation
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_transposedInput, matrixPool);
RequestMatrixFromPool(m_transposedOutput, matrixPool);
RequestMatrixFromPool(m_reserve, matrixPool);
RequestMatrixFromPool(m_workspace, matrixPool);
RequestMatrixFromPool(m_packingIndex, matrixPool);
}
// request matrices needed to do node derivative value evaluation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeBackprop(matrixPool);
RequestMatrixFromPool(m_transposedDInput, matrixPool);
RequestMatrixFromPool(m_transposedDOutput, matrixPool);
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_transposedInput, matrixPool);
ReleaseMatrixToPool(m_transposedOutput, matrixPool);
ReleaseMatrixToPool(m_transposedDInput, matrixPool);
ReleaseMatrixToPool(m_transposedDOutput, matrixPool);
#if 0
ReleaseMatrixToPool(m_reserve, matrixPool);
ReleaseMatrixToPool(m_workspace, matrixPool);
ReleaseMatrixToPool(m_packingIndex, matrixPool);
#endif
}
// Is the output value of the computation node needed for computing
virtual bool OutputUsedInComputingInputNodesGradients() const { return false; }
// Is the output value of the specified input node needed for computing
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const { return false; }
protected:
bool m_BackwardDataCalledYet;
TensorShape shapeXT;
TensorShape shapeYT;
shared_ptr<Matrix<ElemType>> m_transposedInput;
shared_ptr<Matrix<ElemType>> m_transposedOutput;
shared_ptr<Matrix<ElemType>> m_transposedDInput;
shared_ptr<Matrix<ElemType>> m_transposedDOutput;
shared_ptr<Matrix<ElemType>> m_workspace;
shared_ptr<Matrix<ElemType>> m_reserve;
shared_ptr<Matrix<ElemType>> m_packingIndex;
private:
TensorView<ElemType> TensorHelper(int inputIndex/*-1 for output*/, bool gradient/*instead of value*/, const FrameRange& fr);
void TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY);
void PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame);
void UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst);
RnnParameters m_rnnParameters;
};
}}}

Просмотреть файл

@ -89,6 +89,7 @@ CuDnn::ptr_t CuDnn::Instance()
cudnnHandle_t* cudnn = new cudnnHandle_t;
CUDNN_CALL(cudnnCreate(cudnn));
CUDNN_CALL(cudnnSetStream(*cudnn, GetStream()));
fprintf(stderr, "CuDnn::Instance()::createNew() deviceId=%d\n", deviceId);
return cudnn;
};
@ -97,6 +98,7 @@ CuDnn::ptr_t CuDnn::Instance()
assert(*src != nullptr);
auto err = cudnnDestroy(*src);
assert(err == CUDNN_STATUS_SUCCESS);
fprintf(stderr, "CuDnn::Instance()::m_instance.destroy()\n");
#ifdef NDEBUG
UNUSED(err);
#endif

175
Source/Math/CuDnnRNN.cpp Normal file
Просмотреть файл

@ -0,0 +1,175 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "Matrix.h"
#include "GPUMatrix.h"
#include "TensorShape.h"
#include "TensorView.h"
#include <typeinfo>
#include <typeindex>
#include "CuDnnCommon.h"
#include "CuDnnRNN.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
class CuDnnTensorDescriptor
{
private:
cudnnTensorDescriptor_t m_tensorDesc;
public:
CuDnnTensorDescriptor(size_t hiddenSize, size_t miniBatch, size_t numLayers)
{
cudnnDataType_t m_dataType = CuDnnTensor::GetDataType<ElemType>();
int dimA[3] = { (int)hiddenSize, (int)miniBatch, (int)numLayers };
int strideA[3] = { 1, dimA[0], dimA[0] * dimA[1] };
CUDNN_CALL(cudnnCreateTensorDescriptor(&m_tensorDesc));
CUDNN_CALL(cudnnSetTensorNdDescriptor(m_tensorDesc, m_dataType, 3, dimA, strideA));
}
~CuDnnTensorDescriptor()
{
cudnnDestroyTensorDescriptor(m_tensorDesc);
}
operator cudnnTensorDescriptor_t() const
{
return m_tensorDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnTensorDescriptor);
};
template <class ElemType>
void CuDnnRNNExecutor<ElemType>::SetDescriptors(size_t dim, const vector<size_t>& numSequencesForFrame, vector<cudnnTensorDescriptor_t>& descriptors)
{
for (size_t i = 0; i < numSequencesForFrame.size(); i++)
{
if (descriptors.size() <= i)
{
descriptors.push_back(cudnnTensorDescriptor_t());
CUDNN_CALL(cudnnCreateTensorDescriptor(&descriptors[i]));
}
int dims[3] = { (int)numSequencesForFrame[i], (int)dim, 1 };
int strides[3] = { dims[2] * dims[1], dims[2], 1 };
CUDNN_CALL(cudnnSetTensorNdDescriptor(descriptors[i], CUDNN_DATA_FLOAT, 3, dims, strides));
}
}
template <class ElemType>
void CuDnnRNNExecutor<ElemType>::ForwardCore(
const GPUMatrix<ElemType>& weightsW,
const GPUMatrix<ElemType>& inputX, GPUMatrix<ElemType>& outputY,
const vector<size_t>& numSequencesForFrame,
const RnnParameters& rnnParameters,
GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace
)
{
// test that the RNN shape is correct
if (!m_rnnT->IsCompatable(rnnParameters))
LogicError("RNN Layout has changed during processing");
if (m_yDim != (m_rnnT->isBidirectional() ? 2 : 1) * m_rnnT->GetNumHidden())
InvalidArgument("CuDnn ForwardCore: Output leading dimension must be twice hidden size for bidirectional networks");
// set up the input and output descriptors
SetDescriptors(m_xDim, numSequencesForFrame, xDesc);
SetDescriptors(m_yDim, numSequencesForFrame, yDesc);
// ensure workspace and reserve are large enough
m_seqLength = numSequencesForFrame.size();
size_t workSize;
size_t reserveSize;
// Need for every pass
CUDNN_CALL(cudnnGetRNNWorkspaceSize(*m_cudnn, *m_rnnT, (int)m_seqLength, xDesc.data(), &workSize));
// Only needed in training, can't be touched between passes.
CUDNN_CALL(cudnnGetRNNTrainingReserveSize(*m_cudnn, *m_rnnT, (int)m_seqLength, xDesc.data(), &reserveSize));
// convert from bytes to ElemType
workSize = (workSize + sizeof(ElemType) - 1) / (sizeof(ElemType));
reserveSize = (reserveSize + sizeof(ElemType) - 1) / sizeof(ElemType);
reserve.Resize(reserveSize, 1);
workspace.Resize(workSize, 1);
wDesc = make_unique<CuDnnFilter<ElemType>>(*m_rnnT, xDesc[0]);
if (wDesc->GetSize() != weightsW.GetNumElements())
InvalidArgument("RNN needs %ld parameters, but %ld were allocated", wDesc->GetSize(), weightsW.GetNumElements());
CUDNN_CALL(cudnnRNNForwardTraining(
*m_cudnn, *m_rnnT,
(int)m_seqLength,
xDesc.data(), inputX.Data(),
0, 0,
0, 0,
*wDesc, weightsW.Data(),
yDesc.data(), outputY.Data(),
0, 0,
0, 0,
workspace.Data(), workspace.GetNumElements()*sizeof(ElemType),
reserve.Data(), reserve.GetNumElements()*sizeof(ElemType)));
m_BackwardDataCalledYet = false;
}
template <class ElemType>
void CuDnnRNNExecutor<ElemType>::BackwardDataCore(
const GPUMatrix<ElemType>& outputY, const GPUMatrix<ElemType>& outputDY, const GPUMatrix<ElemType>& weightsW, GPUMatrix<ElemType>& dx,
const RnnParameters& rnnParameters,
GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace
)
{
// test that the RNN shape is correct
if (!m_rnnT->IsCompatable(rnnParameters))
LogicError("RNN Layout has changed during processing");
if (!m_BackwardDataCalledYet)
{
CUDNN_CALL(cudnnRNNBackwardData(
*m_cudnn, *m_rnnT,
(int)m_seqLength,
yDesc.data(), outputY.Data(),
yDesc.data(), outputDY.Data(),
0, 0,
0, 0,
*wDesc, weightsW.Data(),
0, 0,
0, 0,
xDesc.data(), dx.Data(),
0, 0,
0, 0,
workspace.Data(), workspace.GetNumElements()*sizeof(ElemType),
reserve.Data(), reserve.GetNumElements()*sizeof(ElemType)));
}
m_BackwardDataCalledYet = true;
}
template <class ElemType>
void CuDnnRNNExecutor<ElemType>::BackwardWeightsCore(const GPUMatrix<ElemType>& inputX, const GPUMatrix<ElemType>& outputY, GPUMatrix<ElemType>& dw,
const RnnParameters& rnnParameters,
GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace
)
{
// test that the RNN shape is correct
if (!m_rnnT->IsCompatable(rnnParameters))
LogicError("RNN Layout has changed during processing");
if (!m_BackwardDataCalledYet)
LogicError("out of order calling you have been very bad");
CUDNN_CALL(cudnnRNNBackwardWeights(
*m_cudnn, *m_rnnT,
(int)m_seqLength,
xDesc.data(), inputX.Data(),
0, 0,
yDesc.data(), outputY.Data(),
workspace.Data(), workspace.GetNumElements()*sizeof(ElemType),
*wDesc, dw.Data(),
reserve.Data(), reserve.GetNumElements()*sizeof(ElemType)));
}
template class CuDnnRNNExecutor<double>;
template class CuDnnRNNExecutor<float>;
} } }

222
Source/Math/CuDnnRNN.h Normal file
Просмотреть файл

@ -0,0 +1,222 @@
#pragma once
#include "Matrix.h"
#include "GPUMatrix.h"
#include "TensorShape.h"
#include <typeinfo>
#include <typeindex>
#include "CuDnnCommon.h"
#include "RNNCommon.h"
namespace Microsoft { namespace MSR { namespace CNTK {
class CuDnnDropout
{
CuDnn::ptr_t m_cudnn;
unsigned long long m_seed = 0xdeadbeefull;
public:
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull)
: m_dropoutDesc(nullptr), m_cudnn(CuDnn::Instance())
{
CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc));
size_t stateSize;
void *states;
CUDNN_CALL(cudnnDropoutGetStatesSize(*m_cudnn, &stateSize));
// bugbug: possible leak. Does CuDnn release this for us?
CUDA_CALL(cudaMalloc(&states, stateSize));
fprintf(stderr, "CuDnnDropout()\n");
CUDNN_CALL(cudnnSetDropoutDescriptor(m_dropoutDesc,
*m_cudnn,
dropout,
states,
stateSize,
seed));
}
~CuDnnDropout()
{
fprintf(stderr, "~CuDnnDropout()\n");
if (m_dropoutDesc != nullptr)
{
cudnnDestroyDropoutDescriptor(m_dropoutDesc);
m_dropoutDesc = nullptr;
}
}
operator cudnnDropoutDescriptor_t() const
{
return m_dropoutDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnDropout);
private:
cudnnDropoutDescriptor_t m_dropoutDesc;
};
template <class ElemType>
class CuDnnRNN
{
private:
cudnnDataType_t m_dataType;
cudnnRNNDescriptor_t m_rnnDesc;
CuDnnDropout m_dropout;
RnnParameters m_rnnParameters;
cudnnRNNMode_t GetMode()
{
if (m_rnnParameters.m_rnnMode == wstring(L"LSTM"))
return cudnnRNNMode_t::CUDNN_LSTM;
if (m_rnnParameters.m_rnnMode == wstring(L"GRU"))
return cudnnRNNMode_t::CUDNN_GRU;
if (m_rnnParameters.m_rnnMode == wstring(L"RNN_RELU"))
return cudnnRNNMode_t::CUDNN_RNN_RELU;
if (m_rnnParameters.m_rnnMode == wstring(L"RNN_TANH"))
return cudnnRNNMode_t::CUDNN_RNN_TANH;
InvalidArgument("RNN Mode set to %ls, but supported values are LSTM, GRU, RNN_RELU, RNN_TANH.", m_rnnParameters.m_rnnMode.c_str());
}
public:
CuDnnRNN(const RnnParameters& rnnParameters)
: m_rnnDesc(nullptr), m_dropout(0.0f), m_rnnParameters(rnnParameters),
m_dataType(CuDnnTensor::GetDataType<ElemType>())
{
fprintf(stderr, "CuDnnRNN()\n");
CUDNN_CALL(cudnnCreateRNNDescriptor(&m_rnnDesc));
CUDNN_CALL(cudnnSetRNNDescriptor(m_rnnDesc,
(int)m_rnnParameters.m_hiddenSize,
(int)m_rnnParameters.m_numLayers,
m_dropout,
CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation
m_rnnParameters.m_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
GetMode(),
m_dataType));
}
~CuDnnRNN()
{
fprintf(stderr, "~CuDnnRNN()\n");
if (m_rnnDesc != nullptr)
{
cudnnDestroyRNNDescriptor(m_rnnDesc);
m_rnnDesc = nullptr;
}
}
bool IsCompatable(const RnnParameters& rnnParameters) const
{
return this->m_rnnParameters == rnnParameters;
}
operator cudnnRNNDescriptor_t() const
{
return m_rnnDesc;
}
bool isBidirectional() const { return m_rnnParameters.m_bidirectional; }
size_t GetNumLayers() { return m_rnnParameters.m_numLayers; }
size_t GetNumHidden() { return m_rnnParameters.m_hiddenSize; }
DISABLE_COPY_AND_MOVE(CuDnnRNN);
};
template <class ElemType>
class CuDnnFilter
{
cudnnDataType_t m_dataType;
CuDnn::ptr_t m_cudnn;
size_t m_filterSize;
public:
CuDnnFilter(const CuDnnRNN<ElemType>& rnn, const cudnnTensorDescriptor_t& xDesc) :
m_cudnn(CuDnn::Instance()), m_dataType(CuDnnTensor::GetDataType<ElemType>())
{
CUDNN_CALL(cudnnCreateFilterDescriptor(&m_filterDesc));
try
{
size_t filterSize;
CUDNN_CALL(cudnnGetRNNParamsSize(*m_cudnn, rnn, xDesc, &filterSize, m_dataType));
size_t dataSize = 2; // CUDNN_DATA_HALF
if (m_dataType == cudnnDataType_t::CUDNN_DATA_DOUBLE)
dataSize = 8;
else if (m_dataType == cudnnDataType_t::CUDNN_DATA_FLOAT)
dataSize = 4;
// convert from bytes to items
m_filterSize = (filterSize + dataSize - 1) / dataSize;
int dimW[3] = { (int)m_filterSize, 1, 1 };
CUDNN_CALL(cudnnSetFilterNdDescriptor(m_filterDesc, m_dataType, CUDNN_TENSOR_NCHW, 3, dimW));
}
catch (exception e)
{
cudnnDestroyFilterDescriptor(m_filterDesc);
m_filterDesc = nullptr;
throw e;
}
}
~CuDnnFilter()
{
assert(m_filterDesc != nullptr);
cudnnDestroyFilterDescriptor(m_filterDesc);
}
size_t GetSize() { return m_filterSize; }
operator cudnnFilterDescriptor_t() const
{
return m_filterDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnFilter);
private:
cudnnFilterDescriptor_t m_filterDesc;
};
template <class ElemType>
class CuDnnRNNExecutor
{
CuDnn::ptr_t m_cudnn;
cudnnDataType_t m_dataType;
size_t m_xDim, m_yDim;
public:
CuDnnRNNExecutor(size_t xDim, size_t yDim, const RnnParameters& rnnParameters ) :
m_cudnn(CuDnn::Instance()),
m_xDim(xDim), m_yDim(yDim),
m_seqLength(0),
m_dataType(CuDnnTensor::GetDataType<ElemType>()),
m_BackwardDataCalledYet(false)
{
fprintf(stderr, "CuDnnRNNExecutor()\n");
m_rnnT = std::make_unique<CuDnnRNN<ElemType>>(rnnParameters);
}
void ForwardCore(const GPUMatrix<ElemType>& weightsW, const GPUMatrix<ElemType>& inputX, GPUMatrix<ElemType>& outputY, const vector<size_t>& numSequencesForFrame, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
void BackwardWeightsCore(const GPUMatrix<ElemType>& inputX, const GPUMatrix<ElemType>& outputY, GPUMatrix<ElemType>& dw, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
void BackwardDataCore(const GPUMatrix<ElemType>& outputY, const GPUMatrix<ElemType>& outputDY, const GPUMatrix<ElemType>& w, GPUMatrix<ElemType>& dx, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
protected:
std::unique_ptr<CuDnnFilter<ElemType>> wDesc;
vector<cudnnTensorDescriptor_t> xDesc;
vector<cudnnTensorDescriptor_t> yDesc;
private:
static ElemType* ptr(GPUMatrix<ElemType>& src)
{
return src.Data();
}
static const ElemType* ptr(const GPUMatrix<ElemType>& src)
{
return src.Data();
}
void SetDescriptors(size_t dim, const vector<size_t>& numSequencesForFrame, vector<cudnnTensorDescriptor_t>& descriptors);
private:
std::unique_ptr<CuDnnRNN<ElemType>> m_rnnT;
bool m_BackwardDataCalledYet;
size_t m_seqLength;
};
} } }

Просмотреть файл

@ -26,6 +26,7 @@
#include <memory>
#include "CntkBatchNormalization.cuh"
#include "Convolution.cuh"
#include "CuDnnRNN.h"
#pragma comment(lib, "cudart.lib") // instruct linker to reference these libs
#pragma comment(lib, "cublas.lib")
@ -3252,6 +3253,37 @@ void GPUMatrix<ElemType>::BatchNormalizationBackward(const GPUMatrix<ElemType>&
in.Data(), Data(), grad.Data(), scale.Data(), mbStatsWeight, scaleGrad.Data(), biasGrad.Data(), saveMean.Data(), saveInvStdDev.Data(), GetStream());
}
#pragma region RNN Functions
template<class ElemType>
struct GPUMatrix<ElemType>::RNNWrapper
{
std::unique_ptr<CuDnnRNNExecutor<ElemType>> m_rnnExecutor;
};
template <class ElemType>
void GPUMatrix<ElemType>::RNNForward(const GPUMatrix<ElemType> &inputX, const GPUMatrix<ElemType> &paramW, size_t xDim, size_t yDim, const vector<size_t>& numSequencesForFrame, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace)
{
// numLayers, hiddenSize are input parameters
if (!m_RNNWrapper)
m_RNNWrapper = std::make_unique<RNNWrapper>();
if (!m_RNNWrapper->m_rnnExecutor)
m_RNNWrapper->m_rnnExecutor = std::make_unique<CuDnnRNNExecutor<ElemType>>(xDim, yDim, rnnParameters);
m_RNNWrapper->m_rnnExecutor->ForwardCore(paramW, inputX, *this, numSequencesForFrame, rnnParameters, reserve, workspace);
}
template <class ElemType>
void GPUMatrix<ElemType>::RNNBackwardData(const GPUMatrix<ElemType>& outputDY, const GPUMatrix<ElemType>& paramW, GPUMatrix<ElemType>& outputDX, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace)
{
m_RNNWrapper->m_rnnExecutor->BackwardDataCore(*this, outputDY, paramW, outputDX, rnnParameters, reserve, workspace);
}
template <class ElemType>
void GPUMatrix<ElemType>::RNNBackwardWeights(const GPUMatrix<ElemType>& inputX, const GPUMatrix<ElemType>& outputY, GPUMatrix<ElemType>& dw, const RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace)
{
m_RNNWrapper->m_rnnExecutor->BackwardWeightsCore(inputX, outputY, dw, rnnParameters, reserve, workspace);
}
#pragma region Static BLAS Functions
// float/double overloads of cublasSgemm()/cublasDgemm()
static cublasStatus_t cublas_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc)

Просмотреть файл

@ -167,6 +167,8 @@ private:
#pragma warning(push)
#pragma warning(disable : 4251)
mutable std::unique_ptr<conc_stack<std::unique_ptr<GPUMatrix<ElemType>>>> m_workspace;
struct RNNWrapper;
mutable std::unique_ptr<struct RNNWrapper> m_RNNWrapper;
#pragma warning(pop)
private:
@ -474,6 +476,11 @@ public:
const GPUMatrix<ElemType>& saveMean, const GPUMatrix<ElemType>& saveInvStdDev,
GPUMatrix<ElemType>& scaleGrad, GPUMatrix<ElemType>& biasGrad) const;
// RNN support functions
void RNNForward(const GPUMatrix<ElemType>& inputX, const GPUMatrix<ElemType>& paramW, size_t xDim, size_t yDim, const vector<size_t>& numSequencesForFrame, const struct RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
void RNNBackwardData(const GPUMatrix<ElemType>& outputDY, const GPUMatrix<ElemType>& paramW, GPUMatrix<ElemType>& outputDX, const struct RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
void RNNBackwardWeights(const GPUMatrix<ElemType>& inputX, const GPUMatrix<ElemType>& outputY, GPUMatrix<ElemType>& dw, const struct RnnParameters& rnnParameters, GPUMatrix<ElemType>& reserve, GPUMatrix<ElemType>& workspace);
public:
// static BLAS functions
static void MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const bool transposeA, const GPUMatrix<ElemType>& b, const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c);

Просмотреть файл

@ -42,6 +42,10 @@
#define IDX2C(i, j, ld) (((j) * (ld)) + (i)) // 0 based indexing
// TODO: This condition seems wrong, it should be:
// !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 600
// NVIDIA should fix their CUDA 8.0 headers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
// CUDA atomicAdd() only exists for 'float'. This is the 'double' version.
// TODO: This may need to be guarded by CUDA version; newer devices may support this.
static __inline__ __device__ double atomicAdd(double* address, double val)
@ -55,6 +59,7 @@ static __inline__ __device__ double atomicAdd(double* address, double val)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
// TODO: replace this with TensorOps.h LogAdd(). It differs in using ElemType throughout, while this one seems to use 'double' versions of exp() and log().
// The 'k' in the name is to avoid naming conflicts with various versions of logadd() that are defined throughout the codebase.

Просмотреть файл

@ -170,9 +170,10 @@
<ClInclude Include="ConvolutionEngine.h" />
<ClInclude Include="ConvolveGeometry.h" />
<ClInclude Include="CPUMatrix.h" />
<ClInclude Include="CPURNGHandle.h" />
<ClInclude Include="CPURNGHandle.h" />
<ClInclude Include="MatrixQuantizerImpl.h" />
<ClInclude Include="RNGHandle.h" />
<ClInclude Include="RNGHandle.h" />
<ClInclude Include="RNNCommon.h" />
<ClInclude Include="TensorOps.h" />
<ClInclude Include="TensorView.h" />
<ClInclude Include="Quantizers.h" />
@ -199,7 +200,7 @@
<ClCompile Include="BlockHandlerAVX.cpp" />
<ClCompile Include="BlockHandlerSSE.cpp" />
<ClCompile Include="ConvolutionEngine.cpp" />
<ClCompile Include="CPURNGHandle.cpp" />
<ClCompile Include="CPURNGHandle.cpp" />
<ClCompile Include="CPUSparseMatrix.cpp" />
<ClCompile Include="CUDAPageLockedMemAllocator.cpp" />
<ClCompile Include="dllmain.cpp">
@ -213,7 +214,7 @@
<ClCompile Include="NoGPU.cpp" />
<ClCompile Include="Matrix.cpp" />
<ClCompile Include="QuantizedMatrix.cpp" />
<ClCompile Include="RNGHandle.cpp" />
<ClCompile Include="RNGHandle.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>
</ClCompile>

Просмотреть файл

@ -111,6 +111,9 @@
<ClInclude Include="CPURNGHandle.h">
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="RNNCommon.h">
<Filter>RNN</Filter>
</ClInclude>
<ClInclude Include="BlockHandlerAVX.h">
<Filter>CPU</Filter>
</ClInclude>
@ -173,5 +176,8 @@
<Filter Include="BatchNormalization">
<UniqueIdentifier>{8f982dac-298d-4e48-b060-8e6cba5ff554}</UniqueIdentifier>
</Filter>
<Filter Include="RNN">
<UniqueIdentifier>{ee6bf704-73f0-488d-8432-0d23f034de88}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>

Просмотреть файл

@ -120,6 +120,7 @@ if exist "$(CuDnnDll)" xcopy /D /Y "$(CuDnnDll)" "$(OutputPath)"
<ClInclude Include="cudalib.h" />
<ClInclude Include="CuDnnCommon.h" />
<ClInclude Include="CuDnnFactories.h" />
<ClInclude Include="CuDnnRNN.h" />
<ClInclude Include="GPUDataTransferer.h" />
<ClInclude Include="GPURNGHandle.h" />
<ClInclude Include="GPUTensor.h" />
@ -165,6 +166,7 @@ if exist "$(CuDnnDll)" xcopy /D /Y "$(CuDnnDll)" "$(OutputPath)"
<FileType>CppCode</FileType>
</CudaCompile>
<ClCompile Include="CuDnnCommon.cpp" />
<ClCompile Include="CuDnnRNN.cpp" />
<ClCompile Include="GPUDataTransferer.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>

Просмотреть файл

@ -54,6 +54,9 @@
<ClCompile Include="CuDnnCommon.cpp">
<Filter>GPU\CuDnn</Filter>
</ClCompile>
<ClCompile Include="CuDnnRNN.cpp">
<Filter>GPU\RNN</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\Common\Include\File.h">
@ -128,6 +131,9 @@
<ClInclude Include="GPURNGHandle.h">
<Filter>GPU</Filter>
</ClInclude>
<ClInclude Include="CuDnnRNN.h">
<Filter>GPU\RNN</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="GPUMatrix.h">
@ -174,5 +180,8 @@
<Filter Include="GPU\CuDnn">
<UniqueIdentifier>{05351afa-de95-40c8-830a-d70eede55dc0}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\RNN">
<UniqueIdentifier>{da8154d0-d1f4-44a5-b64a-5e955ead56fa}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>

Просмотреть файл

@ -4313,6 +4313,50 @@ void Matrix<ElemType>::BatchNormalizationBackward(const Matrix<ElemType>& in, Ma
NOT_IMPLEMENTED);
}
template <class ElemType>
void Matrix<ElemType>::RNNForward(const Matrix<ElemType> &inputX, const Matrix<ElemType> &paramW, size_t xDim, size_t yDim, const vector<size_t>& numSequencesForFrame, const RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace)
{
DecideAndMoveToRightDevice(*this, inputX, paramW);
// move reserve/workspace to the consensus device
reserve._transferToDevice(GetDeviceId());
workspace._transferToDevice(GetDeviceId());
DISPATCH_MATRIX_ON_FLAG(this, this,
NOT_IMPLEMENTED,
m_GPUMatrix->RNNForward(*(inputX.m_GPUMatrix), *(paramW.m_GPUMatrix), xDim, yDim, numSequencesForFrame, rnnParameters, *(reserve.m_GPUMatrix), *(workspace.m_GPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
template <class ElemType>
void Matrix<ElemType>::RNNBackwardData(const Matrix<ElemType>& outputDY, const Matrix<ElemType>& paramW, Matrix<ElemType>& outputDX, const RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace)
{
DecideAndMoveToRightDevice(*this, outputDY, paramW, outputDX);
// move reserve/workspace to the consensus device
reserve._transferToDevice(GetDeviceId());
workspace._transferToDevice(GetDeviceId());
DISPATCH_MATRIX_ON_FLAG(this, this,
NOT_IMPLEMENTED,
m_GPUMatrix->RNNBackwardData(*(outputDY.m_GPUMatrix), *(paramW.m_GPUMatrix), *(outputDX.m_GPUMatrix), rnnParameters, *(reserve.m_GPUMatrix), *(workspace.m_GPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
template <class ElemType>
void Matrix<ElemType>::RNNBackwardWeights(const Matrix<ElemType>& inputX, const Matrix<ElemType>& outputY, Matrix<ElemType>& dw, const RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace)
{
DecideAndMoveToRightDevice(*this, inputX, outputY, dw);
// move reserve/workspace to the consensus device
reserve._transferToDevice(GetDeviceId());
workspace._transferToDevice(GetDeviceId());
DISPATCH_MATRIX_ON_FLAG(this,
this,
NOT_IMPLEMENTED,
m_GPUMatrix->RNNBackwardWeights(*(inputX.m_GPUMatrix), *(outputY.m_GPUMatrix), *(dw.m_GPUMatrix), rnnParameters, *(reserve.m_GPUMatrix), *(workspace.m_GPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
#pragma region Static BLAS Functions
template <class ElemType>

Просмотреть файл

@ -506,6 +506,10 @@ public:
void BatchNormalizationBackward(const Matrix<ElemType>& in, Matrix<ElemType>& grad, const Matrix<ElemType>& scale, double blendFactor, const Matrix<ElemType>& saveMean, const Matrix<ElemType>& saveInvStdDev,
Matrix<ElemType>& scaleGrad, Matrix<ElemType>& biasGrad) const;
void RNNForward(const Matrix<ElemType>& inputX, const Matrix<ElemType>& paramW, size_t xDim, size_t yDim, const vector<size_t>& numSequencesForFrame, const struct RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace);
void RNNBackwardData(const Matrix<ElemType>& outputDY, const Matrix<ElemType>& paramW, Matrix<ElemType>& outputDX, const struct RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace);
void RNNBackwardWeights(const Matrix<ElemType>& inputX, const Matrix<ElemType>& outputY, Matrix<ElemType>& dw, const struct RnnParameters& rnnParameters, Matrix<ElemType>& reserve, Matrix<ElemType>& workspace);
public:
// TODO: why are these not static? And why are they here?
ElemType Exp10(ElemType num);

51
Source/Math/RNNCommon.h Normal file
Просмотреть файл

@ -0,0 +1,51 @@
#pragma once
#include "File.h"
#include <string>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
struct RnnParameters
{
bool m_bidirectional;
size_t m_numLayers;
size_t m_hiddenSize;
wstring m_rnnMode;
RnnParameters(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& rnnMode)
: m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_rnnMode(rnnMode)
{}
bool operator==(const RnnParameters& other) const
{
return
m_bidirectional == other.m_bidirectional &&
m_numLayers == other.m_numLayers &&
m_hiddenSize == other.m_hiddenSize &&
m_rnnMode == other.m_rnnMode;
}
void Read(File& stream)
{
size_t bidirectional;
stream >> bidirectional; m_bidirectional = !!bidirectional;
stream >> m_numLayers;
stream >> m_hiddenSize;
stream >> m_rnnMode;
}
void Write(File& stream) const
{
size_t bidirectional = m_bidirectional ? 1 : 0;
stream << bidirectional;
stream << m_numLayers;
stream << m_hiddenSize;
stream << m_rnnMode;
}
private:
// disallow public default constructor
RnnParameters() {}
};
} } }

Просмотреть файл

@ -146,7 +146,6 @@ public:
shared_ptr<Matrix<ElemType>> AsMatrix() const;
const TensorShape& GetShape() const { return m_shape; }
private:
// -------------------------------------------------------------------
// accessors
// -------------------------------------------------------------------
@ -155,6 +154,7 @@ private:
Matrix<ElemType>& GetSOB() { return *m_sob; }
friend Test::TensorTest<ElemType>;
private:
// -------------------------------------------------------------------
// sob members
// -------------------------------------------------------------------

Просмотреть файл

@ -46,7 +46,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
bool useParallelTrain,
StreamMinibatchInputs& inputMatrices,
size_t& actualMBSize,
const MPIWrapperPtr& mpi)
const MPIWrapperPtr& mpi,
size_t dataDecimationFactor = 0
)
{
// Reading consists of a sequence of Reader API calls:
// - GetMinibatch() --fills the inputMatrices and copies the MBLayout from Reader into inputMatrices
@ -93,6 +95,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
DecimateMinibatchInPlace<ElemType>(inputMatrices, mpi->NumNodesInUse(), mpi->CurrentNodeRank(), pMBLayout);
}
// This will automatically discard a large fraction of the data, useful if the training data is known to be highly correlated
if (dataDecimationFactor)
{
auto& pMBLayout = net->GetMBLayoutPtrOfNetwork();
// Verify that there's indeed a single layout
for (const auto& iter : inputMatrices)
{
assert(iter.second.pMBLayout == pMBLayout);
// TODO: This must be a runtime check, not an assert().
UNUSED(iter);
}
DecimateMinibatchInPlace<ElemType>(inputMatrices, dataDecimationFactor, 0, pMBLayout);
}
NotifyChangedNodes<ElemType>(net, inputMatrices);
// get MB size and tell Network to update its nodes' buffers based on what's in the input matrices

Просмотреть файл

@ -2616,27 +2616,27 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
else
{
size_t numMPIWorkers = pMPI->NumNodesInUse();
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
m_parallelizationStartEpochNum = configParallelTrain(L"parallelizationStartEpoch", (int)1) - 1; // Epoch numbers internally are 0 based
m_enableDistributedMBReading = configParallelTrain(L"distributedMBReading", false);
m_syncStatsTrace = configParallelTrain(L"syncPerfStats", (int)0);
const ConfigRecordType& configParallelTrain(configSGD(L"ParallelTrain", ConfigRecordType::Record()));
m_parallelizationMethod = ParseParallelizationMethod(configParallelTrain(L"parallelizationMethod", L"none"));
m_parallelizationStartEpochNum = configParallelTrain(L"parallelizationStartEpoch", (int) 1) - 1; // Epoch numbers internally are 0 based
m_enableDistributedMBReading = configParallelTrain(L"distributedMBReading", false);
m_syncStatsTrace = configParallelTrain(L"syncPerfStats", (int) 0);
if (configParallelTrain.Exists(L"DataParallelSGD"))
{
const ConfigRecordType& configDataParallelSGD(configParallelTrain(L"DataParallelSGD", ConfigRecordType::Record()));
size_t defaultGradientBits = 8 * sizeofElemType;
m_numGradientBits = configDataParallelSGD(L"gradientBits", defaultGradientBits);
m_zeroThresholdFor1Bit = configDataParallelSGD(L"useZeroThresholdFor1BitQuantization", true);
m_bufferedAsyncGradientAggregation = configDataParallelSGD(L"useBufferedAsyncGradientAggregation", false);
if (configParallelTrain.Exists(L"DataParallelSGD"))
{
const ConfigRecordType& configDataParallelSGD(configParallelTrain(L"DataParallelSGD", ConfigRecordType::Record()));
size_t defaultGradientBits = 8 * sizeofElemType;
m_numGradientBits = configDataParallelSGD(L"gradientBits", defaultGradientBits);
m_zeroThresholdFor1Bit = configDataParallelSGD(L"useZeroThresholdFor1BitQuantization", true);
m_bufferedAsyncGradientAggregation = configDataParallelSGD(L"useBufferedAsyncGradientAggregation", false);
if ( m_numGradientBits < 1 || m_numGradientBits > (8 * sizeofElemType) )
{
InvalidArgument("gradientBits must be in the range [1, 32] when using precision=float and in range [1, 64] when using precision=double!");
}
}
if (configParallelTrain.Exists(L"ModelAveragingSGD"))
{
const ConfigRecordType& configMASGD(configParallelTrain(L"ModelAveragingSGD", ConfigRecordType::Record()));
InvalidArgument("gradientBits must be in the range [1, 32] when using precision=float and in range [1, 64] when using precision=double!");
}
}
if (configParallelTrain.Exists(L"ModelAveragingSGD"))
{
const ConfigRecordType& configMASGD(configParallelTrain(L"ModelAveragingSGD", ConfigRecordType::Record()));
if (configMASGD.Exists(L"blockSizePerWorker") && configMASGD.Exists(L"blockSize"))
{
InvalidArgument("It is only allowed to set blockSizePerWorker or blockSize, not both of them");
@ -2655,8 +2655,8 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_modelAggregationBlockSize = 40000 * numMPIWorkers; // default value
}
#if 1 // legacy option
if (configMASGD.Exists(L"syncFrequencyInFrames"))
{
if (configMASGD.Exists(L"syncFrequencyInFrames"))
{
if (configMASGD.Exists(L"blockSizePerWorker") || configMASGD.Exists(L"blockSize"))
InvalidArgument("syncFrequencyInFrames is a deprecated alias of blockSizePerWorker. It is not allowed to specify both of them");
m_modelAggregationBlockSize = configMASGD(L"syncFrequencyInFrames");
@ -2672,15 +2672,15 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_modelAggregationBlockSize = configMASGD(L"syncPeriod");
m_modelAggregationBlockSize *= numMPIWorkers;
fprintf(stderr, "WARNING: option syncPeroid in ModelAveragingSGD is going to be deprecated. Please use blockSizePerWorker instead in the future.\n");
}
#endif
}
if (configParallelTrain.Exists(L"BlockMomentumSGD"))
{
#endif
}
if (configParallelTrain.Exists(L"BlockMomentumSGD"))
{
#ifndef CNTK_PARALLEL_TRAINING_SUPPORT
InvalidArgument("BlockMomentumSGD is not enabled in this version.\n");
InvalidArgument("BlockMomentumSGD is not enabled in this version.\n");
#else
const ConfigRecordType& configBMSGD(configParallelTrain(L"BlockMomentumSGD", ConfigRecordType::Record()));
const ConfigRecordType& configBMSGD(configParallelTrain(L"BlockMomentumSGD", ConfigRecordType::Record()));
if (configBMSGD.Exists(L"blockSize") && configBMSGD.Exists(L"blockSizePerWorker"))
{
InvalidArgument("It is only allowed to set blockSizePerWorker or blockSize, not both of them");
@ -2710,34 +2710,34 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
fprintf(stderr, "WARNING: option syncPeroid in BlockMomentumSGD is going to be deprecated. Please use blockSizePerWorker instead in the future.\n");
}
#endif
m_resetSGDMomentum = configBMSGD(L"resetSGDMomentum", true);
m_useNesterovBlockMomentum = configBMSGD(L"useNesterovMomentum", true);
m_blockLearningRate = configBMSGD(L"blockLearningRate", 1.0);
m_resetSGDMomentum = configBMSGD(L"resetSGDMomentum", true);
m_useNesterovBlockMomentum = configBMSGD(L"useNesterovMomentum", true);
m_blockLearningRate = configBMSGD(L"blockLearningRate", 1.0);
if (configBMSGD.Exists(L"blockMomentumPerSync") && configBMSGD.Exists(L"blockMomentumAsTimeConstant"))
{
InvalidArgument("It is only allowed to set either blockMomentumPerSync or blockMomentumAsTimeConstant, not both of them");
}
else if (configBMSGD.Exists(L"blockMomentumAsTimeConstant"))
{
m_blockMomentumAsTimeConstant = configBMSGD(L"blockMomentumAsTimeConstant");
}
if (configBMSGD.Exists(L"blockMomentumPerSync") && configBMSGD.Exists(L"blockMomentumAsTimeConstant"))
{
InvalidArgument("It is only allowed to set either blockMomentumPerSync or blockMomentumAsTimeConstant, not both of them");
}
else if (configBMSGD.Exists(L"blockMomentumAsTimeConstant"))
{
m_blockMomentumAsTimeConstant = configBMSGD(L"blockMomentumAsTimeConstant");
}
#if 1 // This option "blockMomentumPerSync" is going to be deprecated in the future
else if (configBMSGD.Exists(L"blockMomentumPerSync"))
{
double blockMomentum = configBMSGD(L"blockMomentumPerSync");
else if (configBMSGD.Exists(L"blockMomentumPerSync"))
{
double blockMomentum = configBMSGD(L"blockMomentumPerSync");
m_blockMomentumAsTimeConstant = BlockMomentumSGD<double>::Momentum2TimeConstant(blockMomentum, m_modelAggregationBlockSize);
}
}
#endif
else /*if (!configBMSGD.Exists(L"blockMomentumPerSync") && !configBMSGD.Exists(L"blockMomentumAsTimeConstant"))*/
{
else /*if (!configBMSGD.Exists(L"blockMomentumPerSync") && !configBMSGD.Exists(L"blockMomentumAsTimeConstant"))*/
{
double blockMomentum = 1.0 - 1.0 / (double)numMPIWorkers; // this is a default value which ensures each block update contributes equally
m_blockMomentumAsTimeConstant = BlockMomentumSGD<double>::Momentum2TimeConstant(blockMomentum, m_modelAggregationBlockSize);
}
}
#endif
InitializeAndCheckBlockMomentumSGDParameters();
}
}
} // if (!pMPI)
} // if (configSGD.Exists(L"ParallelTrain"))
}