From bf413614f819945e56fc6f5a324279282bca990e Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Tue, 15 Mar 2016 13:15:48 -0700 Subject: [PATCH] Finished refactoring of BN node/engine. --- Source/ComputationNetworkLib/TrainingNodes.h | 44 +- Source/Math/BatchNormalizationEngine.cpp | 81 +++- Source/Math/BatchNormalizationEngine.h | 19 +- Source/Math/CPUMatrix.cpp | 44 ++ Source/Math/CPUMatrix.h | 3 + Source/Math/CuDnnBatchNormalization.cu | 96 +++- Source/Math/CuDnnConvolutionEngine.cu | 34 -- Source/Math/CuDnnFactories.h | 37 +- Source/Math/GPUMatrix.cu | 61 ++- Source/Math/GPUMatrix.h | 3 + Source/Math/Matrix.cpp | 39 ++ Source/Math/Matrix.h | 3 + .../BatchNormalizatinEngineTests.cpp | 458 +++++++++--------- 13 files changed, 550 insertions(+), 372 deletions(-) diff --git a/Source/ComputationNetworkLib/TrainingNodes.h b/Source/ComputationNetworkLib/TrainingNodes.h index 40f794910..f804241a2 100644 --- a/Source/ComputationNetworkLib/TrainingNodes.h +++ b/Source/ComputationNetworkLib/TrainingNodes.h @@ -6,7 +6,7 @@ #include "Basics.h" #include "ComputationNode.h" -#include "ConvolutionEngine.h" +#include "BatchNormalizationEngine.h" #include #include @@ -1657,16 +1657,12 @@ public: const Matrix& scale = Input(1)->Value(); const Matrix& bias = Input(2)->Value(); - size_t batchSize = sliceInputValue.GetNumCols(); - m_inT->setN(batchSize); - assert(m_convEng != nullptr); - auto sliceInputGrad = Input(0)->GradientFor(fr); m_dScale->Resize(scale); m_dBias->Resize(bias); // Compute all derivatives in one step. Save derivatives with respect to scale and bias in temp matrices. - m_convEng->BackwardNormalizeBatch(*m_inT, sliceInputValue, sliceOutputGrad, sliceInputGrad, *m_scaleBiasT, scale, m_spatial, - *m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias); + m_bnEng->Backward(sliceInputValue, sliceOutputGrad, sliceInputGrad, scale, + *m_saveMean, *m_saveInvStdDev, *m_dScale, *m_dBias); } else if (inputIndex == 1) // derivative with respect to the scale { @@ -1707,14 +1703,11 @@ public: Matrix sliceOutputValue = ValueFor(fr); - size_t batchSize = sliceInputValue.GetNumCols(); - m_inT->setN(batchSize); - assert(m_convEng != nullptr); #if NANCHECK sliceInputValue.HasNan("BatchNormalization-input"); #endif if (m_eval) - m_convEng->NormalizeBatchInference(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, runMean, runInvStdDev, sliceOutputValue); + m_bnEng->NormalizeBatchInference(sliceInputValue, scale, bias, runMean, runInvStdDev, sliceOutputValue); else { double expAvgFactor; @@ -1736,8 +1729,8 @@ public: m_saveMean->Resize(runMean); m_saveInvStdDev->Resize(runMean); - m_convEng->NormalizeBatch(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, expAvgFactor, runMean, runInvStdDev, - sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev); + m_bnEng->Forward(sliceInputValue, scale, bias, expAvgFactor, runMean, runInvStdDev, + sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev); m_mbCount++; } @@ -1772,24 +1765,10 @@ public: auto shape = GetSampleLayout(); - if (m_factory == nullptr) - m_factory = ConvolutionEngineFactory::Create(m_deviceId, ConvolutionEngineFactory::EngineType::Auto, m_imageLayoutKind); - if (m_convEng == nullptr) - m_convEng = m_factory->CreateConvEngine(m_deviceId, m_imageLayoutKind, 0, m_useCntkEngine ? BatchNormImpl::Cntk : BatchNormImpl::CuDnn); - if (m_spatial) + if (m_bnEng == nullptr) { - auto dims = ImageDimensions(shape, m_imageLayoutKind); - if (m_inT == nullptr) - m_inT = m_factory->CreateTensor(dims.m_width, dims.m_height, dims.m_numChannels, 1); - if (m_scaleBiasT == nullptr) - m_scaleBiasT = m_factory->CreateTensor(1, 1, dims.m_numChannels, 1); - } - else - { - if (m_inT == nullptr) - m_inT = m_factory->CreateTensor(shape.GetNumElements(), 1, 1, 1); - if (m_scaleBiasT == nullptr) - m_scaleBiasT = m_factory->CreateTensor(shape.GetNumElements(), 1, 1, 1); + m_bnEng = BatchNormEngine::Create(m_deviceId, shape, m_spatial, m_imageLayoutKind, + m_useCntkEngine ? BatchNormEngineKind::Cntk : BatchNormEngineKind::CuDnn); } } } @@ -1869,10 +1848,7 @@ private: // Stores bias derivatives. shared_ptr> m_dBias; - std::unique_ptr> m_factory; - std::unique_ptr> m_convEng; - std::unique_ptr m_inT; - std::unique_ptr m_scaleBiasT; + std::unique_ptr> m_bnEng; }; template class BatchNormalizationNode; diff --git a/Source/Math/BatchNormalizationEngine.cpp b/Source/Math/BatchNormalizationEngine.cpp index 8d1fb3df4..cee5537ad 100644 --- a/Source/Math/BatchNormalizationEngine.cpp +++ b/Source/Math/BatchNormalizationEngine.cpp @@ -13,33 +13,50 @@ template void BatchNormEngine::Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) { - assert(in.GetNumRows() == out.GetNumRows()); + assert(in.GetNumRows() == m_inOutT.GetNumElements()); + assert(out.GetNumRows() == m_inOutT.GetNumElements()); assert(in.GetNumCols() == out.GetNumCols()); assert(std::isfinite(epsilon) && epsilon > 0); assert(std::isfinite(expAvgFactor) && (0 < expAvgFactor && expAvgFactor <= 1)); if (!m_spatial) { - assert(in.GetNumRows() == scale.GetNumRows()); - assert(in.GetNumRows() == bias.GetNumRows()); - assert(in.GetNumRows() == runMean.GetNumRows()); - assert(in.GetNumRows() == runInvStdDev.GetNumRows()); - assert(in.GetNumRows() == saveMean.GetNumRows()); - assert(in.GetNumRows() == saveInvStdDev.GetNumRows()); + assert(m_inOutT.GetNumElements() == scale.GetNumRows()); + assert(m_inOutT.GetNumElements() == bias.GetNumRows()); + assert(m_inOutT.GetNumElements() == runMean.GetNumRows()); + assert(m_inOutT.GetNumElements() == runInvStdDev.GetNumRows()); + assert(m_inOutT.GetNumElements() == saveMean.GetNumRows()); + assert(m_inOutT.GetNumElements() == saveInvStdDev.GetNumRows()); } else { - assert((in.GetNumRows() % scale.GetNumRows()) == 0); - assert((in.GetNumRows() % bias.GetNumRows()) == 0); - assert((in.GetNumRows() % runMean.GetNumRows()) == 0); - assert((in.GetNumRows() % runInvStdDev.GetNumRows()) == 0); - assert((in.GetNumRows() % saveMean.GetNumRows()) == 0); - assert((in.GetNumRows() % saveInvStdDev.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % scale.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % bias.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % runMean.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % runInvStdDev.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % saveMean.GetNumRows()) == 0); + assert((m_inOutT.GetNumElements() % saveInvStdDev.GetNumRows()) == 0); } EnsureCompatible(); ForwardCore(in, scale, bias, expAvgFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev); } +template +void BatchNormEngine::ForwardInference(const Mat& in, const Mat& scale, const Mat& bias, + const Mat& runMean, const Mat& runInvStdDev, Mat& out) +{ + EnsureCompatible(); + ForwardInferenceCore(in, scale, bias, runMean, runInvStdDev, out); +} + +template +void BatchNormEngine::Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, + const Mat& saveMean, const Mat& saveInvStdDev, Mat& scaleGrad, Mat& biasGrad) +{ + EnsureCompatible(); + BackwardCore(in, srcGrad, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad); +} + template class BatchNormEngine; template class BatchNormEngine; @@ -52,8 +69,8 @@ public: public: CntkBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout) - : Base(deviceId, inOutT, scaleBiasT, spatial, imageLayout) + bool spatial, ImageLayoutKind imageLayout) + : Base(deviceId, inOutT, spatial, imageLayout) { } @@ -61,7 +78,6 @@ protected: using Base::m_deviceId; using Base::m_imageLayout; using Base::m_inOutT; - using Base::m_scaleBiasT; using Base::m_spatial; void EnsureCompatible() override @@ -75,34 +91,47 @@ protected: { in.BatchNormalizationForward(scale, bias, expAvgFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev); } + + void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) override + { + in.BatchNormalizationForwardInference(scale, bias, runMean, runInvStdDev, out); + } + + void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, + Mat& scaleGrad, Mat& biasGrad) override + { + srcGrad.BatchNormalizationBackward(in, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad); + } }; template class CntkBatchNormEngine; template class CntkBatchNormEngine; +template +bool HasFlag(T src, T testFlag) +{ + return ((int)src & (int)testFlag) != 0; +} + template std::unique_ptr> BatchNormEngine::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout, + bool spatial, ImageLayoutKind imageLayout, BatchNormEngineKind enabledEngines = BatchNormEngineKind::All) { - if (spatial && imageLayout == ImageLayoutKind::HWC) - InvalidArgument("Batch normalization is not supported for legacy(HWC) layout. Please use cudnn(CHW) layout instead."); - - auto isEnabled = [=](BatchNormEngineKind eng) { return ((int)enabledEngines & (int)eng) != 0; }; // Use CNTK as default batch norm engine. - if (isEnabled(BatchNormEngineKind::Cntk)) + if (HasFlag(enabledEngines, BatchNormEngineKind::Cntk)) { fprintf(stderr, "Using CNTK batch normalization engine.\n"); - return std::make_unique>(deviceId, inOutT, scaleBiasT, spatial, imageLayout); + return std::make_unique>(deviceId, inOutT, spatial, imageLayout); } - if (isEnabled(BatchNormEngineKind::CuDnn)) + if (HasFlag(enabledEngines, BatchNormEngineKind::CuDnn)) { fprintf(stderr, "Using cuDNN batch normalization engine.\n"); - return CuDnnBatchNormEngineFactory::Create(deviceId, inOutT, scaleBiasT, spatial, imageLayout); + return CuDnnBatchNormEngineFactory::Create(deviceId, inOutT, spatial, imageLayout); } - RuntimeError("Failed to find appropriate batch normalization engine."); + RuntimeError("Could not find appropriate batch normalization engine."); } } } } diff --git a/Source/Math/BatchNormalizationEngine.h b/Source/Math/BatchNormalizationEngine.h index 0209a0a11..8ae5a0d06 100644 --- a/Source/Math/BatchNormalizationEngine.h +++ b/Source/Math/BatchNormalizationEngine.h @@ -37,22 +37,21 @@ public: void Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev); - //void ForwardInference(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, - // Mat& out); + void ForwardInference(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out); - //void Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, - // Mat& scaleGrad, Mat& biasGrad); + void Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, + Mat& scaleGrad, Mat& biasGrad); static std::unique_ptr> Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout, + bool spatial, ImageLayoutKind imageLayout, BatchNormEngineKind enabledEngines = BatchNormEngineKind::All); DISABLE_COPY_AND_MOVE(BatchNormEngine); protected: BatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout) - : m_deviceId(deviceId), m_inOutT(inOutT), m_scaleBiasT(scaleBiasT), m_spatial(spatial), m_imageLayout(imageLayout) + bool spatial, ImageLayoutKind imageLayout) + : m_deviceId(deviceId), m_inOutT(inOutT), m_spatial(spatial), m_imageLayout(imageLayout) { } @@ -61,10 +60,14 @@ protected: virtual void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) = 0; + virtual void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) = 0; + + virtual void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, + Mat& scaleGrad, Mat& biasGrad) = 0; + protected: DEVICEID_TYPE m_deviceId; TensorShape m_inOutT; - TensorShape m_scaleBiasT; bool m_spatial; ImageLayoutKind m_imageLayout; }; diff --git a/Source/Math/CPUMatrix.cpp b/Source/Math/CPUMatrix.cpp index 692eeb960..0c462be28 100644 --- a/Source/Math/CPUMatrix.cpp +++ b/Source/Math/CPUMatrix.cpp @@ -4293,8 +4293,52 @@ void CPUMatrix::BatchNormalizationForward(const CPUMatrix& s CPUMatrix& out, double epsilon, CPUMatrix& saveMean, CPUMatrix& saveInvStdDev) const { UNUSED(scale); UNUSED(bias); UNUSED(expAvgFactor); UNUSED(runMean); UNUSED(runInvStdDev); UNUSED(out); UNUSED(epsilon); UNUSED(saveMean); UNUSED(saveInvStdDev); + RuntimeError("Not yet implemented."); } +template +void CPUMatrix::BatchNormalizationForwardInference(const CPUMatrix& scale, const CPUMatrix& bias, + const CPUMatrix& runMean, const CPUMatrix& runInvStdDev, + CPUMatrix& out) const +{ + assert((GetNumRows() % scale.GetNumRows()) == 0); + + bool spatial = GetNumRows() != scale.GetNumRows(); + if (spatial) + { + size_t spatialSize = GetNumRows() / scale.GetNumRows(); +#pragma omp parallel for + for (long icol = 0; icol < out.GetNumCols(); icol++) + { + for (long irow = 0; irow < out.GetNumRows(); irow++) + { + size_t imap = irow / spatialSize; + out(irow, icol) = scale(imap, 0) * ((*this)(irow, icol) - runMean(imap, 0)) * runInvStdDev(imap, 0) + bias(imap, 0); + } + } + } + else + { +#pragma omp parallel for + for (long icol = 0; icol < out.GetNumCols(); icol++) + { + for (long irow = 0; irow < out.GetNumRows(); irow++) + { + out(irow, icol) = scale(irow, 0) * ((*this)(irow, icol) - runMean(irow, 0)) * runInvStdDev(irow, 0) + bias(irow, 0); + } + } + } +} + +template +void CPUMatrix::BatchNormalizationBackward(const CPUMatrix& in, CPUMatrix& grad, const CPUMatrix& scale, const CPUMatrix& saveMean, const CPUMatrix& saveInvStdDev, + CPUMatrix& scaleGrad, CPUMatrix& biasGrad) const +{ + UNUSED(in); UNUSED(grad); UNUSED(scale); UNUSED(saveMean); UNUSED(saveInvStdDev); UNUSED(scaleGrad); UNUSED(biasGrad); + RuntimeError("Not yet implemented."); +} + + #pragma region Static BLAS Functions /// Matrix-matrix multiply with col-major matrices (a and b may be transposed): c = alpha * op(a) * op(b) + beta*c diff --git a/Source/Math/CPUMatrix.h b/Source/Math/CPUMatrix.h index b089b1d0b..066ed6641 100644 --- a/Source/Math/CPUMatrix.h +++ b/Source/Math/CPUMatrix.h @@ -335,6 +335,9 @@ public: void BatchNormalizationForward(const CPUMatrix& scale, const CPUMatrix& bias, double expAvgFactor, CPUMatrix& runMean, CPUMatrix& runInvStdDev, CPUMatrix& out, double epsilon, CPUMatrix& saveMean, CPUMatrix& saveInvStdDev) const; + void BatchNormalizationForwardInference(const CPUMatrix& scale, const CPUMatrix& bias, const CPUMatrix& runMean, const CPUMatrix& runInvStdDev, CPUMatrix& out) const; + void BatchNormalizationBackward(const CPUMatrix& in, CPUMatrix& grad, const CPUMatrix& scale, const CPUMatrix& saveMean, const CPUMatrix& saveInvStdDev, + CPUMatrix& scaleGrad, CPUMatrix& biasGrad) const; public: static int SetNumThreads(int numThreads); // note: this does not depend on , i.e. you can call it on any diff --git a/Source/Math/CuDnnBatchNormalization.cu b/Source/Math/CuDnnBatchNormalization.cu index c49891c8b..b3e6fe3fa 100644 --- a/Source/Math/CuDnnBatchNormalization.cu +++ b/Source/Math/CuDnnBatchNormalization.cu @@ -20,11 +20,11 @@ public: public: CuDnnBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, ImageLayoutKind imageLayout) - : Base(deviceId, inOutT, scaleBiasT, spatial, imageLayout), + bool spatial, ImageLayoutKind imageLayout) + : Base(deviceId, inOutT, spatial, imageLayout), m_cudnn(CuDnn::Instance()), - m_inOutCuDnnT(inOutT, CuDnnTensor::GetDataType()), - m_scaleBiasCuDnnT(scaleBiasT, CuDnnTensor::GetDataType()) + m_inOutCuDnnT(GetInOutTensor(inOutT), CuDnnTensor::GetDataType()), + m_scaleBiasCuDnnT(GetScaleBiasTensor(inOutT, spatial), CuDnnTensor::GetDataType()) { } @@ -32,13 +32,14 @@ protected: using Base::m_deviceId; using Base::m_imageLayout; using Base::m_inOutT; - using Base::m_scaleBiasT; using Base::m_spatial; void EnsureCompatible() override { if (m_spatial && m_imageLayout == ImageLayoutKind::HWC) InvalidArgument("cuDNN batch normalization supports only cudnn(CHW) layout."); + if (m_inOutT.GetRank() > 4) + InvalidArgument("cuDNN batch normalization supports tensors of max 4 dimensions."); } void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, @@ -53,6 +54,29 @@ protected: epsilon, ptr(saveMean), ptr(saveInvStdDev))); } + void ForwardInferenceCore(const Mat& in, const Mat& scale, const Mat& bias, const Mat& runMean, const Mat& runInvStdDev, Mat& out) override + { + m_inOutCuDnnT.UpdateBatchSize(in.GetNumCols()); + cudnnBatchNormMode_t mode = m_spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION; + CUDNN_CALL(cudnnBatchNormalizationForwardInference(*m_cudnn, mode, &C::One, &C::Zero, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(out), + m_scaleBiasCuDnnT, ptr(scale), ptr(bias), ptr(runMean), ptr(runInvStdDev), CUDNN_BN_MIN_EPSILON)); + } + + void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, + Mat& scaleGrad, Mat& biasGrad) override + { + m_inOutCuDnnT.UpdateBatchSize(srcGrad.GetNumCols()); + cudnnBatchNormMode_t mode = m_spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION; + // REVIEW alexeyk: remove once Philly is upgraded to prod version. +#if CUDNN_PATCHLEVEL >= 7 + CUDNN_CALL(cudnnBatchNormalizationBackward(*m_cudnn, mode, &C::One, &C::One, &C::One, &C::One, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(srcGrad), m_inOutCuDnnT, ptr(grad), + m_scaleBiasCuDnnT, ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev))); +#else + CUDNN_CALL(cudnnBatchNormalizationBackward(*m_cudnn, mode, &C::One, &C::One, m_inOutCuDnnT, ptr(in), m_inOutCuDnnT, ptr(srcGrad), m_inOutCuDnnT, ptr(grad), + m_scaleBiasCuDnnT, ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev))); +#endif + } + private: template static ElemType* ptr(Matrix& src) @@ -65,6 +89,29 @@ private: return src.BufferPointer(); } + static TensorShape GetInOutTensor(const TensorShape& inOutT) + { + // cuDNN supports only 3D and 4D tensors (in cuDNN docs it's 4D and 5D dues to N dimension) + // even for non-spatial inputs so expand the tensor if needed. + if (inOutT.GetRank() > 2) + return inOutT; + SmallVector v(std::max(inOutT.GetRank(), (size_t)3), 1); + for (size_t i = 0; i < inOutT.GetRank(); i++) + v[i] = inOutT[i]; + return TensorShape(v); + } + + static TensorShape GetScaleBiasTensor(const TensorShape& inOutT, bool spatial) + { + if (!spatial) + return GetInOutTensor(inOutT); + + const auto& t = GetInOutTensor(inOutT); + SmallVector v(t.GetRank(), 1); + v[v.size() - 1] = t[t.GetRank() - 1]; + return TensorShape(v); + } + private: using C = Consts; @@ -78,13 +125,46 @@ template class CuDnnBatchNormEngine; template std::unique_ptr> CuDnnBatchNormEngineFactory::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, - ImageLayoutKind imageLayout) + bool spatial, ImageLayoutKind imageLayout) { - return std::make_unique>(deviceId, inOutT, scaleBiasT, spatial, imageLayout); + return std::make_unique>(deviceId, inOutT, spatial, imageLayout); } template class CuDnnBatchNormEngineFactory; template class CuDnnBatchNormEngineFactory; +CudaTimer::~CudaTimer() +{ + // TODO: Should not throw if std::uncaught_exception() + if (m_start != nullptr) + CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_start))); + if (m_stop != nullptr) + CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_stop))); +} +void CudaTimer::Start() +{ + cudaEvent_t start; + cudaEvent_t stop; + if (m_start != nullptr) + CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_start))); + if (m_stop != nullptr) + CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_stop))); + CUDA_CALL(cudaEventCreate(&start)); + CUDA_CALL(cudaEventCreate(&stop)); + m_start = start; + m_stop = stop; + CUDA_CALL(cudaEventRecord(start, GetStream())); +} +void CudaTimer::Stop() +{ + CUDA_CALL(cudaEventRecord(reinterpret_cast(m_stop), GetStream())); + CUDA_CALL(cudaEventSynchronize(reinterpret_cast(m_stop))); +} +float CudaTimer::Elapsed() +{ + float ms; + CUDA_CALL(cudaEventElapsedTime(&ms, reinterpret_cast(m_start), reinterpret_cast(m_stop))); + return ms; +} + } } } diff --git a/Source/Math/CuDnnConvolutionEngine.cu b/Source/Math/CuDnnConvolutionEngine.cu index 526ea2a8c..f6659c57e 100644 --- a/Source/Math/CuDnnConvolutionEngine.cu +++ b/Source/Math/CuDnnConvolutionEngine.cu @@ -977,38 +977,4 @@ bool CuDnnConvolutionEngineFactory::IsSupported(ConvolveGeometryPtr ge template class CuDnnConvolutionEngineFactory; template class CuDnnConvolutionEngineFactory; -CudaTimer::~CudaTimer() -{ - // TODO: Should not throw if std::uncaught_exception() - if (m_start != nullptr) - CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_start))); - if (m_stop != nullptr) - CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_stop))); -} -void CudaTimer::Start() -{ - cudaEvent_t start; - cudaEvent_t stop; - if (m_start != nullptr) - CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_start))); - if (m_stop != nullptr) - CUDA_CALL(cudaEventDestroy(reinterpret_cast(m_stop))); - CUDA_CALL(cudaEventCreate(&start)); - CUDA_CALL(cudaEventCreate(&stop)); - m_start = start; - m_stop = stop; - CUDA_CALL(cudaEventRecord(start, GetStream())); -} -void CudaTimer::Stop() -{ - CUDA_CALL(cudaEventRecord(reinterpret_cast(m_stop), GetStream())); - CUDA_CALL(cudaEventSynchronize(reinterpret_cast(m_stop))); -} -float CudaTimer::Elapsed() -{ - float ms; - CUDA_CALL(cudaEventElapsedTime(&ms, reinterpret_cast(m_start), reinterpret_cast(m_stop))); - return ms; -} - } } } diff --git a/Source/Math/CuDnnFactories.h b/Source/Math/CuDnnFactories.h index 808e6f414..14fedeb1d 100644 --- a/Source/Math/CuDnnFactories.h +++ b/Source/Math/CuDnnFactories.h @@ -25,42 +25,10 @@ class CuDnnBatchNormEngineFactory { public: static std::unique_ptr> Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, - const TensorShape& scaleBiasT, bool spatial, - ImageLayoutKind imageLayout); + bool spatial, ImageLayoutKind imageLayout); }; -//template -//class CuDnnConvolutionEngineFactory : public ConvolutionEngineFactory -//{ -//public: -// using Base = ConvolutionEngineFactory; -// using typename Base::Tensor4D; -// using typename Base::Tensor4DPtr; -// using typename Base::Filter; -// using typename Base::FilterPtr; -// using typename Base::ConvDesc; -// using typename Base::ConvDescPtr; -// using typename Base::PoolDesc; -// using typename Base::PoolDescPtr; -// -// using typename Base::ConvEnginePtr; -// using typename Base::PoolEnginePtr; -// -//public: -// Tensor4DPtr CreateTensor(size_t w, size_t h, size_t c, size_t n) override; -// FilterPtr CreateFilter(size_t w, size_t h, size_t c, size_t k) override; -// ConvDescPtr CreateConvDescriptor(const Tensor4D& inT, const Filter& filterT, -// size_t wStride, size_t hStride, bool padding) override; -// PoolDescPtr CreatePoolDescriptor(typename PoolDesc::PoolKind kind, size_t w, size_t h, size_t wStride, size_t hStride, size_t wPad, size_t hPad) override; -// -// ConvEnginePtr CreateConvEngine(DEVICEID_TYPE deviceId, ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples, BatchNormImpl bnImpl) override; -// PoolEnginePtr CreatePoolEngine(DEVICEID_TYPE deviceId, ImageLayoutKind imageLayout) override; -// -// static bool IsSupported(DEVICEID_TYPE deviceId); -//}; -// - -// REVIEW alexeyk: wrong place. It is currently used only in unit tests but I can't add it there because of the build issues. +// REVIEW alexeyk: wrong place? It is currently used only in unit tests but I can't add it there because of the build issues. // Timer that can be used to measure CUDA calls. // Uses CUDA event and will synchronize(!) the stream when Stop is called. class MATH_API CudaTimer @@ -79,4 +47,5 @@ private: void* m_start; void* m_stop; }; + } } } diff --git a/Source/Math/GPUMatrix.cu b/Source/Math/GPUMatrix.cu index 4f580bdda..f8c69c011 100644 --- a/Source/Math/GPUMatrix.cu +++ b/Source/Math/GPUMatrix.cu @@ -3053,6 +3053,8 @@ template void GPUMatrix::BatchNormalizationForward(const GPUMatrix& scale, const GPUMatrix& bias, double expAvgFactor, GPUMatrix& runMean, GPUMatrix& runInvStdDev, GPUMatrix& out, double epsilon, GPUMatrix& saveMean, GPUMatrix& saveInvStdDev) const { + assert((GetNumRows() % scale.GetNumRows()) == 0); + bool spatial = GetNumRows() != scale.GetNumRows(); size_t vectorSize = GetNumRows(); size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1; @@ -3064,8 +3066,8 @@ void GPUMatrix::BatchNormalizationForward(const GPUMatrix& s SyncGuard syncGuard; if (spatial) { - Call(spatialSize, vectorSize, spatialSize, batchSize, m_pArray, - expAvgFactor, runMean.m_pArray, runInvStdDev.m_pArray, epsilon, + Call(spatialSize, vectorSize, spatialSize, batchSize, m_pArray, + expAvgFactor, runMean.m_pArray, runInvStdDev.m_pArray, epsilon, saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); } else @@ -3075,8 +3077,59 @@ void GPUMatrix::BatchNormalizationForward(const GPUMatrix& s saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); } Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, - spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray, - saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); + spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray, + saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); +} + +template +void GPUMatrix::BatchNormalizationForwardInference(const GPUMatrix& scale, const GPUMatrix& bias, + const GPUMatrix& runMean, const GPUMatrix& runInvStdDev, + GPUMatrix& out) const +{ + assert((GetNumRows() % scale.GetNumRows()) == 0); + + bool spatial = GetNumRows() != scale.GetNumRows(); + size_t vectorSize = GetNumRows(); + size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1; + size_t batchSize = GetNumCols(); + + assert(0 < vectorSize && vectorSize <= std::numeric_limits::max()); + assert(0 < batchSize && batchSize <= std::numeric_limits::max()); + + SyncGuard syncGuard; + Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, + spatial, m_pArray, out.m_pArray, scale.m_pArray, bias.m_pArray, + runMean.m_pArray, runInvStdDev.m_pArray, GetStream()); +} + +template +void GPUMatrix::BatchNormalizationBackward(const GPUMatrix& in, GPUMatrix& grad, const GPUMatrix& scale, + const GPUMatrix& saveMean, const GPUMatrix& saveInvStdDev, + GPUMatrix& scaleGrad, GPUMatrix& biasGrad) const +{ + assert((GetNumRows() % scale.GetNumRows()) == 0); + + bool spatial = GetNumRows() != scale.GetNumRows(); + size_t vectorSize = GetNumRows(); + size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1; + size_t batchSize = GetNumCols(); + + assert(0 < vectorSize && vectorSize <= std::numeric_limits::max()); + assert(0 < batchSize && batchSize <= std::numeric_limits::max()); + + SyncGuard syncGuard; + if (spatial) + { + Call(spatialSize, vectorSize, spatialSize, batchSize, in.m_pArray, m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray, + saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); + } + else + { + Call(vectorSize, vectorSize, batchSize, in.m_pArray, m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray, + saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); + } + Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial, + in.m_pArray, m_pArray, grad.m_pArray, scale.m_pArray, scaleGrad.m_pArray, biasGrad.m_pArray, saveMean.m_pArray, saveInvStdDev.m_pArray, GetStream()); } #pragma region Static BLAS Functions diff --git a/Source/Math/GPUMatrix.h b/Source/Math/GPUMatrix.h index 083ed81ee..b060c1439 100644 --- a/Source/Math/GPUMatrix.h +++ b/Source/Math/GPUMatrix.h @@ -419,6 +419,9 @@ public: void BatchNormalizationForward(const GPUMatrix& scale, const GPUMatrix& bias, double expAvgFactor, GPUMatrix& runMean, GPUMatrix& runInvStdDev, GPUMatrix& out, double epsilon, GPUMatrix& saveMean, GPUMatrix& saveInvStdDev) const; + void BatchNormalizationForwardInference(const GPUMatrix& scale, const GPUMatrix& bias, const GPUMatrix& runMean, const GPUMatrix& runInvStdDev, GPUMatrix& out) const; + void BatchNormalizationBackward(const GPUMatrix& in, GPUMatrix& grad, const GPUMatrix& scale, const GPUMatrix& saveMean, const GPUMatrix& saveInvStdDev, + GPUMatrix& scaleGrad, GPUMatrix& biasGrad) const; public: // static BLAS functions diff --git a/Source/Math/Matrix.cpp b/Source/Math/Matrix.cpp index bc10b0767..0f40cb2ae 100644 --- a/Source/Math/Matrix.cpp +++ b/Source/Math/Matrix.cpp @@ -4150,6 +4150,45 @@ void Matrix::BatchNormalizationForward(const Matrix& scale, NOT_IMPLEMENTED); } +template +void Matrix::BatchNormalizationForwardInference(const Matrix& scale, const Matrix& bias, + const Matrix& runMean, const Matrix& runInvStdDev, + Matrix& out) const +{ + DecideAndMoveToRightDevice(*this, out); + + // REVIEW alexeyk: add sparse version. + DISPATCH_MATRIX_ON_FLAG(this, + this, + m_CPUMatrix->BatchNormalizationForwardInference(*(scale.m_CPUMatrix), *(bias.m_CPUMatrix), + *(runMean.m_CPUMatrix), *(runInvStdDev.m_CPUMatrix), + *(out.m_CPUMatrix)), + m_GPUMatrix->BatchNormalizationForwardInference(*(scale.m_GPUMatrix), *(bias.m_GPUMatrix), + *(runMean.m_GPUMatrix), *(runInvStdDev.m_GPUMatrix), + *(out.m_GPUMatrix)), + NOT_IMPLEMENTED, + NOT_IMPLEMENTED); +} + +template +void Matrix::BatchNormalizationBackward(const Matrix& in, Matrix& grad, const Matrix& scale, const Matrix& saveMean, const Matrix& saveInvStdDev, + Matrix& scaleGrad, Matrix& biasGrad) const +{ + DecideAndMoveToRightDevice(*this, grad); + + // REVIEW alexeyk: add sparse version. + DISPATCH_MATRIX_ON_FLAG(this, + this, + m_CPUMatrix->BatchNormalizationBackward(*(in.m_CPUMatrix), *(grad.m_CPUMatrix), *(scale.m_CPUMatrix), + *(saveMean.m_CPUMatrix), *(saveInvStdDev.m_CPUMatrix), + *(scaleGrad.m_CPUMatrix), *(biasGrad.m_CPUMatrix)), + m_GPUMatrix->BatchNormalizationBackward(*(in.m_GPUMatrix), *(grad.m_GPUMatrix), *(scale.m_GPUMatrix), + *(saveMean.m_GPUMatrix), *(saveInvStdDev.m_GPUMatrix), + *(scaleGrad.m_GPUMatrix), *(biasGrad.m_GPUMatrix)), + NOT_IMPLEMENTED, + NOT_IMPLEMENTED); +} + #pragma region Static BLAS Functions template diff --git a/Source/Math/Matrix.h b/Source/Math/Matrix.h index 9c682a223..9608f5b6e 100644 --- a/Source/Math/Matrix.h +++ b/Source/Math/Matrix.h @@ -470,6 +470,9 @@ public: void BatchNormalizationForward(const Matrix& scale, const Matrix& bias, double expAvgFactor, Matrix& runMean, Matrix& runInvStdDev, Matrix& out, double epsilon, Matrix& saveMean, Matrix& saveInvStdDev) const; + void BatchNormalizationForwardInference(const Matrix& scale, const Matrix& bias, const Matrix& runMean, const Matrix& runInvStdDev, Matrix& out) const; + void BatchNormalizationBackward(const Matrix& in, Matrix& grad, const Matrix& scale, const Matrix& saveMean, const Matrix& saveInvStdDev, + Matrix& scaleGrad, Matrix& biasGrad) const; public: // TODO: why are these not static? And why are they here? diff --git a/Tests/UnitTests/MathTests/BatchNormalizatinEngineTests.cpp b/Tests/UnitTests/MathTests/BatchNormalizatinEngineTests.cpp index 2a54abbde..1929198db 100644 --- a/Tests/UnitTests/MathTests/BatchNormalizatinEngineTests.cpp +++ b/Tests/UnitTests/MathTests/BatchNormalizatinEngineTests.cpp @@ -19,9 +19,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Test { using vec = std::vector; using BNEng = BatchNormEngine; -std::vector> GenerateBNTestConfigs() +std::vector> GenerateBNTestConfigs() { - std::vector> res; + std::vector> res; // REVIEW alexeyk: how to test batches > 512? cuDNN does not support that so there is no baseline. double expAvgFactor = 1; // Per activation (non-spatial) @@ -33,7 +33,7 @@ std::vector> Generate { for (size_t w : {6, 17, 126, 2048}) { - res.push_back(std::make_tuple(TensorShape(w, h, c), TensorShape(w, h, c), n, false, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(w, h, c), n, false, expAvgFactor)); } } } @@ -47,21 +47,24 @@ std::vector> Generate { for (size_t w : {2, 11, 16}) { - res.push_back(std::make_tuple(TensorShape(w, h, c), TensorShape(1, 1, c), n, true, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(w, h, c), n, true, expAvgFactor)); } } } } // For perf testing (similar to first layers of ResNet). - res.push_back(std::make_tuple(TensorShape(56, 56, 64), TensorShape(1, 1, 64), 64, true, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(56, 56, 64), 64, true, expAvgFactor)); // Next test will fail in cuDNN due to bug we discovered (and reported to NVIDIA). - //res.push_back(std::make_tuple(std::move(fact.CreateTensor(2, 2, 2048, 2)), true)); - res.push_back(std::make_tuple(TensorShape(2, 2, 2048), TensorShape(1, 1, 2048), 64, true, expAvgFactor)); + //res.push_back(std::make_tuple(TensorShape(2, 2, 2048), 2, true, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(2, 2, 2048), 64, true, expAvgFactor)); // Test running mean/isd. expAvgFactor = 0.1; - res.push_back(std::make_tuple(TensorShape(2, 2, 2), TensorShape(2, 2, 2), 8, false, expAvgFactor)); - res.push_back(std::make_tuple(TensorShape(2, 2, 2), TensorShape(1, 1, 2), 8, true, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(2, 2, 2), 8, false, expAvgFactor)); + res.push_back(std::make_tuple(TensorShape(2, 2, 2), 8, true, expAvgFactor)); + + // Test 1D tensor expansion (cuDNN supports 3D and 4D tensors only). + res.push_back(std::make_tuple(TensorShape(2), 8, false, expAvgFactor)); return res; } @@ -89,16 +92,15 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain) for (const auto& cfg : GenerateBNTestConfigs()) { const auto& inOutT = std::get<0>(cfg); - const auto& scaleT = std::get<1>(cfg); - size_t batchSize = std::get<2>(cfg); - bool spatial = std::get<3>(cfg); - double expAvg = std::get<4>(cfg); + size_t batchSize = std::get<1>(cfg); + bool spatial = std::get<2>(cfg); + double expAvg = std::get<3>(cfg); double eps = 1e-5; // CUDNN_BN_MIN_EPSILON - auto engCudnn = BNEng::Create(baseDeviceId, inOutT, scaleT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn); - auto engCntk = BNEng::Create(deviceId, inOutT, scaleT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk); + auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn); + auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk); - size_t crow = inOutT[0] * inOutT[1] * inOutT[2]; + size_t crow = inOutT.GetNumElements(); size_t ccol = batchSize; vec buf(crow * ccol); @@ -106,7 +108,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain) SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal); SingleMatrix inB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal); - size_t crowScaleBias = scaleT.GetNumElements(); + size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements(); buf.resize(crowScaleBias); std::generate(begin(buf), end(buf), [&] { return nd(rng); }); @@ -145,7 +147,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain) time2.Stop(); std::stringstream tmsg; - tmsg << "inOut tensor: " << (std::string)inOutT << ", scaleBias tensor: " << (std::string)scaleT + tmsg << "inOut tensor: " << (std::string)inOutT << ", spatial = " << (spatial ? "true" : "false") << ", expAvg = " << expAvg << ")"; std::string msg = " are not equal, " + tmsg.str(); @@ -194,100 +196,104 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain) } } -//BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInference) -//{ -// if (!IsCuDnnSupported()) -// return; -// -// std::mt19937 rng(0); -// std::normal_distribution nd; -// -// auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix -// { -// data.resize(r * 3 * c); -// std::fill(begin(data), end(data), std::numeric_limits::quiet_NaN()); -// std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); }); -// buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data()); -// // Get center slice. -// return buf.ColumnSlice(c, c); -// }; -// -// for (int deviceId : {0}) -// { -// auto fact = ConvFact::Create(deviceId, ConvFact::EngineType::Auto, ImageLayoutKind::CHW); -// auto engCudnn = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::CuDnn); -// auto engCntk = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::Cntk); -// for (auto& cfg : GenerateBNTestConfigs(*fact)) -// { -// auto& t = *std::move(std::get<0>(cfg)); -// bool spatial = std::get<1>(cfg); -// -// size_t crow = t.w() * t.h() * t.c(); -// size_t ccol = t.n(); -// -// vec buf(crow * t.n()); -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal); -// -// Tensor4DPtr scaleBiasT = spatial ? fact->CreateTensor(1, 1, t.c(), 1) : fact->CreateTensor(t.w(), t.h(), t.c(), 1); -// size_t crowScaleBias = scaleBiasT->w() * scaleBiasT->h() * scaleBiasT->c(); -// buf.resize(crowScaleBias); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix bias(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix runMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix runInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// -// SingleMatrix outBuf(deviceId); -// SingleMatrix out = initMat(outBuf, crow, ccol, buf); -// SingleMatrix outExp(out); -// -// CudaTimer time1; -// time1.Start(); -// engCntk->NormalizeBatchInference(t, in, *scaleBiasT, scale, bias, spatial, runMean, runInvStdDev, out); -// time1.Stop(); -// -// CudaTimer time2; -// time2.Start(); -// engCudnn->NormalizeBatchInference(t, in, *scaleBiasT, scale, bias, spatial, runMean, runInvStdDev, outExp); -// time2.Stop(); -// -// std::stringstream tmsg; -// tmsg << "tensor: (w = " << t.w() << ", h = " << t.h() << ", c = " << t.c() << ", n = " << t.n() << ", spatial = " << (spatial ? "true" : "false") << ")"; -// std::string msg = " are not equal, " + tmsg.str(); -// std::string msgNan = " has NaNs, " + tmsg.str(); -// std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str(); -// -// float relErr = Err::Rel; -// float absErr = Err::Abs; -// std::string emsg; -// -// BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan); -// BOOST_REQUIRE_MESSAGE(CheckEqual(out, outExp, emsg, relErr, absErr * 20), "out" << msg << ". " << emsg); -// BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan); -// // REVIEW alexeyk: add cases for testing numerical stability. -// -//#ifndef _DEBUG -// float elapsedCntk = time1.Elapsed(); -// float elapsedCudnn = time2.Elapsed(); -// // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation. -// // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN. -// if (crow >= 32 && ccol >= 32) -// { -// // Use conservative estimates. -// int speedup = 2; -// BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn, -// "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str()); -// } -//#endif -// } -// } -//} +BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInference) +{ + std::mt19937 rng(0); + std::normal_distribution nd; + + auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix + { + data.resize(r * 3 * c); + std::fill(begin(data), end(data), std::numeric_limits::quiet_NaN()); + std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); }); + buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data()); + // Get center slice. + return buf.ColumnSlice(c, c); + }; + + int baseDeviceId = 0; + for (int deviceId : {0}) + { + for (const auto& cfg : GenerateBNTestConfigs()) + { + const auto& inOutT = std::get<0>(cfg); + size_t batchSize = std::get<1>(cfg); + bool spatial = std::get<2>(cfg); + + auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn); + auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk); + + size_t crow = inOutT.GetNumElements(); + size_t ccol = batchSize; + + vec buf(crow * ccol); + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix in(crow, ccol, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix inB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal); + + size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements(); + buf.resize(crowScaleBias); + + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix scaleB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal); + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix bias(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix biasB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal); + + SingleMatrix runMeanBuf(deviceId); + SingleMatrix runMean = initMat(runMeanBuf, crowScaleBias, 1, buf); + SingleMatrix runMeanB(runMean.DeepClone(), baseDeviceId); + SingleMatrix runInvStdDevBuf(deviceId); + SingleMatrix runInvStdDev = initMat(runInvStdDevBuf, crowScaleBias, 1, buf); + SingleMatrix runInvStdDevB(runInvStdDev.DeepClone(), baseDeviceId); + + SingleMatrix outBuf(deviceId); + SingleMatrix out = initMat(outBuf, crow, ccol, buf); + SingleMatrix outB(out.DeepClone(), baseDeviceId); + + CudaTimer time1; + time1.Start(); + engCntk->ForwardInference(in, scale, bias, runMean, runInvStdDev, out); + time1.Stop(); + + CudaTimer time2; + time2.Start(); + engCudnn->ForwardInference(inB, scaleB, biasB, runMeanB, runInvStdDevB, outB); + time2.Stop(); + + std::stringstream tmsg; + tmsg << "inOut tensor: " << (std::string)inOutT + << ", spatial = " << (spatial ? "true" : "false"); + std::string msg = " are not equal, " + tmsg.str(); + std::string msgNan = " has NaNs, " + tmsg.str(); + std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str(); + + float relErr = Err::Rel; + float absErr = Err::Abs; + std::string emsg; + + BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan); + BOOST_REQUIRE_MESSAGE(CheckEqual(out, outB, emsg, relErr, absErr * 20), "out" << msg << ". " << emsg); + BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan); + // REVIEW alexeyk: add cases for testing numerical stability. + +#ifndef _DEBUG + float elapsedCntk = time1.Elapsed(); + float elapsedCudnn = time2.Elapsed(); + // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation. + // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN. + if (crow >= 32 && ccol >= 32) + { + // Use conservative estimates. + int speedup = 2; + BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn, + "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str()); + } +#endif + } + } +} // //BOOST_AUTO_TEST_CASE(BatchNormalizationForwardInferenceCpu) //{ @@ -373,118 +379,122 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain) // BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crow * 2 * ccol, "out" << msgNotNan); // } //} -// -//BOOST_AUTO_TEST_CASE(BatchNormalizationBackward) -//{ -// if (!IsCuDnnSupported()) -// return; -// -// std::mt19937 rng(0); -// std::normal_distribution nd; -// -// auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix -// { -// data.resize(r * 3 * c); -// std::fill(begin(data), end(data), std::numeric_limits::quiet_NaN()); -// std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); }); -// buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data()); -// // Get center slice. -// return buf.ColumnSlice(c, c); -// }; -// -// for (int deviceId : {0}) -// { -// auto fact = ConvFact::Create(deviceId, ConvFact::EngineType::Auto, ImageLayoutKind::CHW); -// auto engCudnn = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::CuDnn); -// auto engCntk = fact->CreateConvEngine(deviceId, ImageLayoutKind::CHW, 0, BatchNormImpl::Cntk); -// for (auto& cfg : GenerateBNTestConfigs(*fact)) -// { -// auto& t = *std::move(std::get<0>(cfg)); -// bool spatial = std::get<1>(cfg); -// -// size_t crow = t.w() * t.h() * t.c(); -// size_t ccol = t.n(); -// -// vec buf(crow * t.n()); -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix x(crow, ccol, buf.data(), deviceId, matrixFlagNormal); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix dy(crow, ccol, buf.data(), deviceId, matrixFlagNormal); -// -// Tensor4DPtr scaleBiasT = spatial ? fact->CreateTensor(1, 1, t.c(), 1) : fact->CreateTensor(t.w(), t.h(), t.c(), 1); -// size_t crowScaleBias = scaleBiasT->w() * scaleBiasT->h() * scaleBiasT->c(); -// buf.resize(crowScaleBias); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix saveMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// -// std::generate(begin(buf), end(buf), [&] { return nd(rng); }); -// SingleMatrix saveInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); -// -// SingleMatrix dScaleBuf(deviceId); -// SingleMatrix dScale = initMat(dScaleBuf, crowScaleBias, 1, buf); -// SingleMatrix dScaleExp(dScale); -// SingleMatrix dBiasBuf(deviceId); -// SingleMatrix dBias = initMat(dBiasBuf, crowScaleBias, 1, buf); -// SingleMatrix dBiasExp(dBias); -// -// SingleMatrix dxBuf(deviceId); -// SingleMatrix dx = initMat(dxBuf, crow, ccol, buf); -// SingleMatrix dxExp(dx); -// -// CudaTimer time1; -// time1.Start(); -// engCntk->BackwardNormalizeBatch(t, x, dy, dx, *scaleBiasT, scale, spatial, saveMean, saveInvStdDev, dScale, dBias); -// time1.Stop(); -// -// CudaTimer time2; -// time2.Start(); -// engCudnn->BackwardNormalizeBatch(t, x, dy, dxExp, *scaleBiasT, scale, spatial, saveMean, saveInvStdDev, dScaleExp, dBiasExp); -// time2.Stop(); -// -// std::stringstream tmsg; -// tmsg << "tensor: (w = " << t.w() << ", h = " << t.h() << ", c = " << t.c() << ", n = " << t.n() << ", spatial = " << (spatial ? "true" : "false") << ")"; -// std::string msg = " are not equal, " + tmsg.str(); -// std::string msgNan = " has NaNs, " + tmsg.str(); -// std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str(); -// -// float relErr = Err::Rel; -// float absErr = Err::Abs; -// std::string emsg; -// -// BOOST_REQUIRE_MESSAGE(!dx.HasNan("dx"), "dx" << msgNan); -// BOOST_REQUIRE_MESSAGE(CheckEqual(dx, dxExp, emsg, relErr * 16, absErr * 8), "dx" << msg << ". " << emsg); -// BOOST_REQUIRE_MESSAGE(CountNans(dxBuf) == crow * 2 * ccol, "out" << msgNotNan); -// // REVIEW alexeyk: add cases for testing numerical stability. -// -// BOOST_REQUIRE_MESSAGE(!dScale.HasNan("dScale"), "dScale" << msgNan); -// BOOST_REQUIRE_MESSAGE(CheckEqual(dScale, dScaleExp, emsg, relErr * 32, absErr * 8), "dScale" << msg << ". " << emsg); -// BOOST_REQUIRE_MESSAGE(CountNans(dScaleBuf) == crowScaleBias * 2, "dScale" << msgNotNan); -// -// BOOST_REQUIRE_MESSAGE(!dBias.HasNan("dBias"), "dBias" << msgNan); -// BOOST_REQUIRE_MESSAGE(CheckEqual(dBias, dBiasExp, emsg, relErr * 32, absErr * 8), "dBias" << msg << ". " << emsg); -// BOOST_REQUIRE_MESSAGE(CountNans(dBiasBuf) == crowScaleBias * 2, "dBias" << msgNotNan); -// -//#ifndef _DEBUG -// float elapsedCntk = time1.Elapsed(); -// float elapsedCudnn = time2.Elapsed(); -// // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation. -// // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN. -// if (crow >= 32 && ccol >= 32) -// { -// // Use conservative estimates. -// float speedup = 1.3f; -// BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn, -// "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str()); -// } -//#endif -// } -// } -//} + +BOOST_AUTO_TEST_CASE(BatchNormalizationBackward) +{ + std::mt19937 rng(0); + std::normal_distribution nd; + + auto initMat = [&](SingleMatrix& buf, size_t r, size_t c, vec& data) -> SingleMatrix + { + data.resize(r * 3 * c); + std::fill(begin(data), end(data), std::numeric_limits::quiet_NaN()); + std::generate(begin(data) + r * c, begin(data) + 2 * r * c, [&] { return nd(rng); }); + buf.SetValue(r, 3 * c, buf.GetDeviceId(), data.data()); + // Get center slice. + return buf.ColumnSlice(c, c); + }; + + int baseDeviceId = 0; + for (int deviceId : {0}) + { + for (const auto& cfg : GenerateBNTestConfigs()) + { + const auto& inOutT = std::get<0>(cfg); + size_t batchSize = std::get<1>(cfg); + bool spatial = std::get<2>(cfg); + + auto engCudnn = BNEng::Create(baseDeviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::CuDnn); + auto engCntk = BNEng::Create(deviceId, inOutT, spatial, ImageLayoutKind::CHW, BatchNormEngineKind::Cntk); + + size_t crow = inOutT.GetNumElements(); + size_t ccol = batchSize; + + vec buf(crow * ccol); + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix x(crow, ccol, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix xB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal); + + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix dy(crow, ccol, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix dyB(crow, ccol, buf.data(), baseDeviceId, matrixFlagNormal); + + size_t crowScaleBias = spatial ? inOutT[2] : inOutT.GetNumElements(); + buf.resize(crowScaleBias); + + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix scale(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix scaleB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal); + + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix saveMean(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix saveMeanB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal); + + std::generate(begin(buf), end(buf), [&] { return nd(rng); }); + SingleMatrix saveInvStdDev(crowScaleBias, 1, buf.data(), deviceId, matrixFlagNormal); + SingleMatrix saveInvStdDevB(crowScaleBias, 1, buf.data(), baseDeviceId, matrixFlagNormal); + + SingleMatrix dScaleBuf(deviceId); + SingleMatrix dScale = initMat(dScaleBuf, crowScaleBias, 1, buf); + SingleMatrix dScaleB(dScale.DeepClone(), baseDeviceId); + SingleMatrix dBiasBuf(deviceId); + SingleMatrix dBias = initMat(dBiasBuf, crowScaleBias, 1, buf); + SingleMatrix dBiasB(dBias.DeepClone(), baseDeviceId); + + SingleMatrix dxBuf(deviceId); + SingleMatrix dx = initMat(dxBuf, crow, ccol, buf); + SingleMatrix dxB(dx.DeepClone(), baseDeviceId); + + CudaTimer time1; + time1.Start(); + engCntk->Backward(x, dy, dx, scale, saveMean, saveInvStdDev, dScale, dBias); + time1.Stop(); + + CudaTimer time2; + time2.Start(); + engCudnn->Backward(xB, dyB, dxB, scaleB, saveMeanB, saveInvStdDevB, dScaleB, dBiasB); + time2.Stop(); + + std::stringstream tmsg; + tmsg << "inOut tensor: " << (std::string)inOutT + << ", spatial = " << (spatial ? "true" : "false"); + std::string msg = " are not equal, " + tmsg.str(); + std::string msgNan = " has NaNs, " + tmsg.str(); + std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str(); + + float relErr = Err::Rel; + float absErr = Err::Abs; + std::string emsg; + + BOOST_REQUIRE_MESSAGE(!dx.HasNan("dx"), "dx" << msgNan); + BOOST_REQUIRE_MESSAGE(CheckEqual(dx, dxB, emsg, relErr * 16, absErr * 8), "dx" << msg << ". " << emsg); + BOOST_REQUIRE_MESSAGE(CountNans(dxBuf) == crow * 2 * ccol, "out" << msgNotNan); + // REVIEW alexeyk: add cases for testing numerical stability. + + BOOST_REQUIRE_MESSAGE(!dScale.HasNan("dScale"), "dScale" << msgNan); + BOOST_REQUIRE_MESSAGE(CheckEqual(dScale, dScaleB, emsg, relErr * 32, absErr * 8), "dScale" << msg << ". " << emsg); + BOOST_REQUIRE_MESSAGE(CountNans(dScaleBuf) == crowScaleBias * 2, "dScale" << msgNotNan); + + BOOST_REQUIRE_MESSAGE(!dBias.HasNan("dBias"), "dBias" << msgNan); + BOOST_REQUIRE_MESSAGE(CheckEqual(dBias, dBiasB, emsg, relErr * 32, absErr * 8), "dBias" << msg << ". " << emsg); + BOOST_REQUIRE_MESSAGE(CountNans(dBiasBuf) == crowScaleBias * 2, "dBias" << msgNotNan); + +#ifndef _DEBUG + float elapsedCntk = time1.Elapsed(); + float elapsedCudnn = time2.Elapsed(); + // Check performance. Current version of cuDNN (v4 RC) is significanlty slower than CNTK implementation. + // For optimal cases (vectorSize % 32 == 0 and batchSize % 32 == 0), CNTK implementation can be >5x faster than cuDNN. + if (crow >= 32 && ccol >= 32) + { + // Use conservative estimates. + float speedup = 1.3f; + BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn, + "CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str()); + } +#endif + } + } +} BOOST_AUTO_TEST_SUITE_END()