Finished refactoring of BN node/engine.
This commit is contained in:
Родитель
3288306efe
Коммит
bf413614f8
|
@ -6,7 +6,7 @@
|
|||
|
||||
#include "Basics.h"
|
||||
#include "ComputationNode.h"
|
||||
#include "ConvolutionEngine.h"
|
||||
#include "BatchNormalizationEngine.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
@ -1657,16 +1657,12 @@ public:
|
|||
const Matrix<ElemType>& scale = Input(1)->Value();
|
||||
const Matrix<ElemType>& bias = Input(2)->Value();
|
||||
|
||||
size_t batchSize = sliceInputValue.GetNumCols();
|
||||
m_inT->setN(batchSize);
|
||||
assert(m_convEng != nullptr);
|
||||
|
||||
auto sliceInputGrad = Input(0)->GradientFor(fr);
|
||||
m_dScale->Resize(scale);
|
||||
m_dBias->Resize(bias);
|
||||
// Compute all derivatives in one step. Save derivatives with respect to scale and bias in temp matrices.
|
||||
m_convEng->BackwardNormalizeBatch(*m_inT, sliceInputValue, sliceOutputGrad, sliceInputGrad, *m_scaleBiasT, scale, m_spatial,
|
||||
*m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias);
|
||||
m_bnEng->Backward(sliceInputValue, sliceOutputGrad, sliceInputGrad, scale,
|
||||
*m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias);
|
||||
}
|
||||
else if (inputIndex == 1) // derivative with respect to the scale
|
||||
{
|
||||
|
@ -1707,14 +1703,11 @@ public:
|
|||
|
||||
Matrix<ElemType> sliceOutputValue = ValueFor(fr);
|
||||
|
||||
size_t batchSize = sliceInputValue.GetNumCols();
|
||||
m_inT->setN(batchSize);
|
||||
assert(m_convEng != nullptr);
|
||||
#if NANCHECK
|
||||
sliceInputValue.HasNan("BatchNormalization-input");
|
||||
#endif
|
||||
if (m_eval)
|
||||
m_convEng->NormalizeBatchInference(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, runMean, runInvStdDev, sliceOutputValue);
|
||||
m_bnEng->NormalizeBatchInference(sliceInputValue, scale, bias, runMean, runInvStdDev, sliceOutputValue);
|
||||
else
|
||||
{
|
||||
double expAvgFactor;
|
||||
|
@ -1736,8 +1729,8 @@ public:
|
|||
m_saveMean->Resize(runMean);
|
||||
m_saveInvStdDev->Resize(runMean);
|
||||
|
||||
m_convEng->NormalizeBatch(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, expAvgFactor, runMean, runInvStdDev,
|
||||
sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev);
|
||||
m_bnEng->Forward(sliceInputValue, scale, bias, expAvgFactor, runMean, runInvStdDev,
|
||||
sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev);
|
||||
|
||||
m_mbCount++;
|
||||
}
|
||||
|
@ -1772,24 +1765,10 @@ public:
|
|||
|
||||
auto shape = GetSampleLayout();
|
||||
|
||||
if (m_factory == nullptr)
|
||||
m_factory = ConvolutionEngineFactory<ElemType>::Create(m_deviceId, ConvolutionEngineFactory<ElemType>::EngineType::Auto, m_imageLayoutKind);
|
||||
if (m_convEng == nullptr)
|
||||
m_convEng = m_factory->CreateConvEngine(m_deviceId, m_imageLayoutKind, 0, m_useCntkEngine ? BatchNormImpl::Cntk : BatchNormImpl::CuDnn);
|
||||
if (m_spatial)
|
||||
if (m_bnEng == nullptr)
|
||||
{
|
||||
auto dims = ImageDimensions(shape, m_imageLayoutKind);
|
||||
if (m_inT == nullptr)
|
||||
m_inT = m_factory->CreateTensor(dims.m_width, dims.m_height, dims.m_numChannels, 1);
|
||||
if (m_scaleBiasT == nullptr)
|
||||
m_scaleBiasT = m_factory->CreateTensor(1, 1, dims.m_numChannels, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_inT == nullptr)
|
||||
m_inT = m_factory->CreateTensor(shape.GetNumElements(), 1, 1, 1);
|
||||
if (m_scaleBiasT == nullptr)
|
||||
m_scaleBiasT = m_factory->CreateTensor(shape.GetNumElements(), 1, 1, 1);
|
||||
m_bnEng = BatchNormEngine<ElemType>::Create(m_deviceId, shape, m_spatial, m_imageLayoutKind,
|
||||
m_useCntkEngine ? BatchNormEngineKind::Cntk : BatchNormEngineKind::CuDnn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1869,10 +1848,7 @@ private:
|
|||
// Stores bias derivatives.
|
||||
shared_ptr<Matrix<ElemType>> m_dBias;
|
||||
|
||||
std::unique_ptr<ConvolutionEngineFactory<ElemType>> m_factory;
|
||||
std::unique_ptr<ConvolutionEngine<ElemType>> m_convEng;
|
||||
std::unique_ptr<ConvolutionTensor4D> m_inT;
|
||||
std::unique_ptr<ConvolutionTensor4D> m_scaleBiasT;
|
||||
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
|
||||
};
|
||||
|
||||
template class BatchNormalizationNode<float>;
|
||||
|
|
|
@ -13,33 +13,50 @@ template <class ElemType>
|
|||
void BatchNormEngine<ElemType>::Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev,
|
||||
Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev)
|
||||
{
|
||||
assert(in.GetNumRows() == out.GetNumRows());
|
||||
assert(in.GetNumRows() == m_inOutT.GetNumElements());
|
||||
assert(out.GetNumRows() == m_inOutT.GetNumElements());
|
||||
assert(in.GetNumCols() == out.GetNumCols());
|
||||
assert(std::isfinite(epsilon) && epsilon > 0);
|
||||
assert(std::isfinite(expAvgFactor) && (0 < expAvgFactor && expAvgFactor <= 1));
|
||||
if (!m_spatial)
|
||||
{
|
||||
assert(in.GetNumRows() == scale.GetNumRows());
|
||||
assert(in.GetNumRows() == bias.GetNumRows());
|
||||
assert(in.GetNumRows() == runMean.GetNumRows());
|
||||
assert(in.GetNumRows() == runInvStdDev.GetNumRows());
|
||||
assert(in.GetNumRows() == saveMean.GetNumRows());
|
||||
assert(in.GetNumRows() == saveInvStdDev.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == scale.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == bias.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == runMean.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == runInvStdDev.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == saveMean.GetNumRows());
|
||||
assert(m_inOutT.GetNumElements() == saveInvStdDev.GetNumRows());
|
||||
}
|
||||
else
|
||||
{
|
||||
assert((in.GetNumRows() % scale.GetNumRows()) == 0);
|
||||
assert((in.GetNumRows() % bias.GetNumRows()) == 0);
|
||||
assert((in.GetNumRows() % runMean.GetNumRows()) == 0);
|
||||
assert((in.GetNumRows() % runInvStdDev.GetNumRows()) == 0);
|
||||
assert((in.GetNumRows() % saveMean.GetNumRows()) == 0);
|
||||
assert((in.GetNumRows() % saveInvStdDev.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % scale.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % bias.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % runMean.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % runInvStdDev.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % saveMean.GetNumRows()) == 0);
|
||||
assert((m_inOutT.GetNumElements() % saveInvStdDev.GetNumRows()) == 0);
|
||||
}
|
||||
|
||||
EnsureCompatible();
|
||||
ForwardCore(in, scale, bias, expAvgFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void BatchNormEngine<ElemType>::ForwardInference(const Mat& in, const Mat& scale, const Mat& bias,
|
||||
const Mat& runMean, const Mat& runInvStdDev, Mat& out)
|
||||
{
|
||||
EnsureCompatible();
|
||||
ForwardInferenceCore(in, scale, bias, runMean, runInvStdDev, out);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void BatchNormEngine<ElemType>::Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale,
|
||||
const Mat& saveMean, const Mat& saveInvStdDev, Mat& scaleGrad, Mat& biasGrad)
|
||||
{
|
||||
EnsureCompatible();
|
||||
BackwardCore(in, srcGrad, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad);
|
||||
}
|
||||
|
||||
template class BatchNormEngine<float>;
|
||||
template class BatchNormEngine<double>;
|
||||
|
||||
|
@ -52,8 +69,8 @@ public:
|
|||
|
||||
public:
|
||||
CntkBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout)
|
||||
: Base(deviceId, inOutT, scaleBiasT, spatial, imageLayout)
|
||||
bool spatial, ImageLayoutKind imageLayout)
|
||||
: Base(deviceId, inOutT, spatial, imageLayout)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -61,7 +78,6 @@ protected:
|
|||
using Base::m_deviceId;
|
||||
using Base::m_imageLayout;
|
||||
using Base::m_inOutT;
|
||||
using Base::m_scaleBiasT;
|
||||
using Base::m_spatial;
|
||||
|
||||
void EnsureCompatible() override
|
||||
|
@ -75,34 +91,47 @@ protected:
|
|||
{
|
||||
in.BatchNormalizationForward(scale, bias, expAvgFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev);
|
||||
}
|
||||
|
||||
void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) override
|
||||
{
|
||||
in.BatchNormalizationForwardInference(scale, bias, runMean, runInvStdDev, out);
|
||||
}
|
||||
|
||||
void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
|
||||
Mat& scaleGrad, Mat& biasGrad) override
|
||||
{
|
||||
srcGrad.BatchNormalizationBackward(in, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad);
|
||||
}
|
||||
};
|
||||
|
||||
template class CntkBatchNormEngine<float>;
|
||||
template class CntkBatchNormEngine<double>;
|
||||
|
||||
template <typename T>
|
||||
bool HasFlag(T src, T testFlag)
|
||||
{
|
||||
return ((int)src & (int)testFlag) != 0;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
std::unique_ptr<BatchNormEngine<ElemType>> BatchNormEngine<ElemType>::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout,
|
||||
bool spatial, ImageLayoutKind imageLayout,
|
||||
BatchNormEngineKind enabledEngines = BatchNormEngineKind::All)
|
||||
{
|
||||
if (spatial && imageLayout == ImageLayoutKind::HWC)
|
||||
InvalidArgument("Batch normalization is not supported for legacy(HWC) layout. Please use cudnn(CHW) layout instead.");
|
||||
|
||||
auto isEnabled = [=](BatchNormEngineKind eng) { return ((int)enabledEngines & (int)eng) != 0; };
|
||||
// Use CNTK as default batch norm engine.
|
||||
if (isEnabled(BatchNormEngineKind::Cntk))
|
||||
if (HasFlag(enabledEngines, BatchNormEngineKind::Cntk))
|
||||
{
|
||||
fprintf(stderr, "Using CNTK batch normalization engine.\n");
|
||||
return std::make_unique<CntkBatchNormEngine<ElemType>>(deviceId, inOutT, scaleBiasT, spatial, imageLayout);
|
||||
return std::make_unique<CntkBatchNormEngine<ElemType>>(deviceId, inOutT, spatial, imageLayout);
|
||||
}
|
||||
|
||||
if (isEnabled(BatchNormEngineKind::CuDnn))
|
||||
if (HasFlag(enabledEngines, BatchNormEngineKind::CuDnn))
|
||||
{
|
||||
fprintf(stderr, "Using cuDNN batch normalization engine.\n");
|
||||
return CuDnnBatchNormEngineFactory<ElemType>::Create(deviceId, inOutT, scaleBiasT, spatial, imageLayout);
|
||||
return CuDnnBatchNormEngineFactory<ElemType>::Create(deviceId, inOutT, spatial, imageLayout);
|
||||
}
|
||||
|
||||
RuntimeError("Failed to find appropriate batch normalization engine.");
|
||||
RuntimeError("Could not find appropriate batch normalization engine.");
|
||||
}
|
||||
|
||||
} } }
|
||||
|
|
|
@ -37,22 +37,21 @@ public:
|
|||
void Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev,
|
||||
Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev);
|
||||
|
||||
//void ForwardInference(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev,
|
||||
// Mat& out);
|
||||
void ForwardInference(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out);
|
||||
|
||||
//void Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
|
||||
// Mat& scaleGrad, Mat& biasGrad);
|
||||
void Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
|
||||
Mat& scaleGrad, Mat& biasGrad);
|
||||
|
||||
static std::unique_ptr<BatchNormEngine<ElemType>> Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout,
|
||||
bool spatial, ImageLayoutKind imageLayout,
|
||||
BatchNormEngineKind enabledEngines = BatchNormEngineKind::All);
|
||||
|
||||
DISABLE_COPY_AND_MOVE(BatchNormEngine);
|
||||
|
||||
protected:
|
||||
BatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout)
|
||||
: m_deviceId(deviceId), m_inOutT(inOutT), m_scaleBiasT(scaleBiasT), m_spatial(spatial), m_imageLayout(imageLayout)
|
||||
bool spatial, ImageLayoutKind imageLayout)
|
||||
: m_deviceId(deviceId), m_inOutT(inOutT), m_spatial(spatial), m_imageLayout(imageLayout)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -61,10 +60,14 @@ protected:
|
|||
virtual void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev,
|
||||
Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) = 0;
|
||||
|
||||
virtual void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) = 0;
|
||||
|
||||
virtual void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
|
||||
Mat& scaleGrad, Mat& biasGrad) = 0;
|
||||
|
||||
protected:
|
||||
DEVICEID_TYPE m_deviceId;
|
||||
TensorShape m_inOutT;
|
||||
TensorShape m_scaleBiasT;
|
||||
bool m_spatial;
|
||||
ImageLayoutKind m_imageLayout;
|
||||
};
|
||||
|
|
|
@ -4293,8 +4293,52 @@ void CPUMatrix<ElemType>::BatchNormalizationForward(const CPUMatrix<ElemType>& s
|
|||
CPUMatrix<ElemType>& out, double epsilon, CPUMatrix<ElemType>& saveMean, CPUMatrix<ElemType>& saveInvStdDev) const
|
||||
{
|
||||
UNUSED(scale); UNUSED(bias); UNUSED(expAvgFactor); UNUSED(runMean); UNUSED(runInvStdDev); UNUSED(out); UNUSED(epsilon); UNUSED(saveMean); UNUSED(saveInvStdDev);
|
||||
RuntimeError("Not yet implemented.");
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::BatchNormalizationForwardInference(const CPUMatrix<ElemType>& scale, const CPUMatrix<ElemType>& bias,
|
||||
const CPUMatrix<ElemType>& runMean, const CPUMatrix<ElemType>& runInvStdDev,
|
||||
CPUMatrix<ElemType>& out) const
|
||||
{
|
||||
assert((GetNumRows() % scale.GetNumRows()) == 0);
|
||||
|
||||
bool spatial = GetNumRows() != scale.GetNumRows();
|
||||
if (spatial)
|
||||
{
|
||||
size_t spatialSize = GetNumRows() / scale.GetNumRows();
|
||||
#pragma omp parallel for
|
||||
for (long icol = 0; icol < out.GetNumCols(); icol++)
|
||||
{
|
||||
for (long irow = 0; irow < out.GetNumRows(); irow++)
|
||||
{
|
||||
size_t imap = irow / spatialSize;
|
||||
out(irow, icol) = scale(imap, 0) * ((*this)(irow, icol) - runMean(imap, 0)) * runInvStdDev(imap, 0) + bias(imap, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma omp parallel for
|
||||
for (long icol = 0; icol < out.GetNumCols(); icol++)
|
||||
{
|
||||
for (long irow = 0; irow < out.GetNumRows(); irow++)
|
||||
{
|
||||
out(irow, icol) = scale(irow, 0) * ((*this)(irow, icol) - runMean(irow, 0)) * runInvStdDev(irow, 0) + bias(irow, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::BatchNormalizationBackward(const CPUMatrix<ElemType>& in, CPUMatrix<ElemType>& grad, const CPUMatrix<ElemType>& scale, const CPUMatrix<ElemType>& saveMean, const CPUMatrix<ElemType>& saveInvStdDev,
|
||||
CPUMatrix<ElemType>& scaleGrad, CPUMatrix<ElemType>& biasGrad) const
|
||||
{
|
||||
UNUSED(in); UNUSED(grad); UNUSED(scale); UNUSED(saveMean); UNUSED(saveInvStdDev); UNUSED(scaleGrad); UNUSED(biasGrad);
|
||||
RuntimeError("Not yet implemented.");
|
||||
}
|
||||
|
||||
|
||||
#pragma region Static BLAS Functions
|
||||
|
||||
/// <summary>Matrix-matrix multiply with col-major matrices (a and b may be transposed): c = alpha * op(a) * op(b) + beta*c</summary>
|
||||
|
|
|
@ -335,6 +335,9 @@ public:
|
|||
|
||||
void BatchNormalizationForward(const CPUMatrix<ElemType>& scale, const CPUMatrix<ElemType>& bias, double expAvgFactor, CPUMatrix<ElemType>& runMean, CPUMatrix<ElemType>& runInvStdDev,
|
||||
CPUMatrix<ElemType>& out, double epsilon, CPUMatrix<ElemType>& saveMean, CPUMatrix<ElemType>& saveInvStdDev) const;
|
||||
void BatchNormalizationForwardInference(const CPUMatrix<ElemType>& scale, const CPUMatrix<ElemType>& bias, const CPUMatrix<ElemType>& runMean, const CPUMatrix<ElemType>& runInvStdDev, CPUMatrix<ElemType>& out) const;
|
||||
void BatchNormalizationBackward(const CPUMatrix<ElemType>& in, CPUMatrix<ElemType>& grad, const CPUMatrix<ElemType>& scale, const CPUMatrix<ElemType>& saveMean, const CPUMatrix<ElemType>& saveInvStdDev,
|
||||
CPUMatrix<ElemType>& scaleGrad, CPUMatrix<ElemType>& biasGrad) const;
|
||||
|
||||
public:
|
||||
static int SetNumThreads(int numThreads); // note: this does not depend on <ElemType>, i.e. you can call it on any <ElemType>
|
||||
|
|
|
@ -20,11 +20,11 @@ public:
|
|||
|
||||
public:
|
||||
CuDnnBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout)
|
||||
: Base(deviceId, inOutT, scaleBiasT, spatial, imageLayout),
|
||||
bool spatial, ImageLayoutKind imageLayout)
|
||||
: Base(deviceId, inOutT, spatial, imageLayout),
|
||||
m_cudnn(CuDnn::Instance()),
|
||||
m_inOutCuDnnT(inOutT, CuDnnTensor::GetDataType<ElemType>()),
|
||||
m_scaleBiasCuDnnT(scaleBiasT, CuDnnTensor::GetDataType<ElemType>())
|
||||
m_inOutCuDnnT(GetInOutTensor(inOutT), CuDnnTensor::GetDataType<ElemType>()),
|
||||
m_scaleBiasCuDnnT(GetScaleBiasTensor(inOutT, spatial), CuDnnTensor::GetDataType<ElemType>())
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -32,13 +32,14 @@ protected:
|
|||
using Base::m_deviceId;
|
||||
using Base::m_imageLayout;
|
||||
using Base::m_inOutT;
|
||||
using Base::m_scaleBiasT;
|
||||
using Base::m_spatial;
|
||||
|
||||
void EnsureCompatible() override
|
||||
{
|
||||
if (m_spatial && m_imageLayout == ImageLayoutKind::HWC)
|
||||
InvalidArgument("cuDNN batch normalization supports only cudnn(CHW) layout.");
|
||||
if (m_inOutT.GetRank() > 4)
|
||||
InvalidArgument("cuDNN batch normalization supports tensors of max 4 dimensions.");
|
||||
}
|
||||
|
||||
void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev,
|
||||
|
@ -53,6 +54,29 @@ protected:
|
|||
epsilon, ptr(saveMean), ptr(saveInvStdDev)));
|
||||
}
|
||||
|
||||
void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) override
|
||||
{
|
||||
m_inOutCuDnnT.UpdateBatchSize(in.GetNumCols());
|
||||
cudnnBatchNormMode_t mode = m_spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
CUDNN_CALL(cudnnBatchNormalizationForwardInference(*m_cudnn, mode, &C::One, &C::Zero, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(out),
|
||||
m_scaleBiasCuDnnT, ptr(scale), ptr(bias), ptr(runMean), ptr(runInvStdDev), CUDNN_BN_MIN_EPSILON));
|
||||
}
|
||||
|
||||
void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
|
||||
Mat& scaleGrad, Mat& biasGrad) override
|
||||
{
|
||||
m_inOutCuDnnT.UpdateBatchSize(srcGrad.GetNumCols());
|
||||
cudnnBatchNormMode_t mode = m_spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
// REVIEW alexeyk: remove once Philly is upgraded to prod version.
|
||||
#if CUDNN_PATCHLEVEL >= 7
|
||||
CUDNN_CALL(cudnnBatchNormalizationBackward(*m_cudnn, mode, &C::One, &C::One, &C::One, &C::One, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(srcGrad), m_inOutCuDnnT, ptr(grad),
|
||||
m_scaleBiasCuDnnT, ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
|
||||
#else
|
||||
CUDNN_CALL(cudnnBatchNormalizationBackward(*m_cudnn, mode, &C::One, &C::One, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(srcGrad), m_inOutCuDnnT, ptr(grad),
|
||||
m_scaleBiasCuDnnT, ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename ElemType>
|
||||
static ElemType* ptr(Matrix<ElemType>& src)
|
||||
|
@ -65,6 +89,29 @@ private:
|
|||
return src.BufferPointer();
|
||||
}
|
||||
|
||||
static TensorShape GetInOutTensor(const TensorShape& inOutT)
|
||||
{
|
||||
// cuDNN supports only 3D and 4D tensors (in cuDNN docs it's 4D and 5D dues to N dimension)
|
||||
// even for non-spatial inputs so expand the tensor if needed.
|
||||
if (inOutT.GetRank() > 2)
|
||||
return inOutT;
|
||||
SmallVector<size_t> v(std::max(inOutT.GetRank(), (size_t)3), 1);
|
||||
for (size_t i = 0; i < inOutT.GetRank(); i++)
|
||||
v[i] = inOutT[i];
|
||||
return TensorShape(v);
|
||||
}
|
||||
|
||||
static TensorShape GetScaleBiasTensor(const TensorShape& inOutT, bool spatial)
|
||||
{
|
||||
if (!spatial)
|
||||
return GetInOutTensor(inOutT);
|
||||
|
||||
const auto& t = GetInOutTensor(inOutT);
|
||||
SmallVector<size_t> v(t.GetRank(), 1);
|
||||
v[v.size() - 1] = t[t.GetRank() - 1];
|
||||
return TensorShape(v);
|
||||
}
|
||||
|
||||
private:
|
||||
using C = Consts<ElemType>;
|
||||
|
||||
|
@ -78,13 +125,46 @@ template class CuDnnBatchNormEngine<double>;
|
|||
|
||||
template <typename ElemType>
|
||||
std::unique_ptr<BatchNormEngine<ElemType>> CuDnnBatchNormEngineFactory<ElemType>::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial,
|
||||
ImageLayoutKind imageLayout)
|
||||
bool spatial, ImageLayoutKind imageLayout)
|
||||
{
|
||||
return std::make_unique<CuDnnBatchNormEngine<ElemType>>(deviceId, inOutT, scaleBiasT, spatial, imageLayout);
|
||||
return std::make_unique<CuDnnBatchNormEngine<ElemType>>(deviceId, inOutT, spatial, imageLayout);
|
||||
}
|
||||
|
||||
template class CuDnnBatchNormEngineFactory<float>;
|
||||
template class CuDnnBatchNormEngineFactory<double>;
|
||||
|
||||
CudaTimer::~CudaTimer()
|
||||
{
|
||||
// TODO: Should not throw if std::uncaught_exception()
|
||||
if (m_start != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_start)));
|
||||
if (m_stop != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
}
|
||||
void CudaTimer::Start()
|
||||
{
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t stop;
|
||||
if (m_start != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_start)));
|
||||
if (m_stop != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
CUDA_CALL(cudaEventCreate(&start));
|
||||
CUDA_CALL(cudaEventCreate(&stop));
|
||||
m_start = start;
|
||||
m_stop = stop;
|
||||
CUDA_CALL(cudaEventRecord(start, GetStream()));
|
||||
}
|
||||
void CudaTimer::Stop()
|
||||
{
|
||||
CUDA_CALL(cudaEventRecord(reinterpret_cast<cudaEvent_t>(m_stop), GetStream()));
|
||||
CUDA_CALL(cudaEventSynchronize(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
}
|
||||
float CudaTimer::Elapsed()
|
||||
{
|
||||
float ms;
|
||||
CUDA_CALL(cudaEventElapsedTime(&ms, reinterpret_cast<cudaEvent_t>(m_start), reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
return ms;
|
||||
}
|
||||
|
||||
} } }
|
||||
|
|
|
@ -977,38 +977,4 @@ bool CuDnnConvolutionEngineFactory<ElemType>::IsSupported(ConvolveGeometryPtr ge
|
|||
template class CuDnnConvolutionEngineFactory<float>;
|
||||
template class CuDnnConvolutionEngineFactory<double>;
|
||||
|
||||
CudaTimer::~CudaTimer()
|
||||
{
|
||||
// TODO: Should not throw if std::uncaught_exception()
|
||||
if (m_start != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_start)));
|
||||
if (m_stop != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
}
|
||||
void CudaTimer::Start()
|
||||
{
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t stop;
|
||||
if (m_start != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_start)));
|
||||
if (m_stop != nullptr)
|
||||
CUDA_CALL(cudaEventDestroy(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
CUDA_CALL(cudaEventCreate(&start));
|
||||
CUDA_CALL(cudaEventCreate(&stop));
|
||||
m_start = start;
|
||||
m_stop = stop;
|
||||
CUDA_CALL(cudaEventRecord(start, GetStream()));
|
||||
}
|
||||
void CudaTimer::Stop()
|
||||
{
|
||||
CUDA_CALL(cudaEventRecord(reinterpret_cast<cudaEvent_t>(m_stop), GetStream()));
|
||||
CUDA_CALL(cudaEventSynchronize(reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
}
|
||||
float CudaTimer::Elapsed()
|
||||
{
|
||||
float ms;
|
||||
CUDA_CALL(cudaEventElapsedTime(&ms, reinterpret_cast<cudaEvent_t>(m_start), reinterpret_cast<cudaEvent_t>(m_stop)));
|
||||
return ms;
|
||||
}
|
||||
|
||||
} } }
|
||||
|
|
|
@ -25,42 +25,10 @@ class CuDnnBatchNormEngineFactory
|
|||
{
|
||||
public:
|
||||
static std::unique_ptr<BatchNormEngine<ElemType>> Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
|
||||
const TensorShape& scaleBiasT, bool spatial,
|
||||
ImageLayoutKind imageLayout);
|
||||
bool spatial, ImageLayoutKind imageLayout);
|
||||
};
|
||||
|
||||
//template <class ElemType>
|
||||
//class CuDnnConvolutionEngineFactory : public ConvolutionEngineFactory<ElemType>
|
||||
//{
|
||||
//public:
|
||||
// using Base = ConvolutionEngineFactory<ElemType>;
|
||||
// using typename Base::Tensor4D;
|
||||
// using typename Base::Tensor4DPtr;
|
||||
// using typename Base::Filter;
|
||||
// using typename Base::FilterPtr;
|
||||
// using typename Base::ConvDesc;
|
||||
// using typename Base::ConvDescPtr;
|
||||
// using typename Base::PoolDesc;
|
||||
// using typename Base::PoolDescPtr;
|
||||
//
|
||||
// using typename Base::ConvEnginePtr;
|
||||
// using typename Base::PoolEnginePtr;
|
||||
//
|
||||
//public:
|
||||
// Tensor4DPtr CreateTensor(size_t w, size_t h, size_t c, size_t n) override;
|
||||
// FilterPtr CreateFilter(size_t w, size_t h, size_t c, size_t k) override;
|
||||
// ConvDescPtr CreateConvDescriptor(const Tensor4D& inT, const Filter& filterT,
|
||||
// size_t wStride, size_t hStride, bool padding) override;
|
||||
// PoolDescPtr CreatePoolDescriptor(typename PoolDesc::PoolKind kind, size_t w, size_t h, size_t wStride, size_t hStride, size_t wPad, size_t hPad) override;
|
||||
//
|
||||
// ConvEnginePtr CreateConvEngine(DEVICEID_TYPE deviceId, ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples, BatchNormImpl bnImpl) override;
|
||||
// PoolEnginePtr CreatePoolEngine(DEVICEID_TYPE deviceId, ImageLayoutKind imageLayout) override;
|
||||
//
|
||||
// static bool IsSupported(DEVICEID_TYPE deviceId);
|
||||
//};
|
||||
//
|
||||
|
||||
// REVIEW alexeyk: wrong place. It is currently used only in unit tests but I can't add it there because of the build issues.
|
||||
// REVIEW alexeyk: wrong place? It is currently used only in unit tests but I can't add it there because of the build issues.
|
||||
// Timer that can be used to measure CUDA calls.
|
||||
// Uses CUDA event and will synchronize(!) the stream when Stop is called.
|
||||
class MATH_API CudaTimer
|
||||
|
@ -79,4 +47,5 @@ private:
|
|||
void* m_start;
|
||||
void* m_stop;
|
||||
};
|
||||
|
||||
} } }
|
||||
|
|
|
@ -3053,6 +3053,8 @@ template <class ElemType>
|
|||
void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& bias, double expAvgFactor, GPUMatrix<ElemType>& runMean, GPUMatrix<ElemType>& runInvStdDev,
|
||||
GPUMatrix<ElemType>& out, double epsilon, GPUMatrix<ElemType>& saveMean, GPUMatrix<ElemType>& saveInvStdDev) const
|
||||
{
|
||||
assert((GetNumRows() % scale.GetNumRows()) == 0);
|
||||
|
||||
bool spatial = GetNumRows() != scale.GetNumRows();
|
||||
size_t vectorSize = GetNumRows();
|
||||
size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1;
|
||||
|
@ -3064,8 +3066,8 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
|
|||
SyncGuard syncGuard;
|
||||
if (spatial)
|
||||
{
|
||||
Call<ComputeSpatialBatchMeanAndInvStdDev, ElemType>(spatialSize, vectorSize, spatialSize, batchSize, m_pArray,
|
||||
expAvgFactor, runMean.m_pArray, runInvStdDev.m_pArray, epsilon,
|
||||
Call<ComputeSpatialBatchMeanAndInvStdDev, ElemType>(spatialSize, vectorSize, spatialSize, batchSize, m_pArray,
|
||||
expAvgFactor, runMean.m_pArray, runInvStdDev.m_pArray, epsilon,
|
||||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
else
|
||||
|
@ -3075,8 +3077,59 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
|
|||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize,
|
||||
spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray,
|
||||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray,
|
||||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::BatchNormalizationForwardInference(const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& bias,
|
||||
const GPUMatrix<ElemType>& runMean, const GPUMatrix<ElemType>& runInvStdDev,
|
||||
GPUMatrix<ElemType>& out) const
|
||||
{
|
||||
assert((GetNumRows() % scale.GetNumRows()) == 0);
|
||||
|
||||
bool spatial = GetNumRows() != scale.GetNumRows();
|
||||
size_t vectorSize = GetNumRows();
|
||||
size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1;
|
||||
size_t batchSize = GetNumCols();
|
||||
|
||||
assert(0 < vectorSize && vectorSize <= std::numeric_limits<int>::max());
|
||||
assert(0 < batchSize && batchSize <= std::numeric_limits<int>::max());
|
||||
|
||||
SyncGuard syncGuard;
|
||||
Call<NormalizeBatchTraining, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize,
|
||||
spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray,
|
||||
runMean.m_pArray, runInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::BatchNormalizationBackward(const GPUMatrix<ElemType>& in, GPUMatrix<ElemType>& grad, const GPUMatrix<ElemType>& scale,
|
||||
const GPUMatrix<ElemType>& saveMean, const GPUMatrix<ElemType>& saveInvStdDev,
|
||||
GPUMatrix<ElemType>& scaleGrad, GPUMatrix<ElemType>& biasGrad) const
|
||||
{
|
||||
assert((GetNumRows() % scale.GetNumRows()) == 0);
|
||||
|
||||
bool spatial = GetNumRows() != scale.GetNumRows();
|
||||
size_t vectorSize = GetNumRows();
|
||||
size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1;
|
||||
size_t batchSize = GetNumCols();
|
||||
|
||||
assert(0 < vectorSize && vectorSize <= std::numeric_limits<int>::max());
|
||||
assert(0 < batchSize && batchSize <= std::numeric_limits<int>::max());
|
||||
|
||||
SyncGuard syncGuard;
|
||||
if (spatial)
|
||||
{
|
||||
Call<ComputeSpatialScaleAndBiasGradients, ElemType>(spatialSize, vectorSize, spatialSize, batchSize, in.m_pArray, m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray,
|
||||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
Call<ComputeScaleAndBiasGradients, ElemType>(vectorSize, vectorSize, batchSize, in.m_pArray, m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray,
|
||||
saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
Call<BackpropagateBatchNormGradients, ElemType>(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial,
|
||||
in.m_pArray, m_pArray, grad.m_pArray, scale.m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray, saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream());
|
||||
}
|
||||
|
||||
#pragma region Static BLAS Functions
|
||||
|
|
|
@ -419,6 +419,9 @@ public:
|
|||
|
||||
void BatchNormalizationForward(const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& bias, double expAvgFactor, GPUMatrix<ElemType>& runMean, GPUMatrix<ElemType>& runInvStdDev,
|
||||
GPUMatrix<ElemType>& out, double epsilon, GPUMatrix<ElemType>& saveMean, GPUMatrix<ElemType>& saveInvStdDev) const;
|
||||
void BatchNormalizationForwardInference(const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& bias, const GPUMatrix<ElemType>& runMean, const GPUMatrix<ElemType>& runInvStdDev, GPUMatrix<ElemType>& out) const;
|
||||
void BatchNormalizationBackward(const GPUMatrix<ElemType>& in, GPUMatrix<ElemType>& grad, const GPUMatrix<ElemType>& scale, const GPUMatrix<ElemType>& saveMean, const GPUMatrix<ElemType>& saveInvStdDev,
|
||||
GPUMatrix<ElemType>& scaleGrad, GPUMatrix<ElemType>& biasGrad) const;
|
||||
|
||||
public:
|
||||
// static BLAS functions
|
||||
|
|
|
@ -4150,6 +4150,45 @@ void Matrix<ElemType>::BatchNormalizationForward(const Matrix<ElemType>& scale,
|
|||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::BatchNormalizationForwardInference(const Matrix<ElemType>& scale, const Matrix<ElemType>& bias,
|
||||
const Matrix<ElemType>& runMean, const Matrix<ElemType>& runInvStdDev,
|
||||
Matrix<ElemType>& out) const
|
||||
{
|
||||
DecideAndMoveToRightDevice(*this, out);
|
||||
|
||||
// REVIEW alexeyk: add sparse version.
|
||||
DISPATCH_MATRIX_ON_FLAG(this,
|
||||
this,
|
||||
m_CPUMatrix->BatchNormalizationForwardInference(*(scale.m_CPUMatrix), *(bias.m_CPUMatrix),
|
||||
*(runMean.m_CPUMatrix), *(runInvStdDev.m_CPUMatrix),
|
||||
*(out.m_CPUMatrix)),
|
||||
m_GPUMatrix->BatchNormalizationForwardInference(*(scale.m_GPUMatrix), *(bias.m_GPUMatrix),
|
||||
*(runMean.m_GPUMatrix), *(runInvStdDev.m_GPUMatrix),
|
||||
*(out.m_GPUMatrix)),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::BatchNormalizationBackward(const Matrix<ElemType>& in, Matrix<ElemType>& grad, const Matrix<ElemType>& scale, const Matrix<ElemType>& saveMean, const Matrix<ElemType>& saveInvStdDev,
|
||||
Matrix<ElemType>& scaleGrad, Matrix<ElemType>& biasGrad) const
|
||||
{
|
||||
DecideAndMoveToRightDevice(*this, grad);
|
||||
|
||||
// REVIEW alexeyk: add sparse version.
|
||||
DISPATCH_MATRIX_ON_FLAG(this,
|
||||
this,
|
||||
m_CPUMatrix->BatchNormalizationBackward(*(in.m_CPUMatrix), *(grad.m_CPUMatrix), *(scale.m_CPUMatrix),
|
||||
*(saveMean.m_CPUMatrix), *(saveInvStdDev.m_CPUMatrix),
|
||||
*(scaleGrad.m_CPUMatrix), *(biasGrad.m_CPUMatrix)),
|
||||
m_GPUMatrix->BatchNormalizationBackward(*(in.m_GPUMatrix), *(grad.m_GPUMatrix), *(scale.m_GPUMatrix),
|
||||
*(saveMean.m_GPUMatrix), *(saveInvStdDev.m_GPUMatrix),
|
||||
*(scaleGrad.m_GPUMatrix), *(biasGrad.m_GPUMatrix)),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
#pragma region Static BLAS Functions
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -470,6 +470,9 @@ public:
|
|||
|
||||
void BatchNormalizationForward(const Matrix<ElemType>& scale, const Matrix<ElemType>& bias, double expAvgFactor, Matrix<ElemType>& runMean, Matrix<ElemType>& runInvStdDev,
|
||||
Matrix<ElemType>& out, double epsilon, Matrix<ElemType>& saveMean, Matrix<ElemType>& saveInvStdDev) const;
|
||||
void BatchNormalizationForwardInference(const Matrix<ElemType>& scale, const Matrix<ElemType>& bias, const Matrix<ElemType>& runMean, const Matrix<ElemType>& runInvStdDev, Matrix<ElemType>& out) const;
|
||||
void BatchNormalizationBackward(const Matrix<ElemType>& in, Matrix<ElemType>& grad, const Matrix<ElemType>& scale, const Matrix<ElemType>& saveMean, const Matrix<ElemType>& saveInvStdDev,
|
||||
Matrix<ElemType>& scaleGrad, Matrix<ElemType>& biasGrad) const;
|
||||
|
||||
public:
|
||||
// TODO: why are these not static? And why are they here?
|
||||
|
|
|
@ -19,9 +19,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Test {
|
|||
using vec = std::vector<float>;
|
||||
using BNEng = BatchNormEngine<float>;
|
||||
|
||||
std::vector<std::tuple<TensorShape, TensorShape, size_t, bool, double>> GenerateBNTestConfigs()
|
||||
std::vector<std::tuple<TensorShape, size_t, bool, double>> GenerateBNTestConfigs()
|
||||
{
|
||||
std::vector<std::tuple<TensorShape, TensorShape, size_t, bool, double>> res;
|
||||
std::vector<std::tuple<TensorShape, size_t, bool, double>> res;
|
||||
// REVIEW alexeyk: how to test batches > 512? cuDNN does not support that so there is no baseline.
|
||||
double expAvgFactor = 1;
|
||||
// Per activation (non-spatial)
|
||||
|
@ -33,7 +33,7 @@ std::vector<std::tuple<TensorShape, TensorShape, size_t, bool, double>> Generate
|
|||
{
|
||||
for (size_t w : {6, 17, 126, 2048})
|
||||
{
|
||||
res.push_back(std::make_tuple(TensorShape(w, h, c), TensorShape(w, h, c), n, false, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(w, h, c), n, false, expAvgFactor));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,21 +47,24 @@ std::vector<std::tuple<TensorShape, TensorShape, size_t, bool, double>> Generate
|
|||
{
|
||||
for (size_t w : {2, 11, 16})
|
||||
{
|
||||
res.push_back(std::make_tuple(TensorShape(w, h, c), TensorShape(1, 1, c), n, true, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(w, h, c), n, true, expAvgFactor));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// For perf testing (similar to first layers of ResNet).
|
||||
res.push_back(std::make_tuple(TensorShape(56, 56, 64), TensorShape(1, 1, 64), 64, true, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(56, 56, 64), 64, true, expAvgFactor));
|
||||
// Next test will fail in cuDNN due to bug we discovered (and reported to NVIDIA).
|
||||
//res.push_back(std::make_tuple(std::move(fact.CreateTensor(2, 2, 2048, 2)), true));
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2048), TensorShape(1, 1, 2048), 64, true, expAvgFactor));
|
||||
//res.push_back(std::make_tuple(TensorShape(2, 2, 2048), 2, true, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2048), 64, true, expAvgFactor));
|
||||
|
||||
// Test running mean/isd.
|
||||
expAvgFactor = 0.1;
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2), TensorShape(2, 2, 2), 8, false, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2), TensorShape(1, 1, 2), 8, true, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2), 8, false, expAvgFactor));
|
||||
res.push_back(std::make_tuple(TensorShape(2, 2, 2), 8, true, expAvgFactor));
|
||||
|
||||
// Test 1D tensor expansion (cuDNN supports 3D and 4D tensors only).
|
||||
res.push_back(std::make_tuple(TensorShape(2), 8, false, expAvgFactor));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -89,16 +92,15 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
|
|||
for (const auto& cfg : GenerateBNTestConfigs())
|
||||
{
|
||||
const auto& inOutT = std::get<0>(cfg);
|
||||
const auto& scaleT = std::get<1>(cfg);
|
||||
size_t batchSize = std::get<2>(cfg);
|
||||
bool spatial = std::get<3>(cfg);
|
||||
double expAvg = std::get<4>(cfg);
|
||||
size_t batchSize = std::get<1>(cfg);
|
||||
bool spatial = std::get<2>(cfg);
|
||||
double expAvg = std::get<3>(cfg);
|
||||
double eps = 1e-5; // CUDNN_BN_MIN_EPSILON
|
||||
|
||||
auto engCudnn = BNEng::Create(baseDeviceId, inOutT, scaleT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn);
|
||||
auto engCntk = BNEng::Create(deviceId, inOutT, scaleT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk);
|
||||
auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn);
|
||||
auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk);
|
||||
|
||||
size_t crow = inOutT[0] * inOutT[1] * inOutT[2];
|
||||
size_t crow = inOutT.GetNumElements();
|
||||
size_t ccol = batchSize;
|
||||
|
||||
vec buf(crow * ccol);
|
||||
|
@ -106,7 +108,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
|
|||
SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix inB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
size_t crowScaleBias = scaleT.GetNumElements();
|
||||
size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements();
|
||||
buf.resize(crowScaleBias);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
|
@ -145,7 +147,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
|
|||
time2.Stop();
|
||||
|
||||
std::stringstream tmsg;
|
||||
tmsg << "inOut tensor: " << (std::string)inOutT << ", scaleBias tensor: " << (std::string)scaleT
|
||||
tmsg << "inOut tensor: " << (std::string)inOutT
|
||||
<< ", spatial = " << (spatial ? "true" : "false")
|
||||
<< ", expAvg = " << expAvg << ")";
|
||||
std::string msg = " are not equal, " + tmsg.str();
|
||||
|
@ -194,100 +196,104 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
|
|||
}
|
||||
}
|
||||
|
||||
//BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInference)
|
||||
//{
|
||||
// if (!IsCuDnnSupported())
|
||||
// return;
|
||||
//
|
||||
// std::mt19937 rng(0);
|
||||
// std::normal_distribution<float> nd;
|
||||
//
|
||||
// auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix
|
||||
// {
|
||||
// data.resize(r * 3 * c);
|
||||
// std::fill(begin(data), end(data), std::numeric_limits<float>::quiet_NaN());
|
||||
// std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); });
|
||||
// buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data());
|
||||
// // Get center slice.
|
||||
// return buf.ColumnSlice(c, c);
|
||||
// };
|
||||
//
|
||||
// for (int deviceId : {0})
|
||||
// {
|
||||
// auto fact = ConvFact::Create(deviceId, ConvFact::EngineType::Auto, ImageLayoutKind::CHW);
|
||||
// auto engCudnn = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::CuDnn);
|
||||
// auto engCntk = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::Cntk);
|
||||
// for (auto& cfg : GenerateBNTestConfigs(*fact))
|
||||
// {
|
||||
// auto& t = *std::move(std::get<0>(cfg));
|
||||
// bool spatial = std::get<1>(cfg);
|
||||
//
|
||||
// size_t crow = t.w() * t.h() * t.c();
|
||||
// size_t ccol = t.n();
|
||||
//
|
||||
// vec buf(crow * t.n());
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// Tensor4DPtr scaleBiasT = spatial ? fact->CreateTensor(1, 1, t.c(), 1) : fact->CreateTensor(t.w(), t.h(), t.c(), 1);
|
||||
// size_t crowScaleBias = scaleBiasT->w() * scaleBiasT->h() * scaleBiasT->c();
|
||||
// buf.resize(crowScaleBias);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix bias(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix runMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix runInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// SingleMatrix outBuf(deviceId);
|
||||
// SingleMatrix out = initMat(outBuf, crow, ccol, buf);
|
||||
// SingleMatrix outExp(out);
|
||||
//
|
||||
// CudaTimer time1;
|
||||
// time1.Start();
|
||||
// engCntk->NormalizeBatchInference(t, in, *scaleBiasT, scale, bias, spatial, runMean, runInvStdDev, out);
|
||||
// time1.Stop();
|
||||
//
|
||||
// CudaTimer time2;
|
||||
// time2.Start();
|
||||
// engCudnn->NormalizeBatchInference(t, in, *scaleBiasT, scale, bias, spatial, runMean, runInvStdDev, outExp);
|
||||
// time2.Stop();
|
||||
//
|
||||
// std::stringstream tmsg;
|
||||
// tmsg << "tensor: (w = " << t.w() << ", h = " << t.h() << ", c = " << t.c() << ", n = " << t.n() << ", spatial = " << (spatial ? "true" : "false") << ")";
|
||||
// std::string msg = " are not equal, " + tmsg.str();
|
||||
// std::string msgNan = " has NaNs, " + tmsg.str();
|
||||
// std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();
|
||||
//
|
||||
// float relErr = Err<float>::Rel;
|
||||
// float absErr = Err<float>::Abs;
|
||||
// std::string emsg;
|
||||
//
|
||||
// BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan);
|
||||
// BOOST_REQUIRE_MESSAGE(CheckEqual(out, outExp, emsg, relErr, absErr * 20), "out" << msg << ". " << emsg);
|
||||
// BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan);
|
||||
// // REVIEW alexeyk: add cases for testing numerical stability.
|
||||
//
|
||||
//#ifndef _DEBUG
|
||||
// float elapsedCntk = time1.Elapsed();
|
||||
// float elapsedCudnn = time2.Elapsed();
|
||||
// // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation.
|
||||
// // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN.
|
||||
// if (crow >= 32 && ccol >= 32)
|
||||
// {
|
||||
// // Use conservative estimates.
|
||||
// int speedup = 2;
|
||||
// BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn,
|
||||
// "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str());
|
||||
// }
|
||||
//#endif
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInference)
|
||||
{
|
||||
std::mt19937 rng(0);
|
||||
std::normal_distribution<float> nd;
|
||||
|
||||
auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix
|
||||
{
|
||||
data.resize(r * 3 * c);
|
||||
std::fill(begin(data), end(data), std::numeric_limits<float>::quiet_NaN());
|
||||
std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); });
|
||||
buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data());
|
||||
// Get center slice.
|
||||
return buf.ColumnSlice(c, c);
|
||||
};
|
||||
|
||||
int baseDeviceId = 0;
|
||||
for (int deviceId : {0})
|
||||
{
|
||||
for (const auto& cfg : GenerateBNTestConfigs())
|
||||
{
|
||||
const auto& inOutT = std::get<0>(cfg);
|
||||
size_t batchSize = std::get<1>(cfg);
|
||||
bool spatial = std::get<2>(cfg);
|
||||
|
||||
auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn);
|
||||
auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk);
|
||||
|
||||
size_t crow = inOutT.GetNumElements();
|
||||
size_t ccol = batchSize;
|
||||
|
||||
vec buf(crow * ccol);
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix inB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements();
|
||||
buf.resize(crowScaleBias);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix scaleB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix bias(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix biasB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
SingleMatrix runMeanBuf(deviceId);
|
||||
SingleMatrix runMean = initMat(runMeanBuf, crowScaleBias, 1, buf);
|
||||
SingleMatrix runMeanB(runMean.DeepClone(), baseDeviceId);
|
||||
SingleMatrix runInvStdDevBuf(deviceId);
|
||||
SingleMatrix runInvStdDev = initMat(runInvStdDevBuf, crowScaleBias, 1, buf);
|
||||
SingleMatrix runInvStdDevB(runInvStdDev.DeepClone(), baseDeviceId);
|
||||
|
||||
SingleMatrix outBuf(deviceId);
|
||||
SingleMatrix out = initMat(outBuf, crow, ccol, buf);
|
||||
SingleMatrix outB(out.DeepClone(), baseDeviceId);
|
||||
|
||||
CudaTimer time1;
|
||||
time1.Start();
|
||||
engCntk->ForwardInference(in, scale, bias, runMean, runInvStdDev, out);
|
||||
time1.Stop();
|
||||
|
||||
CudaTimer time2;
|
||||
time2.Start();
|
||||
engCudnn->ForwardInference(inB, scaleB, biasB, runMeanB, runInvStdDevB, outB);
|
||||
time2.Stop();
|
||||
|
||||
std::stringstream tmsg;
|
||||
tmsg << "inOut tensor: " << (std::string)inOutT
|
||||
<< ", spatial = " << (spatial ? "true" : "false");
|
||||
std::string msg = " are not equal, " + tmsg.str();
|
||||
std::string msgNan = " has NaNs, " + tmsg.str();
|
||||
std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();
|
||||
|
||||
float relErr = Err<float>::Rel;
|
||||
float absErr = Err<float>::Abs;
|
||||
std::string emsg;
|
||||
|
||||
BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan);
|
||||
BOOST_REQUIRE_MESSAGE(CheckEqual(out, outB, emsg, relErr, absErr * 20), "out" << msg << ". " << emsg);
|
||||
BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan);
|
||||
// REVIEW alexeyk: add cases for testing numerical stability.
|
||||
|
||||
#ifndef _DEBUG
|
||||
float elapsedCntk = time1.Elapsed();
|
||||
float elapsedCudnn = time2.Elapsed();
|
||||
// Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation.
|
||||
// For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN.
|
||||
if (crow >= 32 && ccol >= 32)
|
||||
{
|
||||
// Use conservative estimates.
|
||||
int speedup = 2;
|
||||
BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn,
|
||||
"CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
//BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInferenceCpu)
|
||||
//{
|
||||
|
@ -373,118 +379,122 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
|
|||
// BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan);
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//BOOST_AUTO_TEST_CASE(BatchNormalizationBackward)
|
||||
//{
|
||||
// if (!IsCuDnnSupported())
|
||||
// return;
|
||||
//
|
||||
// std::mt19937 rng(0);
|
||||
// std::normal_distribution<float> nd;
|
||||
//
|
||||
// auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix
|
||||
// {
|
||||
// data.resize(r * 3 * c);
|
||||
// std::fill(begin(data), end(data), std::numeric_limits<float>::quiet_NaN());
|
||||
// std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); });
|
||||
// buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data());
|
||||
// // Get center slice.
|
||||
// return buf.ColumnSlice(c, c);
|
||||
// };
|
||||
//
|
||||
// for (int deviceId : {0})
|
||||
// {
|
||||
// auto fact = ConvFact::Create(deviceId, ConvFact::EngineType::Auto, ImageLayoutKind::CHW);
|
||||
// auto engCudnn = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::CuDnn);
|
||||
// auto engCntk = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::Cntk);
|
||||
// for (auto& cfg : GenerateBNTestConfigs(*fact))
|
||||
// {
|
||||
// auto& t = *std::move(std::get<0>(cfg));
|
||||
// bool spatial = std::get<1>(cfg);
|
||||
//
|
||||
// size_t crow = t.w() * t.h() * t.c();
|
||||
// size_t ccol = t.n();
|
||||
//
|
||||
// vec buf(crow * t.n());
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix x(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix dy(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// Tensor4DPtr scaleBiasT = spatial ? fact->CreateTensor(1, 1, t.c(), 1) : fact->CreateTensor(t.w(), t.h(), t.c(), 1);
|
||||
// size_t crowScaleBias = scaleBiasT->w() * scaleBiasT->h() * scaleBiasT->c();
|
||||
// buf.resize(crowScaleBias);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix saveMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
// SingleMatrix saveInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
//
|
||||
// SingleMatrix dScaleBuf(deviceId);
|
||||
// SingleMatrix dScale = initMat(dScaleBuf, crowScaleBias, 1, buf);
|
||||
// SingleMatrix dScaleExp(dScale);
|
||||
// SingleMatrix dBiasBuf(deviceId);
|
||||
// SingleMatrix dBias = initMat(dBiasBuf, crowScaleBias, 1, buf);
|
||||
// SingleMatrix dBiasExp(dBias);
|
||||
//
|
||||
// SingleMatrix dxBuf(deviceId);
|
||||
// SingleMatrix dx = initMat(dxBuf, crow, ccol, buf);
|
||||
// SingleMatrix dxExp(dx);
|
||||
//
|
||||
// CudaTimer time1;
|
||||
// time1.Start();
|
||||
// engCntk->BackwardNormalizeBatch(t, x, dy, dx, *scaleBiasT, scale, spatial, saveMean, saveInvStdDev, dScale, dBias);
|
||||
// time1.Stop();
|
||||
//
|
||||
// CudaTimer time2;
|
||||
// time2.Start();
|
||||
// engCudnn->BackwardNormalizeBatch(t, x, dy, dxExp, *scaleBiasT, scale, spatial, saveMean, saveInvStdDev, dScaleExp, dBiasExp);
|
||||
// time2.Stop();
|
||||
//
|
||||
// std::stringstream tmsg;
|
||||
// tmsg << "tensor: (w = " << t.w() << ", h = " << t.h() << ", c = " << t.c() << ", n = " << t.n() << ", spatial = " << (spatial ? "true" : "false") << ")";
|
||||
// std::string msg = " are not equal, " + tmsg.str();
|
||||
// std::string msgNan = " has NaNs, " + tmsg.str();
|
||||
// std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();
|
||||
//
|
||||
// float relErr = Err<float>::Rel;
|
||||
// float absErr = Err<float>::Abs;
|
||||
// std::string emsg;
|
||||
//
|
||||
// BOOST_REQUIRE_MESSAGE(!dx.HasNan("dx"), "dx" << msgNan);
|
||||
// BOOST_REQUIRE_MESSAGE(CheckEqual(dx, dxExp, emsg, relErr * 16, absErr * 8), "dx" << msg << ". " << emsg);
|
||||
// BOOST_REQUIRE_MESSAGE(CountNans(dxBuf) == crow * 2 * ccol, "out" << msgNotNan);
|
||||
// // REVIEW alexeyk: add cases for testing numerical stability.
|
||||
//
|
||||
// BOOST_REQUIRE_MESSAGE(!dScale.HasNan("dScale"), "dScale" << msgNan);
|
||||
// BOOST_REQUIRE_MESSAGE(CheckEqual(dScale, dScaleExp, emsg, relErr * 32, absErr * 8), "dScale" << msg << ". " << emsg);
|
||||
// BOOST_REQUIRE_MESSAGE(CountNans(dScaleBuf) == crowScaleBias * 2, "dScale" << msgNotNan);
|
||||
//
|
||||
// BOOST_REQUIRE_MESSAGE(!dBias.HasNan("dBias"), "dBias" << msgNan);
|
||||
// BOOST_REQUIRE_MESSAGE(CheckEqual(dBias, dBiasExp, emsg, relErr * 32, absErr * 8), "dBias" << msg << ". " << emsg);
|
||||
// BOOST_REQUIRE_MESSAGE(CountNans(dBiasBuf) == crowScaleBias * 2, "dBias" << msgNotNan);
|
||||
//
|
||||
//#ifndef _DEBUG
|
||||
// float elapsedCntk = time1.Elapsed();
|
||||
// float elapsedCudnn = time2.Elapsed();
|
||||
// // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation.
|
||||
// // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN.
|
||||
// if (crow >= 32 && ccol >= 32)
|
||||
// {
|
||||
// // Use conservative estimates.
|
||||
// float speedup = 1.3f;
|
||||
// BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn,
|
||||
// "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str());
|
||||
// }
|
||||
//#endif
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(BatchNormalizationBackward)
|
||||
{
|
||||
std::mt19937 rng(0);
|
||||
std::normal_distribution<float> nd;
|
||||
|
||||
auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix
|
||||
{
|
||||
data.resize(r * 3 * c);
|
||||
std::fill(begin(data), end(data), std::numeric_limits<float>::quiet_NaN());
|
||||
std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); });
|
||||
buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data());
|
||||
// Get center slice.
|
||||
return buf.ColumnSlice(c, c);
|
||||
};
|
||||
|
||||
int baseDeviceId = 0;
|
||||
for (int deviceId : {0})
|
||||
{
|
||||
for (const auto& cfg : GenerateBNTestConfigs())
|
||||
{
|
||||
const auto& inOutT = std::get<0>(cfg);
|
||||
size_t batchSize = std::get<1>(cfg);
|
||||
bool spatial = std::get<2>(cfg);
|
||||
|
||||
auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn);
|
||||
auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk);
|
||||
|
||||
size_t crow = inOutT.GetNumElements();
|
||||
size_t ccol = batchSize;
|
||||
|
||||
vec buf(crow * ccol);
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix x(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix xB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix dy(crow, ccol, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix dyB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements();
|
||||
buf.resize(crowScaleBias);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix scaleB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix saveMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix saveMeanB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
|
||||
SingleMatrix saveInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal);
|
||||
SingleMatrix saveInvStdDevB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal);
|
||||
|
||||
SingleMatrix dScaleBuf(deviceId);
|
||||
SingleMatrix dScale = initMat(dScaleBuf, crowScaleBias, 1, buf);
|
||||
SingleMatrix dScaleB(dScale.DeepClone(), baseDeviceId);
|
||||
SingleMatrix dBiasBuf(deviceId);
|
||||
SingleMatrix dBias = initMat(dBiasBuf, crowScaleBias, 1, buf);
|
||||
SingleMatrix dBiasB(dBias.DeepClone(), baseDeviceId);
|
||||
|
||||
SingleMatrix dxBuf(deviceId);
|
||||
SingleMatrix dx = initMat(dxBuf, crow, ccol, buf);
|
||||
SingleMatrix dxB(dx.DeepClone(), baseDeviceId);
|
||||
|
||||
CudaTimer time1;
|
||||
time1.Start();
|
||||
engCntk->Backward(x, dy, dx, scale, saveMean, saveInvStdDev, dScale, dBias);
|
||||
time1.Stop();
|
||||
|
||||
CudaTimer time2;
|
||||
time2.Start();
|
||||
engCudnn->Backward(xB, dyB, dxB, scaleB, saveMeanB, saveInvStdDevB, dScaleB, dBiasB);
|
||||
time2.Stop();
|
||||
|
||||
std::stringstream tmsg;
|
||||
tmsg << "inOut tensor: " << (std::string)inOutT
|
||||
<< ", spatial = " << (spatial ? "true" : "false");
|
||||
std::string msg = " are not equal, " + tmsg.str();
|
||||
std::string msgNan = " has NaNs, " + tmsg.str();
|
||||
std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();
|
||||
|
||||
float relErr = Err<float>::Rel;
|
||||
float absErr = Err<float>::Abs;
|
||||
std::string emsg;
|
||||
|
||||
BOOST_REQUIRE_MESSAGE(!dx.HasNan("dx"), "dx" << msgNan);
|
||||
BOOST_REQUIRE_MESSAGE(CheckEqual(dx, dxB, emsg, relErr * 16, absErr * 8), "dx" << msg << ". " << emsg);
|
||||
BOOST_REQUIRE_MESSAGE(CountNans(dxBuf) == crow * 2 * ccol, "out" << msgNotNan);
|
||||
// REVIEW alexeyk: add cases for testing numerical stability.
|
||||
|
||||
BOOST_REQUIRE_MESSAGE(!dScale.HasNan("dScale"), "dScale" << msgNan);
|
||||
BOOST_REQUIRE_MESSAGE(CheckEqual(dScale, dScaleB, emsg, relErr * 32, absErr * 8), "dScale" << msg << ". " << emsg);
|
||||
BOOST_REQUIRE_MESSAGE(CountNans(dScaleBuf) == crowScaleBias * 2, "dScale" << msgNotNan);
|
||||
|
||||
BOOST_REQUIRE_MESSAGE(!dBias.HasNan("dBias"), "dBias" << msgNan);
|
||||
BOOST_REQUIRE_MESSAGE(CheckEqual(dBias, dBiasB, emsg, relErr * 32, absErr * 8), "dBias" << msg << ". " << emsg);
|
||||
BOOST_REQUIRE_MESSAGE(CountNans(dBiasBuf) == crowScaleBias * 2, "dBias" << msgNotNan);
|
||||
|
||||
#ifndef _DEBUG
|
||||
float elapsedCntk = time1.Elapsed();
|
||||
float elapsedCudnn = time2.Elapsed();
|
||||
// Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation.
|
||||
// For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN.
|
||||
if (crow >= 32 && ccol >= 32)
|
||||
{
|
||||
// Use conservative estimates.
|
||||
float speedup = 1.3f;
|
||||
BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn,
|
||||
"CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE_END()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче