RNNNodes finalized initial design

This commit is contained in:
Jasha Droppo 2016-04-22 15:48:33 -07:00
Родитель 29c065ae7c
Коммит bbd04b3a90
15 изменённых файлов: 741 добавлений и 1 удалений

Просмотреть файл

@ -12,6 +12,7 @@
#include "LinearAlgebraNodes.h"
#include "RecurrentNodes.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "NonlinearityNodes.h"
#include "ReshapingNodes.h"
#include "InputAndParamNodes.h"

Просмотреть файл

@ -11,6 +11,7 @@
#include "NDLNetworkBuilder.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "DeprecatedNodes.h"
#include "EvaluationNodes.h"
#include "InputAndParamNodes.h"

Просмотреть файл

@ -12,6 +12,7 @@
#include "LinearAlgebraNodes.h"
#include "NonlinearityNodes.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "RecurrentNodes.h"
#include "ReshapingNodes.h"
#include "TrainingNodes.h"

Просмотреть файл

@ -12,6 +12,7 @@
#include "ComputationNode.h"
#include "ConvolutionalNodes.h"
#include "RNNNodes.h"
#include "DeprecatedNodes.h"
#include "EvaluationNodes.h"
#include "InputAndParamNodes.h"
@ -134,6 +135,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateNode(const std::wstring& node
else if (nodeType == OperationNameOf(BatchNormalizationNode)) return New<BatchNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ConvolutionNode)) return New<ConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PoolingNode)) return New<PoolingNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(RNNNode)) return New<RNNNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SparseInputValue)) return New<SparseInputValue<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(InputValue)) return New<InputValue<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(LearnableParameter)) return New<LearnableParameter<ElemType>>(forward<_Types>(_Args)...);
@ -222,6 +224,12 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Creat
return net.AddNodeToNetWithElemType(New<SparseInputValue<ElemType>>(net.GetDeviceId(), inputName, imageLayout, dynamicAxisName));
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CreateRNNNode(const std::wstring& nodeName)
{
return net.AddNodeToNetWithElemType(New<ConvolutionNode<ElemType>>(net.GetDeviceId(), nodeName));
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CreateConvolutionNode(const std::wstring& nodeName,
const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels,

Просмотреть файл

@ -146,6 +146,7 @@
<ClInclude Include="ConvolutionalNodes.h" />
<ClInclude Include="DeprecatedNodes.h" />
<ClInclude Include="PreComputeNodes.h" />
<ClInclude Include="RNNNodes.h" />
<ClInclude Include="SpecialPurposeNodes.h" />
<ClInclude Include="EvaluationNodes.h" />
<ClInclude Include="InputAndParamNodes.h" />

Просмотреть файл

@ -141,6 +141,9 @@
<ClInclude Include="DeprecatedNodes.h">
<Filter>Nodes</Filter>
</ClInclude>
<ClInclude Include="RNNNodes.h">
<Filter>Nodes</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="Common">

Просмотреть файл

@ -0,0 +1,169 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include "RNNNodes.h"
#include <unordered_set>
#include <map>
#include <string>
#include <vector>
#include <stdexcept>
#include <list>
#include <memory>
#include <algorithm>
#include <utility>
#include <assert.h>
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// RNNNode
// -----------------------------------------------------------------------
template<class ElemType>
RNNNode<ElemType>::RNNNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name), m_numLayers(7), m_numHidden(123)
{
}
// This constructor helps with BrainScript integration
template<class ElemType>
RNNNode<ElemType>::RNNNode(const ScriptableObjects::IConfigRecordPtr configp)
: Base(configp->Get(L"deviceId"), L"<placeholder>"), m_numHidden(configp->Get(L"numHidden")), m_numLayers(configp->Get(L"numLayers")),
BackwardDataCalledYet(false)
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
template<class ElemType>
void RNNNode<ElemType>::Save(File& fstream) const
{
Base::Save(fstream);
// todo: save RNN topology
fstream << m_numHidden;
fstream << m_numLayers;
}
template<class ElemType>
void RNNNode<ElemType>::Load(File& fstream, size_t modelVersion)
{
Base::Load(fstream, modelVersion);
// load RNN topology
fstream >> m_numHidden;
fstream >> m_numLayers;
}
template<class ElemType>
TensorView<ElemType> RNNNode<ElemType>::TensorHelper(int inputIndex/*-1 for output*/, bool gradient/*instead of value*/, const FrameRange& fr)
{
auto input = inputIndex < 0 ? this : Input(inputIndex).get();
return gradient ? input->GradientTensorFor(SIZE_MAX, fr) : input->ValueTensorFor(SIZE_MAX, fr);
}
template<class ElemType>
void RNNNode<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
{
shapeY = shapeX;
shapeY.SwapDimsInPlace(1, 2);
TensorView<ElemType> Y(matY, TensorShape(shapeY.GetDims()));
TensorView<ElemType> X(matX, shapeY);
Y.AssignCopyOf(X);
shapeY = Y.GetShape();
};
template<class ElemType>
void RNNNode<ElemType>::ForwardProp(const FrameRange& fr)
{
// The parameters are stored in a column matrix
Matrix<ElemType>& paramW = Input(1)->Value();
TensorView<ElemType> outputY = ValueTensorFor(SIZE_MAX, fr);
m_transposedInput->Resize(Input(0)->Value());
TransposeHelper(Input(0)->ValuePtr(), Input(0)->GetTensorSliceFor(SIZE_MAX, fr), m_transposedInput, shapeXT);
m_transposedOutput->Resize(this->Value());
shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr));
shapeYT.SwapDimsInPlace(1, 2);
shapeYT = TensorShape(shapeYT.GetDims());
m_transposedOutput->RNNForward(*m_transposedInput, shapeXT, paramW, shapeYT, m_numLayers, m_numHidden);
TensorShape shapeY;
TransposeHelper(m_transposedOutput, TensorShape(shapeYT.GetDims()), this->ValuePtr(), shapeY);
BackwardDataCalledYet = false;
}
template<class ElemType>
void RNNNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
{
// ensure BackwardData is the first method called
if (!BackwardDataCalledYet)
{
Matrix<ElemType>& paramW = Input(1)->Value();
m_transposedDOutput->Resize(this->Gradient());
TransposeHelper(this->GradientPtr(), this->GetTensorSliceFor(SIZE_MAX, fr), m_transposedDOutput, shapeYT);
m_transposedDInput->Resize(Input(1)->Gradient());
m_transposedOutput->RNNBackwardData(*m_transposedDOutput, shapeYT, paramW, *m_transposedDInput, shapeXT);
BackwardDataCalledYet = true;
}
if (inputIndex == 1) // parameters
{
Matrix<ElemType>& paramDW = Input(1)->Gradient();
m_transposedOutput->RNNBackwardWeights(*m_transposedInput, shapeXT, *m_transposedOutput, shapeYT, paramDW);
}
else if (inputIndex == 0) // data
{
TensorShape tmp;
TransposeHelper(m_transposedDInput, shapeXT, Input(0)->GradientPtr(), tmp);
}
}
template<class ElemType>
void RNNNode<ElemType>::Validate(bool isFinalValidationPass)
{
// N.B.: I need both of these lines.
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// get tensor shapes
auto dimsA = Input(1)->GetSampleLayout().GetDims();
auto dimsB = Input(0)->GetSampleLayout().GetDims();
string dimsAstring = string(Input(1)->GetSampleLayout()); // for error messages
string dimsBstring = string(Input(0)->GetSampleLayout());
// validate and infer
if (isFinalValidationPass || (dimsA.size() > 0 && dimsB.size() > 0)) // only if we got at least some input dimensions to work with or need to wrap up
{
// now determine result dimensions
// bugbug - could want to squash output dims, need to reduce?
auto dimsC = dimsB;
//dimsC.resize(m_outputRank); // output dims
dimsC[0] = 2 * m_numHidden;
/// N.B. - this is the magical call, the reason for the function
/// dimensions would be outputRank * numSamples * minibatch * time
SetDims(TensorShape(dimsC), HasMBLayout());
// update dimensions of A
// update if LearnableParameter
// Input(0)->ValidateInferInputDimsFrom(TensorShape(dimsA));
}
};
template class RNNNode<float>;
template class RNNNode<double>;
} } }

Просмотреть файл

@ -0,0 +1,84 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include <unordered_set>
#include <map>
#include <string>
#include <vector>
#include <stdexcept>
#include <list>
#include <memory>
#include <algorithm>
#include <utility>
#include <assert.h>
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// RNNNode
// -----------------------------------------------------------------------
template <class ElemType>
class RNNNode : public ComputationNode<ElemType>, public NumInputs<2>
{
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"RNN"; }
using Base::OperationName; \
public:
RNNNode(DEVICEID_TYPE deviceId, const wstring& name);
RNNNode(const ScriptableObjects::IConfigRecordPtr configp);
void Save(File& fstream) const;
virtual void Load(File& fstream, size_t modelVersion) override;
public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override;
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override;
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override;
// request matrices needed to do node function value evaluation
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_transposedInput, matrixPool);
RequestMatrixFromPool(m_transposedOutput, matrixPool);
RequestMatrixFromPool(m_transposedDInput, matrixPool);
RequestMatrixFromPool(m_transposedDOutput, matrixPool);
}
// Is the output value of the computation node needed for computing
virtual bool OutputUsedInComputingInputNodesGradients() const { return false; }
// Is the output value of the specified input node needed for computing
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const { return false; }
protected:
bool BackwardDataCalledYet;
TensorShape shapeXT;
TensorShape shapeYT;
shared_ptr<Matrix<ElemType>> m_transposedInput;
shared_ptr<Matrix<ElemType>> m_transposedOutput;
shared_ptr<Matrix<ElemType>> m_transposedDInput;
shared_ptr<Matrix<ElemType>> m_transposedDOutput;
private:
TensorView<ElemType> TensorHelper(int inputIndex/*-1 for output*/, bool gradient/*instead of value*/, const FrameRange& fr);
void TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY);
size_t m_numLayers;
size_t m_numHidden;
};
} } }

Просмотреть файл

@ -73,7 +73,7 @@ protected:
m_inOutCuDnnT.UpdateBatchSize(srcGrad.GetNumCols());
cudnnBatchNormMode_t mode = m_spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
// REVIEW alexeyk: remove once Philly is upgraded to prod version. Also change betaParamDiff to 1 and update CNTK BN engine.
#if CUDNN_PATCHLEVEL >= 7
#if (CUDNN_MAJOR >=5 || CUDNN_PATCHLEVEL >= 7)
CUDNN_CALL(cudnnBatchNormalizationBackward(*m_cudnn, mode, &C::One, &C::One, &C::One, &C::Zero, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(srcGrad), m_inOutCuDnnT, ptr(grad),
m_scaleBiasCuDnnT, ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
#else

Просмотреть файл

@ -138,9 +138,12 @@ public:
}
// Must use CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING to get the same results as in reference engine.
RuntimeError("NOT IMPLEMENTED - Need to update cudnnSetPoolingNdDescriptor() for cudnn5 signature");
#if 0
CUDNN_CALL(cudnnSetPoolingNdDescriptor(m_pool,
kind == PoolKind::Max ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING,
(int)dims.size(), dims.data(), pad.data(), stride.data()));
#endif
}
~CuDnnPool()

450
Source/Math/CuDnnRNN.cpp Normal file
Просмотреть файл

@ -0,0 +1,450 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "Matrix.h"
#include "GPUMatrix.h"
#include <typeinfo>
#include <typeindex>
#include "CuDnnCommon.h"
template <>
const char* CudaErrString<cudnnStatus_t>(cudnnStatus_t x)
{
return cudnnGetErrorString(x);
}
namespace Microsoft { namespace MSR { namespace CNTK {
static bool IsGpu(DEVICEID_TYPE deviceId)
{
return deviceId >= 0;
}
class CuDnnDropout
{
CuDnn::ptr_t m_cudnn;
unsigned long long m_seed = 0xdeadbeefull;
public:
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull)
: m_dropoutDesc(nullptr), m_cudnn(CuDnn::Instance())
{
CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc));
size_t stateSize;
void *states;
CUDNN_CALL(cudnnDropoutGetStatesSize(*m_cudnn, &stateSize));
// bugbug: possible leak. Does CuDnn release this for us?
CUDA_CALL(cudaMalloc(&states, stateSize));
CUDNN_CALL(cudnnSetDropoutDescriptor(m_dropoutDesc,
*m_cudnn,
dropout,
states,
stateSize,
seed));
}
~CuDnnDropout()
{
if (m_dropoutDesc != nullptr)
{
cudnnDestroyDropoutDescriptor(m_dropoutDesc);
m_dropoutDesc = nullptr;
}
}
operator cudnnDropoutDescriptor_t() const
{
return m_dropoutDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnDropout);
private:
cudnnDropoutDescriptor_t m_dropoutDesc;
};
class CuDnnRNN
{
CuDnnDropout m_dropout;
public:
CuDnnRNN(cudnnDataType_t dataType)
: m_rnnDesc(nullptr)
{
CUDNN_CALL(cudnnCreateRNNDescriptor(&m_rnnDesc));
// hard code these for now, expose other types later.
cudnnRNNMode_t RNNMode = CUDNN_LSTM;
int hiddenSize = 512;
int seqLength = 512;
int numLayers = 6;
bool bidirectional = true;
CUDNN_CALL(cudnnSetRNNDescriptor(m_rnnDesc,
hiddenSize,
seqLength,
numLayers,
m_dropout,
CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
RNNMode,
dataType));
}
~CuDnnRNN()
{
if (m_rnnDesc != nullptr)
{
cudnnDestroyRNNDescriptor(m_rnnDesc);
m_rnnDesc = nullptr;
}
}
operator cudnnRNNDescriptor_t() const
{
return m_rnnDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnRNN);
private:
cudnnRNNDescriptor_t m_rnnDesc;
};
class CuDnnFilter
{
CuDnn::ptr_t m_cudnn;
public:
CuDnnFilter(const CuDnnRNN& rnn, const cudnnTensorDescriptor_t *xDesc) :
m_cudnn(CuDnn::Instance())
{
CUDNN_CALL(cudnnCreateFilterDescriptor(&m_filterDesc));
size_t filterSize;
CUDNN_CALL(cudnnGetRNNParamsSize(*m_cudnn, rnn, xDesc, &filterSize));
int dimW[3];
// bugbug: hard-wired for float
dimW[0] = (int)(filterSize / sizeof(float));
dimW[1] = 1;
dimW[2] = 1;
CUDNN_CALL(cudnnSetFilterNdDescriptor(m_filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dimW));
}
~CuDnnFilter()
{
if (m_filterDesc != nullptr)
{
cudnnDestroyFilterDescriptor(m_filterDesc);
m_filterDesc = nullptr;
}
}
operator cudnnFilterDescriptor_t() const
{
return m_filterDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnFilter);
private:
cudnnFilterDescriptor_t m_filterDesc;
};
template <class ElemType>
class CuDnnRNNExecutor
{
CuDnn::ptr_t m_cudnn;
cudnnDataType_t m_dataType;
using Mat = Matrix<ElemType>;
public:
CuDnnRNNExecutor(const TensorShape &inputShape, const TensorShape &outputShape, DEVICEID_TYPE deviceId) :
m_cudnn(CuDnn::Instance()),
m_dataType(CuDnnTensor::GetDataType<ElemType>())
{
}
protected:
void EnsureCompatible() override
{
if (!IsGpu(m_deviceId))
RuntimeError("cuDNN convolution engine supports GPU devices only.");
}
void EnsureRNNInitialized() override
{
if (m_rnnT == nullptr)
{
m_rnnT = std::make_unique<CuDnnRNN>(m_dataType);
}
}
void ForwardCore(const TensorShape& in, const Mat& weights, TensorShape& out, Mat& workspace, Mat& reserve) override
{
// get input data layout
// source shape, stride is [inputSize, seqLength, miniBatch], [1, inputSize, inputSize*seqLength]
// target shape, stride is [inputsize, miniBatch, seqLength], [1, inputSize*seqLength, inputSize]
size_t inputSize = in.GetDim(0);
size_t seqLength = in.GetDim(1);
size_t miniBatch = in.GetDim(2);
int dimX = { inputSize, miniBatch, 1 };
int strideX = { 1, dimX[0] * dimX[1], dimX[0] }
vector<cudnnTensorDescriptor_t> xDesc(seqLength);
for (int i = 0; i < seqLength; i++) {
cudnnErrCheck(cudnnCreateTensorDescriptor(&xDesc[i]));
cudnnErrCheck(cudnnSetTensorNdDescriptor(xDesc[i], CUDNN_DATA_FLOAT, 3, dimS, strideX));
}
// get output data layout
// source shape, stride is [outputSize, seqLength, miniBatch], [1, outputSize, outputSize*seqLength]
// target shape, stride is [outputSize, miniBatch, seqLength], [1, outputSize*seqLength, outputSize]
size_t outputSize = in.GetDim(0);
if (in.GetDim(1) != seqLength)
RuntimeError("CuDnn ForwardCore: Output sequence length doesn't match input sequence length");
if (in.GetDim(2) != miniBatch)
RuntimeError("CuDnn ForwardCore: Output minibatch size doesn't match input minibatch size");
int dimY = { outputSize, miniBatch, 1 };
int strideX = { 1, dimY[0] * dimY[1], dimY[0] }
vector<cudnnTensorDescriptor_t> yDesc(seqLength);
for (int i = 0; i < seqLength; i++) {
cudnnErrCheck(cudnnCreateTensorDescriptor(&yDesc[i]));
cudnnErrCheck(cudnnSetTensorNdDescriptor(yDesc[i], CUDNN_DATA_FLOAT, 3, dimY, strideY));
}
// ensure workspace and reserve are large enough
{
size_t workSize;
size_t reserveSize;
// Need for every pass
CUDNN_CALL(cudnnGetRNNWorkspaceSize(m_cudnn, m_rnnT, xDesc, &workSize));
// Only needed in training, can't be touched between passes.
CUDNN_CALL(cudnnGetRNNTrainingReserveSize(m_cudnn, m_rnnT, xDesc, &reserveSize));
// convert from bytes to ElemType
workSize = (workSize + sizeof(ElemType) - 1) / (sizeof(ElemType));
reserveSize = (reserveSize + sizeof(ElemType) - 1) / sizeof(ElemType);
reserve.Resize(reserveSize);
workspace.Resize(workSize);
}
CUDNN_CALL(cudnnRNNForwardTraining(m_cudnn,
m_rnnT,
xDesc.data(), ptr(in),
0, nullptr,
0, nullptr,
wDesc, ptr(weights),
yDesc.data(), ptr(out),
0, nullptr,
0, nullptr,
ptr(workspace), workSize,
ptr(reserveSpace), reserveSize));
}
void BackwardDataCore(const Mat& srcGrad, const Mat& kernel, Mat& grad, Mat& workspace) override
{
size_t batchSize = srcGrad.GetNumCols();
// Find best algo and allocate temp buffer, if needed.
auto finder = [this](int& calgo, cudnnConvolutionBwdDataAlgoPerf_t algoPerf[MaxAlgoCount]) -> cudnnStatus_t
{
return cudnnFindConvolutionBackwardDataAlgorithm(*m_cudnn, *m_kernelT, m_outT, *m_conv, m_inT, MaxAlgoCount, &calgo, algoPerf);
};
auto staticFinder = [this](cudnnConvolutionBwdDataAlgo_t& algo) -> cudnnStatus_t
{
return cudnnGetConvolutionBackwardDataAlgorithm(*m_cudnn, *m_kernelT, m_outT, *m_conv, m_inT, CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE, 0, &algo);
};
FindBestAlgo(batchSize, m_backDataAlgo, finder, staticFinder);
if (m_backDataAlgo.Algo.memory > 0)
workspace.Resize((m_backDataAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardData(*m_cudnn, &C::One, *m_kernelT, ptr(kernel), m_outT, ptr(srcGrad), *m_conv, m_backDataAlgo.Algo.algo,
ptr(workspace), m_backDataAlgo.Algo.memory, &C::One, m_inT, ptr(grad)));
}
void BackwardKernelCore(const Mat& srcGrad, const Mat& in, Mat& kernelGrad, bool /*allowReuse*/, Mat& workspace) override
{
size_t batchSize = in.GetNumCols();
// Find best algo and allocate temp buffer, if needed.
auto finder = [this](int& calgo, cudnnConvolutionBwdFilterAlgoPerf_t algoPerf[MaxAlgoCount]) -> cudnnStatus_t
{
return cudnnFindConvolutionBackwardFilterAlgorithm(*m_cudnn, m_inT, m_outT, *m_conv, *m_kernelT, MaxAlgoCount, &calgo, algoPerf);
};
auto staticFinder = [this](cudnnConvolutionBwdFilterAlgo_t& algo) -> cudnnStatus_t
{
return cudnnGetConvolutionBackwardFilterAlgorithm(*m_cudnn, m_inT, m_outT, *m_conv, *m_kernelT, CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE, 0, &algo);
class CuDnnDropout
{
unsigned long long m_seed = 0xdeadbeefull;
public:
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull)
: m_dropoutDesc(nullptr)
{
CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc));
size_t stateSize;
void *states;
CUDNN_CALL(cudnnDropoutGetStatesSize(CuDnn::Instance(), &stateSize));
// bugbug: possible leak. Does CuDnn release this for us?
CUDA_CALL(cudaMalloc(&states, stateSize));
CUDA_CALL(cudnnSetDropoutDescriptor(m_dropoutDesc,
CuDnn::Instance(),
dropout,
states,
stateSize,
seed));
}
~CuDnnDropout()
{
if (m_dropoutDesc != nullptr)
{
cudnnDestroyDropoutDescriptor(m_dropoutDesc);
m_dropoutDesc = nullptr;
}
}
operator cudnnDropoutDescriptor_t() const
{
return m_dropoutDesc;
}
DISABLE_COPY_AND_MOVE(CuDnnDropout);
private:
cudnnDropoutDescriptor_t m_dropoutDesc;
};
};
FindBestAlgo(batchSize, m_backFiltAlgo, finder, staticFinder);
if (m_backFiltAlgo.Algo.memory > 0)
workspace.Resize((m_backFiltAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardFilter(*m_cudnn, &C::One, m_inT, ptr(in), m_outT, ptr(srcGrad), *m_conv, m_backFiltAlgo.Algo.algo,
ptr(workspace), m_backFiltAlgo.Algo.memory, &C::One, *m_kernelT, ptr(kernelGrad)));
}
private:
using C = Consts<ElemType>;
template <typename TAlgo, typename TFinder, typename TStaticFinder>
void FindBestAlgo(size_t batchSize, TAlgo& algo, TFinder finder, TStaticFinder staticFinder)
{
if (!algo.NeedAutotuning(batchSize))
return;
m_inT.UpdateBatchSize(batchSize);
m_outT.UpdateBatchSize(batchSize);
using CuDnnAlgoT = decltype(TAlgo::Algo);
CuDnnAlgoT algoPerf[MaxAlgoCount];
int calgo = 0;
cudnnStatus_t err = finder(calgo, algoPerf);
// Alloc failed - usually means cuDNN runtime auto-tuner could not allocate workspace.
// In such case, use static auto-tuner with no workspace.
if (err == CUDNN_STATUS_ALLOC_FAILED)
{
decltype(CuDnnAlgoT::algo) noMemAlgo;
CUDNN_CALL(staticFinder(noMemAlgo));
algo.CurMBSize = batchSize;
algo.Algo = algoPerf[0];
algo.Algo.algo = noMemAlgo;
algo.Algo.memory = 0;
algo.Algo.status = CUDNN_STATUS_SUCCESS;
algo.NoWorkspaceAlgo = noMemAlgo;
return;
}
CUDNN_CALL(err);
assert(calgo > 0);
size_t inputSampleSize = m_geometry->InputShape().GetNumElements();
size_t maxMem = m_maxTempMemSizeInSamples == 0 ? (std::numeric_limits<size_t>::max)() : inputSampleSize * m_maxTempMemSizeInSamples * sizeof(ElemType);
// Find best (fastest) algorithm which satisfies workspace requirements.
auto res = std::find_if(algoPerf, algoPerf + calgo,
[=](const CuDnnAlgoT& cur)
{
return cur.status == CUDNN_STATUS_SUCCESS && cur.memory <= maxMem;
});
if (res == algoPerf + calgo)
RuntimeError("cuDNN could not find suitable algorithm for the current convolution configuration.");
algo.CurMBSize = batchSize;
algo.Algo = *res;
// Find fastest algorithm that does NOT require workspace. It is used as a fallback algo in Forward function.
res = std::find_if(algoPerf, algoPerf + calgo,
[](const CuDnnAlgoT& cur)
{
return cur.status == CUDNN_STATUS_SUCCESS && cur.memory == 0;
});
if (res == algoPerf + calgo)
{
// In theory, this should never happen.
RuntimeError("cuDNN could not find no-workspace algorithm for the current convolution configuration.");
}
else
algo.NoWorkspaceAlgo = (*res).algo;
}
static ElemType* ptr(Mat& src)
{
return src.Data();
}
static const ElemType* ptr(const Mat& src)
{
return src.Data();
}
private:
template <typename T>
struct ConvAlgoInfo
{
using CuDnnAlgoT = decltype(T::algo);
ConvAlgoInfo()
: CurMBSize(0)
{
Algo.status = CUDNN_STATUS_NOT_INITIALIZED;
NoWorkspaceAlgo = (CuDnnAlgoT)-1;
}
// Current mini-batch size, needed for re-computing statistics in auto-tuner.
size_t CurMBSize;
T Algo;
CuDnnAlgoT NoWorkspaceAlgo;
bool NeedAutotuning(size_t batchSize)
{
// Need to re-run auto-tuner in case minibatch size is increased.
// If minibatch size is decreased we assume that previously selected algorithm requires less or the same amount of workspace.
// This is done to avoid re-running auto-tuner every time in case minibatch size changes frequently (e.g. when distributed reading is enabled).
// REVIEW alexeyk: potentially, this might cause some perf issues if better (faster) algo can be selected for a smaller mininbatch.
// We also need to reset auto-tuning status at the beginning of each epoch but ComputationNode currently does not provide such notification.
// We assume no other dimensions of tensors can change so we don't check it.
// REVIEW alexeyk: review once we get response from NVIDIA.
return (Algo.status != CUDNN_STATUS_SUCCESS || batchSize > CurMBSize);
}
};
CuDnn::ptr_t m_cudnn;
cudnnDataType_t m_dataType;
CuDnnTensor m_inT;
CuDnnTensor m_outT;
// Convolution specific.
std::unique_ptr<CuDnnKernel> m_kernelT;
std::unique_ptr<CuDnnConv> m_conv;
// Pooling specific.
std::unique_ptr<CuDnnPool> m_pool;
ConvAlgoInfo<cudnnConvolutionFwdAlgoPerf_t> m_fwdAlgo;
ConvAlgoInfo<cudnnConvolutionBwdDataAlgoPerf_t> m_backDataAlgo;
ConvAlgoInfo<cudnnConvolutionBwdFilterAlgoPerf_t> m_backFiltAlgo;
};
} } }

Просмотреть файл

@ -3138,6 +3138,12 @@ void GPUMatrix<ElemType>::BatchNormalizationBackward(const GPUMatrix<ElemType>&
in.Data(), Data(), grad.Data(), scale.Data(), scaleGrad.Data(), biasGrad.Data(), saveMean.Data(), saveInvStdDev.Data(), GetStream());
}
#pragma RNN Functions
template <class ElemType>
void GPUMatrix<ElemType>::RNNForward(const GPUMatrix<ElemType>& w, int numLayers, bool bidirectional, const GPUMatrix<ElemType> &output) const
{
}
#pragma region Static BLAS Functions
// float/double overloads of cublasSgemm()/cublasDgemm()
static cublasStatus_t cublas_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc)

Просмотреть файл

@ -141,6 +141,7 @@ private:
#pragma warning(push)
#pragma warning(disable : 4251)
mutable std::unique_ptr<conc_stack<std::unique_ptr<GPUMatrix<ElemType>>>> m_workspace;
mutable std::unique_ptr<struct RNNInfo> m_rnnworkspace;
#pragma warning(pop)
private:
@ -447,6 +448,11 @@ public:
void BatchNormalizationBackward(const GPUMatrix<ElemType>& in, GPUMatrix<ElemType>& grad, const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& saveMean, const GPUMatrix<ElemType>& saveInvStdDev,
GPUMatrix<ElemType>& scaleGrad, GPUMatrix<ElemType>& biasGrad) const;
// RNN support functions
void RNNForward(const GPUMatrix<ElemType>& w, int numLayers, bool bidirectional, const GPUMatrix<ElemType> &output ) const;
void RNNBackwardData() const;
void RNNBackwardGradient() const;
public:
// static BLAS functions
static void MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const bool transposeA, const GPUMatrix<ElemType>& b, const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c);

Просмотреть файл

@ -190,6 +190,7 @@ if exist "$(CuDnnDll)" (xcopy /Y "$(CuDnnDll)" "$(OutputPath)")
<FileType>CppCode</FileType>
</CudaCompile>
<ClCompile Include="CuDnnCommon.cpp" />
<ClCompile Include="CuDnnRNN.cpp" />
<ClCompile Include="GPUDataTransferer.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>

Просмотреть файл

@ -51,6 +51,9 @@
<ClCompile Include="CuDnnCommon.cpp">
<Filter>GPU\CuDnn</Filter>
</ClCompile>
<ClCompile Include="CuDnnRNN.cpp">
<Filter>GPU\RNN</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\Common\Include\File.h">
@ -168,5 +171,8 @@
<Filter Include="GPU\CuDnn">
<UniqueIdentifier>{05351afa-de95-40c8-830a-d70eede55dc0}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\RNN">
<UniqueIdentifier>{da8154d0-d1f4-44a5-b64a-5e955ead56fa}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>