Merge pull request #1046 from shelhamer/cudnn

cuDNN acceleration
This commit is contained in:
Evan Shelhamer 2014-09-08 09:57:44 +02:00
Родитель ae8599655b 359197b039
Коммит 3bafe2fcbb
30 изменённых файлов: 2070 добавлений и 17 удалений

Просмотреть файл

@ -253,10 +253,17 @@ endif
# Debugging
ifeq ($(DEBUG), 1)
COMMON_FLAGS += -DDEBUG -g -O0
NVCCFLAGS += -G
else
COMMON_FLAGS += -DNDEBUG -O2
endif
# cuDNN acceleration configuration.
ifeq ($(USE_CUDNN), 1)
LIBRARIES += cudnn
COMMON_FLAGS += -DUSE_CUDNN
endif
# CPU-only configuration
ifeq ($(CPU_ONLY), 1)
OBJS := $(PROTO_OBJS) $(CXX_OBJS)
@ -299,7 +306,7 @@ LIBRARY_DIRS += $(BLAS_LIB)
# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
# mex may invoke an older gcc that is too liberal with -Wuninitalized
MATLAB_CXXFLAGS := $(CXXFLAGS) -Wno-uninitialized
LINKFLAGS += -fPIC $(COMMON_FLAGS) $(WARNINGS)

Просмотреть файл

@ -1,6 +1,9 @@
## Refer to http://caffe.berkeleyvision.org/installation.html
# Contributions simplifying and improving our build system are welcome!
# cuDNN acceleration switch (uncomment to build with cuDNN).
# USE_CUDNN := 1
# CPU-only switch (uncomment to build without GPU support).
# CPU_ONLY := 1

Просмотреть файл

@ -15,7 +15,7 @@ We have installed Caffe on Ubuntu 14.04, Ubuntu 12.04, OS X 10.9, and OS X 10.8.
Caffe depends on several software packages.
* [CUDA](https://developer.nvidia.com/cuda-zone) library version 6.0, 5.5, or 5.0 and the latest driver version for CUDA 6 or 319.* for CUDA 5 (and NOT 331.*)
* [CUDA](https://developer.nvidia.com/cuda-zone) library version 6.5 (recommended), 6.0, 5.5, or 5.0 and the latest driver version for CUDA 6 or 319.* for CUDA 5 (and NOT 331.*)
* [BLAS](http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) (provided via ATLAS, MKL, or OpenBLAS).
* [OpenCV](http://opencv.org/).
* [Boost](http://www.boost.org/) (>= 1.55, although only 1.55 is tested)
@ -25,13 +25,17 @@ Caffe depends on several software packages.
* For the MATLAB wrapper
* MATLAB with the `mex` compiler.
**CPU-only Caffe**: for cold-brewed CPU-only Caffe uncomment the `CPU_ONLY := 1` in `Makefile.config` to configure and build Caffe without CUDA. This is helpful for cloud or cluster deployment.
**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic.
**CPU-only Caffe**: for cold-brewed CPU-only Caffe uncomment the `CPU_ONLY := 1` flag in `Makefile.config` to configure and build Caffe without CUDA. This is helpful for cloud or cluster deployment.
### CUDA and BLAS
Caffe requires the CUDA `nvcc` compiler to compile its GPU code and CUDA driver for GPU operation.
To install CUDA, go to the [NVIDIA CUDA website](https://developer.nvidia.com/cuda-downloads) and follow installation instructions there. Install the library and the latest standalone driver separately; the driver bundled with the library is usually out-of-date. **Warning!** The 331.* CUDA driver series has a critical performance issue: do not use it.
For best performance, Caffe can be accelerated by [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). Register for free at the cuDNN site, install it, then continue with these installation instructions. To compile with cuDNN set the `USE_CUDNN := 1` flag set in your `Makefile.config`.
Caffe requires BLAS as the backend of its matrix and vector computations.
There are several implementations of this library.
The choice is yours:
@ -92,7 +96,7 @@ Keep reading to find out how to manually build and install the Google flags libr
On **CentOS / RHEL / Fedora**, most of the dependencies can be installed with
sudo yum install protobuf-devel leveldb-devel snappy-devel opencv-devel boost-devel hdf5-devel
The Google flags library, Google logging library and LMDB already made their ways into newer versions of **CentOS / RHEL / Fedora** so it is better to first attempt to install them using `yum`
sudo yum install gflags-devel glog-devel lmdb-devel
@ -192,7 +196,7 @@ If you're not using Anaconda, include `hdf5` in the list above.
**Note** that in order to build the caffe python wrappers you must install boost using the --with-python option:
brew install --build-from-source --with-python --fresh -vd boost
**Note** that Homebrew maintains itself as a separate git repository and making the above `brew edit FORMULA` changes will change files in your local copy of homebrew's master branch. By default, this will prevent you from updating Homebrew using `brew update`, as you will get an error message like the following:
$ brew update
@ -201,7 +205,7 @@ If you're not using Anaconda, include `hdf5` in the list above.
Please, commit your changes or stash them before you can merge.
Aborting
Error: Failure while executing: git pull -q origin refs/heads/master:refs/remotes/origin/master
One solution is to commit your changes to a separate Homebrew branch, run `brew update`, and rebase your changes onto the updated master, as follows:
cd /usr/local
@ -213,7 +217,7 @@ One solution is to commit your changes to a separate Homebrew branch, run `brew
git rebase master caffe
# Resolve any merge conflicts here
git checkout caffe
At this point, you should be running the latest Homebrew packages and your Caffe-related modifications will remain in place. You may still get the following error:
$ brew update
@ -240,6 +244,8 @@ The defaults should work, but uncomment the relevant lines if using Anaconda Pyt
make test
make runtest
To compile with cuDNN acceleration, you should uncomment the `USE_CUDNN := 1` switch in `Makefile.config`.
If there is no GPU in your machine, you should switch to CPU-only Caffe by uncommenting `CPU_ONLY := 1` in `Makefile.config`.
To compile the Python and MATLAB wrappers do `make pycaffe` and `make matcaffe` respectively.

Просмотреть файл

@ -4,7 +4,7 @@ title: Performance and Hardware Configuration
# Performance and Hardware Configuration
To measure performance on different NVIDIA GPUs we use the Caffe reference ImageNet model.
To measure performance on different NVIDIA GPUs we use CaffeNet, the Caffe reference ImageNet model.
For training, each time point is 20 iterations/minibatches of 256 images for 5,120 images total. For testing, a 50,000 image validation set is classified.
@ -14,11 +14,16 @@ For training, each time point is 20 iterations/minibatches of 256 images for 5,1
Performance is best with ECC off and boost clock enabled. While ECC makes a negligible difference in speed, disabling it frees ~1 GB of GPU memory.
Best settings with ECC off and maximum clock speed:
Best settings with ECC off and maximum clock speed in standard Caffe:
* Training is 26.5 secs / 20 iterations (5,120 images)
* Testing is 100 secs / validation set (50,000 images)
Best settings with Caffe + [cuDNN acceleration](http://nvidia.com/cudnn):
* Training is 19.2 secs / 20 iterations (5,120 images)
* Testing is 60.7 secs / validation set (50,000 images)
Other settings:
* ECC on, max speed: training 26.7 secs / 20 iterations, test 101 secs / validation set
@ -50,12 +55,19 @@ but note that this configuration resets across driver reloading / rebooting. Inc
Training: 26.26 secs / 20 iterations (5,120 images).
Testing: 100 secs / validation set (50,000 images).
cuDNN Training: 20.25 secs / 20 iterations (5,120 images).
cuDNN Testing: 66.3 secs / validation set (50,000 images).
## NVIDIA K20
Training: 36.0 secs / 20 iterations (5,120 images).
Testing: 133 secs / validation set (50,000 images)
Testing: 133 secs / validation set (50,000 images).
## NVIDIA GTX 770
Training: 33.0 secs / 20 iterations (5,120 images).
Testing: 129 secs / validation set (50,000 images)
Testing: 129 secs / validation set (50,000 images).
cuDNN Training: 24.3 secs / 20 iterations (5,120 images).
cuDNN Testing: 104 secs / validation set (50,000 images).

Просмотреть файл

@ -375,6 +375,32 @@ class SoftmaxLayer : public Layer<Dtype> {
Blob<Dtype> scale_;
};
#ifdef USE_CUDNN
/**
* @brief cuDNN implementation of SoftmaxLayer.
* Fallback to SoftmaxLayer for CPU mode.
*/
template <typename Dtype>
class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {
public:
explicit CuDNNSoftmaxLayer(const LayerParameter& param)
: SoftmaxLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNSoftmaxLayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
};
#endif
/**
* @brief Creates a "split" path in the network by copying the bottom Blob
* into multiple top Blob%s to be used by multiple consuming layers.

Просмотреть файл

@ -356,6 +356,31 @@ class ReLULayer : public NeuronLayer<Dtype> {
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
};
#ifdef USE_CUDNN
/**
* @brief CuDNN acceleration of ReLULayer.
*/
template <typename Dtype>
class CuDNNReLULayer : public ReLULayer<Dtype> {
public:
explicit CuDNNReLULayer(const LayerParameter& param)
: ReLULayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNReLULayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
};
#endif
/**
* @brief Sigmoid function non-linearity @f$
* y = (1 + \exp(-x))^{-1}
@ -413,6 +438,31 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
};
#ifdef USE_CUDNN
/**
* @brief CuDNN acceleration of SigmoidLayer.
*/
template <typename Dtype>
class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
public:
explicit CuDNNSigmoidLayer(const LayerParameter& param)
: SigmoidLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNSigmoidLayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
};
#endif
/**
* @brief TanH hyperbolic tangent non-linearity @f$
* y = \frac{\exp(2x) - 1}{\exp(2x) + 1}
@ -472,6 +522,31 @@ class TanHLayer : public NeuronLayer<Dtype> {
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
};
#ifdef USE_CUDNN
/**
* @brief CuDNN acceleration of TanHLayer.
*/
template <typename Dtype>
class CuDNNTanHLayer : public TanHLayer<Dtype> {
public:
explicit CuDNNTanHLayer(const LayerParameter& param)
: TanHLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNTanHLayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
};
#endif
/**
* @brief Tests whether the input exceeds a threshold: outputs 1 for inputs
* above threshold; 0 otherwise.

Просмотреть файл

@ -0,0 +1,119 @@
#ifndef CAFFE_UTIL_CUDNN_H_
#define CAFFE_UTIL_CUDNN_H_
#ifdef USE_CUDNN
#include <cudnn.h>
#include "caffe/proto/caffe.pb.h"
#define CUDNN_CHECK(condition) \
do { \
cudnnStatus_t status = condition; \
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
<< cudnnGetErrorString(status); \
} while (0)
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
}
return "Unknown cudnn status";
}
namespace caffe {
namespace cudnn {
template <typename Dtype> class dataType;
template<> class dataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
};
template<> class dataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
};
template <typename Dtype>
inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}
template <typename Dtype>
inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
createTensor4dDesc<Dtype>(desc, n, c, h, w,
stride_n, stride_c, stride_h, stride_w);
}
template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
}
template <typename Dtype>
inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}
template <typename Dtype>
inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
int h, int w, int stride_h, int stride_w) {
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
*mode = CUDNN_POOLING_AVERAGE;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv));
CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w,
stride_h, stride_w));
}
} // namespace cudnn
} // namespace caffe
#endif // USE_CUDNN
#endif // CAFFE_UTIL_CUDNN_H_

Просмотреть файл

@ -36,6 +36,9 @@ void classname<Dtype>::funcname##_##gpu(const vector<Blob<Dtype>*>& top, \
#include <cuda_runtime.h>
#include <curand.h>
#include <driver_types.h> // cuda driver types
#ifdef USE_CUDNN // cuDNN acceleration library.
#include "caffe/util/cudnn.hpp"
#endif
//
// CUDA macros

Просмотреть файл

@ -63,6 +63,37 @@ class ConvolutionLayer : public Layer<Dtype> {
Blob<Dtype> bias_multiplier_;
};
#ifdef USE_CUDNN
/*
* @brief cuDNN implementation of ConvolutionLayer.
* Fallback to ConvolutionLayer for CPU mode.
*/
template <typename Dtype>
class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype>
{
public:
explicit CuDNNConvolutionLayer(const LayerParameter& param)
: ConvolutionLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNConvolutionLayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t* handle_;
cudaStream_t* stream_;
vector<cudnnTensor4dDescriptor_t> bottom_descs_, top_descs_;
cudnnTensor4dDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
};
#endif
/**
* @brief A helper for image operations that rearranges image regions into
* column vectors. Used by ConvolutionLayer to perform convolution
@ -225,6 +256,33 @@ class PoolingLayer : public Layer<Dtype> {
Blob<int> max_idx_;
};
#ifdef USE_CUDNN
/*
* @brief cuDNN implementation of PoolingLayer.
* Fallback to PoolingLayer for CPU mode.
*/
template <typename Dtype>
class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
public:
explicit CuDNNPoolingLayer(const LayerParameter& param)
: PoolingLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual ~CuDNNPoolingLayer();
protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
};
#endif
} // namespace caffe
#endif // CAFFE_VISION_LAYERS_HPP_

Просмотреть файл

@ -16,9 +16,16 @@ ConvolutionLayer<Dtype>* GetConvolutionLayer(const string& name,
ConvolutionParameter_Engine engine = param.convolution_param().engine();
if (engine == ConvolutionParameter_Engine_DEFAULT) {
engine = ConvolutionParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = ConvolutionParameter_Engine_CUDNN;
#endif
}
if (engine == ConvolutionParameter_Engine_CAFFE) {
return new ConvolutionLayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == ConvolutionParameter_Engine_CUDNN) {
return new CuDNNConvolutionLayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}
@ -36,9 +43,16 @@ PoolingLayer<Dtype>* GetPoolingLayer(const string& name,
PoolingParameter_Engine engine = param.pooling_param().engine();
if (engine == PoolingParameter_Engine_DEFAULT) {
engine = PoolingParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = PoolingParameter_Engine_CUDNN;
#endif
}
if (engine == PoolingParameter_Engine_CAFFE) {
return new PoolingLayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == PoolingParameter_Engine_CUDNN) {
return new CuDNNPoolingLayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}
@ -56,9 +70,16 @@ ReLULayer<Dtype>* GetReLULayer(const string& name,
ReLUParameter_Engine engine = param.relu_param().engine();
if (engine == ReLUParameter_Engine_DEFAULT) {
engine = ReLUParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = ReLUParameter_Engine_CUDNN;
#endif
}
if (engine == ReLUParameter_Engine_CAFFE) {
return new ReLULayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == ReLUParameter_Engine_CUDNN) {
return new CuDNNReLULayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}
@ -76,9 +97,16 @@ SigmoidLayer<Dtype>* GetSigmoidLayer(const string& name,
SigmoidParameter_Engine engine = param.sigmoid_param().engine();
if (engine == SigmoidParameter_Engine_DEFAULT) {
engine = SigmoidParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = SigmoidParameter_Engine_CUDNN;
#endif
}
if (engine == SigmoidParameter_Engine_CAFFE) {
return new SigmoidLayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == SigmoidParameter_Engine_CUDNN) {
return new CuDNNSigmoidLayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}
@ -96,9 +124,16 @@ TanHLayer<Dtype>* GetTanHLayer(const string& name,
TanHParameter_Engine engine = param.tanh_param().engine();
if (engine == TanHParameter_Engine_DEFAULT) {
engine = TanHParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = TanHParameter_Engine_CUDNN;
#endif
}
if (engine == TanHParameter_Engine_CAFFE) {
return new TanHLayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == TanHParameter_Engine_CUDNN) {
return new CuDNNTanHLayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}
@ -116,9 +151,16 @@ SoftmaxLayer<Dtype>* GetSoftmaxLayer(const string& name,
SoftmaxParameter_Engine engine = param.softmax_param().engine();
if (engine == SoftmaxParameter_Engine_DEFAULT) {
engine = SoftmaxParameter_Engine_CAFFE;
#ifdef USE_CUDNN
engine = SoftmaxParameter_Engine_CUDNN;
#endif
}
if (engine == SoftmaxParameter_Engine_CAFFE) {
return new SoftmaxLayer<Dtype>(param);
#ifdef USE_CUDNN
} else if (engine == SoftmaxParameter_Engine_CUDNN) {
return new CuDNNSoftmaxLayer<Dtype>(param);
#endif
} else {
LOG(FATAL) << "Layer " << name << " has unknown engine.";
}

Просмотреть файл

@ -66,11 +66,11 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_EQ(channels_ % group_, 0);
// The im2col result buffer would only hold one image at a time to avoid
// overly large memory usage.
int height_out =
height_out_ =
(height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
int width_out = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
col_buffer_.Reshape(
1, channels_ * kernel_h_ * kernel_w_, height_out, width_out);
1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
// Set the parameters
CHECK_EQ(num_output_ % group_, 0)
<< "Number of output should be multiples of group.";
@ -78,9 +78,9 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
// Figure out the dimensions for individual gemms.
M_ = num_output_ / group_;
K_ = channels_ * kernel_h_ * kernel_w_ / group_;
N_ = height_out * width_out;
N_ = height_out_ * width_out_;
for (int top_id = 0; top_id < top->size(); ++top_id) {
(*top)[top_id]->Reshape(num_, num_output_, height_out, width_out);
(*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
}
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {

Просмотреть файл

@ -0,0 +1,106 @@
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
// Set to three for the benefit of the backward pass, which
// can use separate streams for calculating the gradient w.r.t.
// bias, filter weights, and bottom data for each group independently
#define CUDNN_STREAMS_PER_GROUP 3
/**
* TODO(dox) explain cuDNN interface
*/
template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
// Initialize CUDA streams and cuNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
CUDNN_CHECK(cudnnCreate(&handle_[g]));
CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
}
// Set the indexing parameters.
bottom_offset_ = (this->channels_ / this->group_)
* this->height_ * this->width_;
top_offset_ = (this->num_output_ / this->group_)
* this->height_out_ * this->width_out_;
weight_offset_ = (this->num_output_ / this->group_)
* (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_;
bias_offset_ = (this->num_output_ / this->group_);
// Create filter descriptor.
cudnn::createFilterDesc<Dtype>(&filter_desc_,
this->num_output_ / this->group_, this->channels_ / this->group_,
this->kernel_h_, this->kernel_w_);
// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensor4dDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc,
this->num_,
this->channels_ / this->group_,
this->height_, this->width_,
this->channels_ * this->height_ * this->width_,
this->height_ * this->width_,
this->width_, 1);
bottom_descs_.push_back(bottom_desc);
cudnnTensor4dDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc,
this->num_,
this->num_output_ / this->group_,
this->height_out_, this->width_out_,
this->num_output_ * this->height_out_ * this->width_out_,
this->height_out_ * this->width_out_,
this->width_out_, 1);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
cudnn::createConvolutionDesc<Dtype>(&conv_desc, bottom_desc,
filter_desc_, this->pad_h_, this->pad_w_,
this->stride_h_, this->stride_w_);
conv_descs_.push_back(conv_desc);
}
// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::createTensor4dDesc<Dtype>(&bias_desc_,
1, this->num_output_ / this->group_, 1, 1);
}
}
template <typename Dtype>
CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
for (int i = 0; i < bottom_descs_.size(); i++) {
cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
cudnnDestroyTensor4dDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
cudnnDestroyTensor4dDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
cudaStreamDestroy(stream_[g]);
cudnnDestroy(handle_[g]);
}
delete [] stream_;
delete [] handle_;
}
INSTANTIATE_CLASS(CuDNNConvolutionLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,109 @@
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
__global__ void sync_conv_groups() { }
template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* top_data = (*top)[i]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
top_descs_[i], top_data + top_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
Dtype alpha = 1.;
CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
bias_desc_, bias_data + bias_offset_ * g,
top_descs_[i], top_data + top_offset_ * g));
}
}
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
sync_conv_groups<<<1, 1>>>();
}
}
template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* weight = NULL;
Dtype* weight_diff = NULL;
if (this->param_propagate_down_[0]) {
weight = this->blobs_[0]->gpu_data();
weight_diff = this->blobs_[0]->mutable_gpu_diff();
caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
Dtype* bias_diff = NULL;
if (this->bias_term_ && this->param_propagate_down_[1]) {
bias_diff = this->blobs_[1]->mutable_gpu_diff();
caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();
// Backward through cuDNN in parallel over groups and gradients.
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
top_descs_[i], top_diff + top_offset_ * g,
bias_desc_, bias_diff + bias_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
}
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = (*bottom)[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
filter_desc_, weight_diff + weight_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
}
// Gradient w.r.t. bottom data.
if (propagate_down[i]) {
Dtype* bottom_diff = (*bottom)[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
bottom_descs_[i], bottom_diff + bottom_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
}
}
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
sync_conv_groups<<<1, 1>>>();
}
}
INSTANTIATE_CLASS(CuDNNConvolutionLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,38 @@
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
PoolingLayer<Dtype>::LayerSetUp(bottom, top);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
this->channels_, this->height_, this->width_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
this->channels_, this->pooled_height_, this->pooled_width_);
cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
}
template <typename Dtype>
CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
INSTANTIATE_CLASS(CuDNNPoolingLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,52 @@
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// Fallback to Caffe for padded pooling, max top mask.
if ((this->pad_h_ > 0 || this->pad_w_ > 0) || (*top).size() > 1) {
LOG(WARNING) << "Falling back to standard Caffe for padded pooling.";
return PoolingLayer<Dtype>::Forward_gpu(bottom, top);
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
bottom_desc_, bottom_data, top_desc_, top_data));
}
template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
if (!propagate_down[0]) {
return;
}
// Fallback to Caffe for padded pooling, max top mask.
if ((this->pad_h_ > 0 || this->pad_w_ > 0) || top.size() > 1) {
LOG(WARNING) << "Falling back to standard Caffe for padded pooling.";
return PoolingLayer<Dtype>::Backward_gpu(top, propagate_down, bottom);
}
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data, bottom_desc_, bottom_diff));
}
INSTANTIATE_CLASS(CuDNNPoolingLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,34 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
ReLULayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
const int N = bottom[0]->num();
const int K = bottom[0]->channels();
const int H = bottom[0]->height();
const int W = bottom[0]->width();
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
}
template <typename Dtype>
CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
INSTANTIATE_CLASS(CuDNNReLULayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,51 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// Fallback to standard Caffe for leaky ReLU.
if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
return ReLULayer<Dtype>::Forward_gpu(bottom, top);
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_RELU,
this->bottom_desc_, bottom_data, this->top_desc_, top_data));
}
template <typename Dtype>
void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (!propagate_down[0]) {
return;
}
// Fallback to standard Caffe for leaky ReLU.
if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
return ReLULayer<Dtype>::Backward_gpu(top, propagate_down, bottom);
}
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_RELU,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
}
INSTANTIATE_CLASS(CuDNNReLULayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,34 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
SigmoidLayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
const int N = bottom[0]->num();
const int K = bottom[0]->channels();
const int H = bottom[0]->height();
const int W = bottom[0]->width();
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
}
template <typename Dtype>
CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
INSTANTIATE_CLASS(CuDNNSigmoidLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,41 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
this->bottom_desc_, bottom_data, this->top_desc_, top_data));
}
template <typename Dtype>
void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (!propagate_down[0]) {
return;
}
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
}
INSTANTIATE_CLASS(CuDNNSigmoidLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,38 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <cfloat>
#include <vector>
#include "thrust/device_vector.h"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
SoftmaxLayer<Dtype>::LayerSetUp(bottom, top);
// Initialize CUDNN.
CUDNN_CHECK(cudnnCreate(&handle_));
int N = bottom[0]->num();
int K = bottom[0]->channels();
int H = bottom[0]->height();
int W = bottom[0]->width();
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
}
template <typename Dtype>
CuDNNSoftmaxLayer<Dtype>::~CuDNNSoftmaxLayer() {
cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroy(handle_);
}
INSTANTIATE_CLASS(CuDNNSoftmaxLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,41 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <cfloat>
#include <vector>
#include "thrust/device_vector.h"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
bottom_desc_, bottom_data, top_desc_, top_data));
}
template <typename Dtype>
void CuDNNSoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
if (propagate_down[0]) {
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
top_desc_, top_data, top_desc_, top_diff, bottom_desc_, bottom_diff));
}
}
INSTANTIATE_CLASS(CuDNNSoftmaxLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,34 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
TanHLayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
const int N = bottom[0]->num();
const int K = bottom[0]->channels();
const int H = bottom[0]->height();
const int W = bottom[0]->width();
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
}
template <typename Dtype>
CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
INSTANTIATE_CLASS(CuDNNTanHLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -0,0 +1,41 @@
#ifdef USE_CUDNN
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_TANH,
this->bottom_desc_, bottom_data, this->top_desc_, top_data));
}
template <typename Dtype>
void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (!propagate_down[0]) {
return;
}
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_TANH,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
}
INSTANTIATE_CLASS(CuDNNTanHLayer);
} // namespace caffe
#endif

Просмотреть файл

@ -1,4 +1,3 @@
//
#include <algorithm>
#include <vector>

Просмотреть файл

@ -388,6 +388,7 @@ message ConvolutionParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 15 [default = DEFAULT];
}
@ -579,6 +580,7 @@ message PoolingParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 11 [default = DEFAULT];
}
@ -602,6 +604,7 @@ message ReLUParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 2 [default = DEFAULT];
}
@ -611,6 +614,7 @@ message SigmoidParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [default = DEFAULT];
}
@ -630,6 +634,7 @@ message SoftmaxParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [default = DEFAULT];
}
@ -639,6 +644,7 @@ message TanHParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [default = DEFAULT];
}

Просмотреть файл

@ -302,4 +302,295 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) {
&(this->blob_top_vec_));
}
#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNConvolutionLayerTest : public ::testing::Test {
protected:
CuDNNConvolutionLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
blob_top_(new Blob<Dtype>()),
blob_top_2_(new Blob<Dtype>()) {}
virtual void SetUp() {
// fill the values
FillerParameter filler_param;
filler_param.set_value(1.);
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
filler.Fill(this->blob_bottom_2_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~CuDNNConvolutionLayerTest() {
delete blob_bottom_;
delete blob_bottom_2_;
delete blob_top_;
delete blob_top_2_;
}
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_bottom_2_;
Blob<Dtype>* const blob_top_;
Blob<Dtype>* const blob_top_2_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
TYPED_TEST_CASE(CuDNNConvolutionLayerTest, TestDtypes);
TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(4);
this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
this->blob_top_vec_.push_back(this->blob_top_2_);
shared_ptr<Layer<TypeParam> > layer(
new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 4);
EXPECT_EQ(this->blob_top_->height(), 2);
EXPECT_EQ(this->blob_top_->width(), 1);
EXPECT_EQ(this->blob_top_2_->num(), 2);
EXPECT_EQ(this->blob_top_2_->channels(), 4);
EXPECT_EQ(this->blob_top_2_->height(), 2);
EXPECT_EQ(this->blob_top_2_->width(), 1);
// setting group should not change the shape
convolution_param->set_num_output(3);
convolution_param->set_group(3);
layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3);
EXPECT_EQ(this->blob_top_->height(), 2);
EXPECT_EQ(this->blob_top_->width(), 1);
EXPECT_EQ(this->blob_top_2_->num(), 2);
EXPECT_EQ(this->blob_top_2_->channels(), 3);
EXPECT_EQ(this->blob_top_2_->height(), 2);
EXPECT_EQ(this->blob_top_2_->width(), 1);
}
TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) {
// We will simply see if the convolution layer carries out averaging well.
Caffe::set_mode(Caffe::GPU);
shared_ptr<ConstantFiller<TypeParam> > filler;
FillerParameter filler_param;
filler_param.set_value(1.);
filler.reset(new ConstantFiller<TypeParam>(filler_param));
filler->Fill(this->blob_bottom_);
filler_param.set_value(2.);
filler.reset(new ConstantFiller<TypeParam>(filler_param));
filler->Fill(this->blob_bottom_2_);
this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
this->blob_top_vec_.push_back(this->blob_top_2_);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(4);
convolution_param->mutable_weight_filler()->set_type("constant");
convolution_param->mutable_weight_filler()->set_value(1);
convolution_param->mutable_bias_filler()->set_type("constant");
convolution_param->mutable_bias_filler()->set_value(0.1);
shared_ptr<Layer<TypeParam> > layer(
new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// After the convolution, the output should all have output values 27.1
const TypeParam* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_top_->count(); ++i) {
EXPECT_NEAR(top_data[i], 27.1, 1e-4);
}
top_data = this->blob_top_2_->cpu_data();
for (int i = 0; i < this->blob_top_2_->count(); ++i) {
EXPECT_NEAR(top_data[i], 54.1, 1e-4);
}
}
TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) {
// We will simply see if the convolution layer carries out averaging well.
Caffe::set_mode(Caffe::GPU);
FillerParameter filler_param;
filler_param.set_value(1.);
ConstantFiller<TypeParam> filler(filler_param);
filler.Fill(this->blob_bottom_);
TypeParam* bottom_data = this->blob_bottom_->mutable_cpu_data();
for (int n = 0; n < this->blob_bottom_->num(); ++n) {
for (int c = 0; c < this->blob_bottom_->channels(); ++c) {
for (int h = 0; h < this->blob_bottom_->height(); ++h) {
for (int w = 0; w < this->blob_bottom_->width(); ++w) {
bottom_data[this->blob_bottom_->offset(n, c, h, w)] = c;
}
}
}
}
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(3);
convolution_param->set_group(3);
convolution_param->mutable_weight_filler()->set_type("constant");
convolution_param->mutable_weight_filler()->set_value(1);
convolution_param->mutable_bias_filler()->set_type("constant");
convolution_param->mutable_bias_filler()->set_value(0.1);
shared_ptr<Layer<TypeParam> > layer(
new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// After the convolution, the output should all have output values 9.1
const TypeParam* top_data = this->blob_top_->cpu_data();
for (int n = 0; n < this->blob_top_->num(); ++n) {
for (int c = 0; c < this->blob_top_->channels(); ++c) {
for (int h = 0; h < this->blob_top_->height(); ++h) {
for (int w = 0; w < this->blob_top_->width(); ++w) {
TypeParam data = top_data[this->blob_top_->offset(n, c, h, w)];
EXPECT_NEAR(data, c * 9 + 0.1, 1e-4);
}
}
}
}
}
TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
// Test separable convolution by computing the Sobel operator
// as a single filter then comparing the result
// as the convolution of two rectangular filters.
Caffe::set_mode(Caffe::GPU);
// Fill bottoms with identical Gaussian noise.
shared_ptr<GaussianFiller<TypeParam> > filler;
FillerParameter filler_param;
filler_param.set_value(1.);
filler.reset(new GaussianFiller<TypeParam>(filler_param));
filler->Fill(this->blob_bottom_);
this->blob_bottom_2_->CopyFrom(*this->blob_bottom_);
// Compute Sobel G_x operator as 3 x 3 convolution.
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(1);
convolution_param->set_bias_term(false);
shared_ptr<Layer<TypeParam> > layer(
new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->blobs().resize(1);
layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 3, 3));
TypeParam* weights = layer->blobs()[0]->mutable_cpu_data();
for (int c = 0; c < 3; ++c) {
int i = c * 9; // 3 x 3 filter
weights[i + 0] = -1;
weights[i + 1] = 0;
weights[i + 2] = 1;
weights[i + 3] = -2;
weights[i + 4] = 0;
weights[i + 5] = 2;
weights[i + 6] = -1;
weights[i + 7] = 0;
weights[i + 8] = 1;
}
layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Compute Sobel G_x operator as separable 3 x 1 and 1 x 3 convolutions.
// (1) the [1 2 1] column filter
vector<Blob<TypeParam>*> sep_blob_bottom_vec;
vector<Blob<TypeParam>*> sep_blob_top_vec;
shared_ptr<Blob<TypeParam> > blob_sep(new Blob<TypeParam>());
sep_blob_bottom_vec.push_back(this->blob_bottom_2_);
sep_blob_top_vec.push_back(this->blob_top_2_);
convolution_param->clear_kernel_size();
convolution_param->clear_stride();
convolution_param->set_kernel_h(3);
convolution_param->set_kernel_w(1);
convolution_param->set_stride_h(2);
convolution_param->set_stride_w(1);
convolution_param->set_num_output(1);
convolution_param->set_bias_term(false);
layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->blobs().resize(1);
layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 3, 1));
TypeParam* weights_1 = layer->blobs()[0]->mutable_cpu_data();
for (int c = 0; c < 3; ++c) {
int i = c * 3; // 3 x 1 filter
weights_1[i + 0] = 1;
weights_1[i + 1] = 2;
weights_1[i + 2] = 1;
}
layer->SetUp(sep_blob_bottom_vec, &(sep_blob_top_vec));
layer->Forward(sep_blob_bottom_vec, &(sep_blob_top_vec));
// (2) the [-1 0 1] row filter
blob_sep->CopyFrom(*this->blob_top_2_, false, true);
sep_blob_bottom_vec.clear();
sep_blob_bottom_vec.push_back(blob_sep.get());
convolution_param->set_kernel_h(1);
convolution_param->set_kernel_w(3);
convolution_param->set_stride_h(1);
convolution_param->set_stride_w(2);
convolution_param->set_num_output(1);
convolution_param->set_bias_term(false);
layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
layer->blobs().resize(1);
layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 1, 3));
TypeParam* weights_2 = layer->blobs()[0]->mutable_cpu_data();
for (int c = 0; c < 3; ++c) {
int i = c * 3; // 1 x 3 filter
weights_2[i + 0] = -1;
weights_2[i + 1] = 0;
weights_2[i + 2] = 1;
}
layer->SetUp(sep_blob_bottom_vec, &(sep_blob_top_vec));
layer->Forward(sep_blob_bottom_vec, &(sep_blob_top_vec));
// Test equivalence of full and separable filters.
const TypeParam* top_data = this->blob_top_->cpu_data();
const TypeParam* sep_top_data = this->blob_top_2_->cpu_data();
for (int i = 0; i < this->blob_top_->count(); ++i) {
EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4);
}
}
TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
this->blob_top_vec_.push_back(this->blob_top_2_);
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(2);
convolution_param->mutable_weight_filler()->set_type("gaussian");
convolution_param->mutable_bias_filler()->set_type("gaussian");
CuDNNConvolutionLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
convolution_param->set_kernel_size(3);
convolution_param->set_stride(2);
convolution_param->set_num_output(3);
convolution_param->set_group(3);
convolution_param->mutable_weight_filler()->set_type("gaussian");
convolution_param->mutable_bias_filler()->set_type("gaussian");
CuDNNConvolutionLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
#endif
} // namespace caffe

Просмотреть файл

@ -272,5 +272,137 @@ TYPED_TEST(NeuronLayerTest, TestBNLLGradient) {
&(this->blob_top_vec_));
}
#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNNeuronLayerTest : public ::testing::Test {
protected:
CuDNNNeuronLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
blob_top_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~CuDNNNeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
TYPED_TEST_CASE(CuDNNNeuronLayerTest, TestDtypes);
TYPED_TEST(CuDNNNeuronLayerTest, TestReLUCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNReLULayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_GE(top_data[i], 0.);
EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
}
}
TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNReLULayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
TYPED_TEST(CuDNNNeuronLayerTest, TestReLUWithNegativeSlopeCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
layer_param.ParseFromString("relu_param{negative_slope:0.01}");
CuDNNReLULayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_GE(top_data[i], 0.);
EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
}
}
TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientWithNegativeSlopeCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
layer_param.ParseFromString("relu_param{negative_slope:0.01}");
CuDNNReLULayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNSigmoidLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
const TypeParam* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_FLOAT_EQ(top_data[i], 1. / (1 + exp(-bottom_data[i])));
// check that we squashed the value between 0 and 1
EXPECT_GE(top_data[i], 0.);
EXPECT_LE(top_data[i], 1.);
}
}
TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNSigmoidLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
TYPED_TEST(CuDNNNeuronLayerTest, TestTanHCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNTanHLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test exact values
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
for (int k = 0; k < this->blob_bottom_->height(); ++k) {
for (int l = 0; l < this->blob_bottom_->width(); ++l) {
EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
(exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
(exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
(exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
(exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
}
}
}
}
}
TYPED_TEST(CuDNNNeuronLayerTest, TestTanHGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNTanHLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
#endif
} // namespace caffe

Просмотреть файл

@ -592,5 +592,587 @@ TYPED_TEST(PoolingLayerTest, TestGradientAvePadded) {
}
}
#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNPoolingLayerTest : public ::testing::Test {
protected:
CuDNNPoolingLayerTest()
: blob_bottom_(new Blob<Dtype>()),
blob_top_(new Blob<Dtype>()),
blob_top_mask_(new Blob<Dtype>()) {}
virtual void SetUp() {
Caffe::set_random_seed(1701);
blob_bottom_->Reshape(2, 3, 6, 5);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~CuDNNPoolingLayerTest() {
delete blob_bottom_;
delete blob_top_;
delete blob_top_mask_;
}
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
Blob<Dtype>* const blob_top_mask_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
// Test for 2x 2 square pooling layer
void TestForwardSquare() {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
const int num = 2;
const int channels = 2;
blob_bottom_->Reshape(num, channels, 3, 5);
// Input: 2x 2 channels of:
// [1 2 5 2 3]
// [9 4 1 4 8]
// [1 2 5 2 3]
for (int i = 0; i < 15 * num * channels; i += 15) {
blob_bottom_->mutable_cpu_data()[i + 0] = 1;
blob_bottom_->mutable_cpu_data()[i + 1] = 2;
blob_bottom_->mutable_cpu_data()[i + 2] = 5;
blob_bottom_->mutable_cpu_data()[i + 3] = 2;
blob_bottom_->mutable_cpu_data()[i + 4] = 3;
blob_bottom_->mutable_cpu_data()[i + 5] = 9;
blob_bottom_->mutable_cpu_data()[i + 6] = 4;
blob_bottom_->mutable_cpu_data()[i + 7] = 1;
blob_bottom_->mutable_cpu_data()[i + 8] = 4;
blob_bottom_->mutable_cpu_data()[i + 9] = 8;
blob_bottom_->mutable_cpu_data()[i + 10] = 1;
blob_bottom_->mutable_cpu_data()[i + 11] = 2;
blob_bottom_->mutable_cpu_data()[i + 12] = 5;
blob_bottom_->mutable_cpu_data()[i + 13] = 2;
blob_bottom_->mutable_cpu_data()[i + 14] = 3;
}
CuDNNPoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
EXPECT_EQ(blob_top_->height(), 2);
EXPECT_EQ(blob_top_->width(), 4);
if (blob_top_vec_.size() > 1) {
EXPECT_EQ(blob_top_mask_->num(), num);
EXPECT_EQ(blob_top_mask_->channels(), channels);
EXPECT_EQ(blob_top_mask_->height(), 2);
EXPECT_EQ(blob_top_mask_->width(), 4);
}
layer.Forward(blob_bottom_vec_, &blob_top_vec_);
// Expected output: 2x 2 channels of:
// [9 5 5 8]
// [9 5 5 8]
for (int i = 0; i < 8 * num * channels; i += 8) {
EXPECT_EQ(blob_top_->cpu_data()[i + 0], 9);
EXPECT_EQ(blob_top_->cpu_data()[i + 1], 5);
EXPECT_EQ(blob_top_->cpu_data()[i + 2], 5);
EXPECT_EQ(blob_top_->cpu_data()[i + 3], 8);
EXPECT_EQ(blob_top_->cpu_data()[i + 4], 9);
EXPECT_EQ(blob_top_->cpu_data()[i + 5], 5);
EXPECT_EQ(blob_top_->cpu_data()[i + 6], 5);
EXPECT_EQ(blob_top_->cpu_data()[i + 7], 8);
}
if (blob_top_vec_.size() > 1) {
// Expected mask output: 2x 2 channels of:
// [5 2 2 9]
// [5 12 12 9]
for (int i = 0; i < 8 * num * channels; i += 8) {
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 5);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 2);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 2);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 9);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 5);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 12);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 12);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 9);
}
}
}
// Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w
void TestForwardRectHigh() {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(3);
pooling_param->set_kernel_w(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
const int num = 2;
const int channels = 2;
blob_bottom_->Reshape(num, channels, 6, 6);
// Input: 2x 2 channels of:
// [35 1 6 26 19 24]
// [ 3 32 7 21 23 25]
// [31 9 2 22 27 20]
// [ 8 28 33 17 10 15]
// [30 5 34 12 14 16]
// [ 4 36 29 13 18 11]
// (this is generated by magic(6) in MATLAB)
for (int i = 0; i < 36 * num * channels; i += 36) {
blob_bottom_->mutable_cpu_data()[i + 0] = 35;
blob_bottom_->mutable_cpu_data()[i + 1] = 1;
blob_bottom_->mutable_cpu_data()[i + 2] = 6;
blob_bottom_->mutable_cpu_data()[i + 3] = 26;
blob_bottom_->mutable_cpu_data()[i + 4] = 19;
blob_bottom_->mutable_cpu_data()[i + 5] = 24;
blob_bottom_->mutable_cpu_data()[i + 6] = 3;
blob_bottom_->mutable_cpu_data()[i + 7] = 32;
blob_bottom_->mutable_cpu_data()[i + 8] = 7;
blob_bottom_->mutable_cpu_data()[i + 9] = 21;
blob_bottom_->mutable_cpu_data()[i + 10] = 23;
blob_bottom_->mutable_cpu_data()[i + 11] = 25;
blob_bottom_->mutable_cpu_data()[i + 12] = 31;
blob_bottom_->mutable_cpu_data()[i + 13] = 9;
blob_bottom_->mutable_cpu_data()[i + 14] = 2;
blob_bottom_->mutable_cpu_data()[i + 15] = 22;
blob_bottom_->mutable_cpu_data()[i + 16] = 27;
blob_bottom_->mutable_cpu_data()[i + 17] = 20;
blob_bottom_->mutable_cpu_data()[i + 18] = 8;
blob_bottom_->mutable_cpu_data()[i + 19] = 28;
blob_bottom_->mutable_cpu_data()[i + 20] = 33;
blob_bottom_->mutable_cpu_data()[i + 21] = 17;
blob_bottom_->mutable_cpu_data()[i + 22] = 10;
blob_bottom_->mutable_cpu_data()[i + 23] = 15;
blob_bottom_->mutable_cpu_data()[i + 24] = 30;
blob_bottom_->mutable_cpu_data()[i + 25] = 5;
blob_bottom_->mutable_cpu_data()[i + 26] = 34;
blob_bottom_->mutable_cpu_data()[i + 27] = 12;
blob_bottom_->mutable_cpu_data()[i + 28] = 14;
blob_bottom_->mutable_cpu_data()[i + 29] = 16;
blob_bottom_->mutable_cpu_data()[i + 30] = 4;
blob_bottom_->mutable_cpu_data()[i + 31] = 36;
blob_bottom_->mutable_cpu_data()[i + 32] = 29;
blob_bottom_->mutable_cpu_data()[i + 33] = 13;
blob_bottom_->mutable_cpu_data()[i + 34] = 18;
blob_bottom_->mutable_cpu_data()[i + 35] = 11;
}
CuDNNPoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
EXPECT_EQ(blob_top_->height(), 4);
EXPECT_EQ(blob_top_->width(), 5);
if (blob_top_vec_.size() > 1) {
EXPECT_EQ(blob_top_mask_->num(), num);
EXPECT_EQ(blob_top_mask_->channels(), channels);
EXPECT_EQ(blob_top_mask_->height(), 4);
EXPECT_EQ(blob_top_mask_->width(), 5);
}
layer.Forward(blob_bottom_vec_, &blob_top_vec_);
// Expected output: 2x 2 channels of:
// [35 32 26 27 27]
// [32 33 33 27 27]
// [31 34 34 27 27]
// [36 36 34 18 18]
for (int i = 0; i < 20 * num * channels; i += 20) {
EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35);
EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32);
EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26);
EXPECT_EQ(blob_top_->cpu_data()[i + 3], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 4], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32);
EXPECT_EQ(blob_top_->cpu_data()[i + 6], 33);
EXPECT_EQ(blob_top_->cpu_data()[i + 7], 33);
EXPECT_EQ(blob_top_->cpu_data()[i + 8], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 9], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31);
EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36);
EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18);
EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
}
if (blob_top_vec_.size() > 1) {
// [ 1 8 4 17 17]
// [ 8 21 21 17 17]
// [13 27 27 17 17]
// [32 32 27 35 35]
for (int i = 0; i < 20 * num * channels; i += 20) {
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 20);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 20);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
}
}
}
// Test for rectangular pooling layer with kernel_w > kernel_h
void TestForwardRectWide() {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(2);
pooling_param->set_kernel_w(3);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
const int num = 2;
const int channels = 2;
blob_bottom_->Reshape(num, channels, 6, 6);
// Input: 2x 2 channels of:
// [35 1 6 26 19 24]
// [ 3 32 7 21 23 25]
// [31 9 2 22 27 20]
// [ 8 28 33 17 10 15]
// [30 5 34 12 14 16]
// [ 4 36 29 13 18 11]
// (this is generated by magic(6) in MATLAB)
for (int i = 0; i < 36 * num * channels; i += 36) {
blob_bottom_->mutable_cpu_data()[i + 0] = 35;
blob_bottom_->mutable_cpu_data()[i + 1] = 1;
blob_bottom_->mutable_cpu_data()[i + 2] = 6;
blob_bottom_->mutable_cpu_data()[i + 3] = 26;
blob_bottom_->mutable_cpu_data()[i + 4] = 19;
blob_bottom_->mutable_cpu_data()[i + 5] = 24;
blob_bottom_->mutable_cpu_data()[i + 6] = 3;
blob_bottom_->mutable_cpu_data()[i + 7] = 32;
blob_bottom_->mutable_cpu_data()[i + 8] = 7;
blob_bottom_->mutable_cpu_data()[i + 9] = 21;
blob_bottom_->mutable_cpu_data()[i + 10] = 23;
blob_bottom_->mutable_cpu_data()[i + 11] = 25;
blob_bottom_->mutable_cpu_data()[i + 12] = 31;
blob_bottom_->mutable_cpu_data()[i + 13] = 9;
blob_bottom_->mutable_cpu_data()[i + 14] = 2;
blob_bottom_->mutable_cpu_data()[i + 15] = 22;
blob_bottom_->mutable_cpu_data()[i + 16] = 27;
blob_bottom_->mutable_cpu_data()[i + 17] = 20;
blob_bottom_->mutable_cpu_data()[i + 18] = 8;
blob_bottom_->mutable_cpu_data()[i + 19] = 28;
blob_bottom_->mutable_cpu_data()[i + 20] = 33;
blob_bottom_->mutable_cpu_data()[i + 21] = 17;
blob_bottom_->mutable_cpu_data()[i + 22] = 10;
blob_bottom_->mutable_cpu_data()[i + 23] = 15;
blob_bottom_->mutable_cpu_data()[i + 24] = 30;
blob_bottom_->mutable_cpu_data()[i + 25] = 5;
blob_bottom_->mutable_cpu_data()[i + 26] = 34;
blob_bottom_->mutable_cpu_data()[i + 27] = 12;
blob_bottom_->mutable_cpu_data()[i + 28] = 14;
blob_bottom_->mutable_cpu_data()[i + 29] = 16;
blob_bottom_->mutable_cpu_data()[i + 30] = 4;
blob_bottom_->mutable_cpu_data()[i + 31] = 36;
blob_bottom_->mutable_cpu_data()[i + 32] = 29;
blob_bottom_->mutable_cpu_data()[i + 33] = 13;
blob_bottom_->mutable_cpu_data()[i + 34] = 18;
blob_bottom_->mutable_cpu_data()[i + 35] = 11;
}
CuDNNPoolingLayer<Dtype> layer(layer_param);
layer.SetUp(blob_bottom_vec_, &blob_top_vec_);
EXPECT_EQ(blob_top_->num(), num);
EXPECT_EQ(blob_top_->channels(), channels);
EXPECT_EQ(blob_top_->height(), 5);
EXPECT_EQ(blob_top_->width(), 4);
if (blob_top_vec_.size() > 1) {
EXPECT_EQ(blob_top_mask_->num(), num);
EXPECT_EQ(blob_top_mask_->channels(), channels);
EXPECT_EQ(blob_top_mask_->height(), 5);
EXPECT_EQ(blob_top_mask_->width(), 4);
}
layer.Forward(blob_bottom_vec_, &blob_top_vec_);
// Expected output: 2x 2 channels of:
// [35 32 26 26]
// [32 32 27 27]
// [33 33 33 27]
// [34 34 34 17]
// [36 36 34 18]
for (int i = 0; i < 20 * num * channels; i += 20) {
EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35);
EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32);
EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26);
EXPECT_EQ(blob_top_->cpu_data()[i + 3], 26);
EXPECT_EQ(blob_top_->cpu_data()[i + 4], 32);
EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32);
EXPECT_EQ(blob_top_->cpu_data()[i + 6], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 7], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 8], 33);
EXPECT_EQ(blob_top_->cpu_data()[i + 9], 33);
EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33);
EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27);
EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17);
EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36);
EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34);
EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
}
if (blob_top_vec_.size() > 1) {
// [ 1 8 4 4]
// [ 8 8 17 17]
// [21 21 21 17]
// [27 27 27 22]
// [32 32 27 35]
for (int i = 0; i < 20 * num * channels; i += 20) {
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 3);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 7);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 20);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 20);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26);
EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
}
}
}
};
TYPED_TEST_CASE(CuDNNPoolingLayerTest, TestDtypes);
TYPED_TEST(CuDNNPoolingLayerTest, TestSetupCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
EXPECT_EQ(this->blob_top_->height(), 3);
EXPECT_EQ(this->blob_top_->width(), 2);
}
TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
pooling_param->set_pad(1);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
EXPECT_EQ(this->blob_top_->height(), 4);
EXPECT_EQ(this->blob_top_->width(), 3);
}
/*
TYPED_TEST(CuDNNPoolingLayerTest, PrintBackwardCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
layer_param.set_kernelsize(3);
layer_param.set_stride(2);
layer_param.set_pool(LayerParameter_PoolMethod_MAX);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
cout << "bottom data " << i << " " << this->blob_bottom_->cpu_data()[i] << endl;
}
for (int i = 0; i < this->blob_top_->count(); ++i) {
cout << "top data " << i << " " << this->blob_top_->cpu_data()[i] << endl;
}
for (int i = 0; i < this->blob_top_->count(); ++i) {
this->blob_top_->mutable_cpu_diff()[i] = i;
}
layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
cout << "bottom diff " << i << " " << this->blob_bottom_->cpu_diff()[i] << endl;
}
}
*/
TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxCuDNN) {
Caffe::set_mode(Caffe::GPU);
this->TestForwardSquare();
this->TestForwardRectHigh();
this->TestForwardRectWide();
}
TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxTopMaskCuDNN) {
Caffe::set_mode(Caffe::GPU);
this->blob_top_vec_.push_back(this->blob_top_mask_);
this->TestForwardSquare();
this->TestForwardRectHigh();
this->TestForwardRectWide();
}
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxCuDNN) {
Caffe::set_mode(Caffe::GPU);
for (int kernel_h = 3; kernel_h <= 4; kernel_h++) {
for (int kernel_w = 3; kernel_w <= 4; kernel_w++) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(kernel_h);
pooling_param->set_kernel_w(kernel_w);
pooling_param->set_stride(2);
pooling_param->set_pad(1);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-4, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
}
}
TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(2);
pooling_param->set_pad(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
this->blob_bottom_->Reshape(1, 1, 3, 3);
// Input:
// [ 1 2 4 ]
// [ 2 3 2 ]
// [ 4 2 1 ]
this->blob_bottom_->mutable_cpu_data()[0] = 1;
this->blob_bottom_->mutable_cpu_data()[1] = 2;
this->blob_bottom_->mutable_cpu_data()[2] = 4;
this->blob_bottom_->mutable_cpu_data()[3] = 2;
this->blob_bottom_->mutable_cpu_data()[4] = 3;
this->blob_bottom_->mutable_cpu_data()[5] = 2;
this->blob_bottom_->mutable_cpu_data()[6] = 4;
this->blob_bottom_->mutable_cpu_data()[7] = 2;
this->blob_bottom_->mutable_cpu_data()[8] = 1;
CuDNNPoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 1);
EXPECT_EQ(this->blob_top_->channels(), 1);
EXPECT_EQ(this->blob_top_->height(), 3);
EXPECT_EQ(this->blob_top_->width(), 3);
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
TypeParam epsilon = 1e-8;
// Output:
// [ 1 4 4 ]
// [ 4 4 4 ]
// [ 4 4 1 ]
EXPECT_NEAR(this->blob_top_->cpu_data()[0], 1, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[2], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[4], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[6], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon);
}
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxTopMaskCuDNN) {
Caffe::set_mode(Caffe::GPU);
for (int kernel_h = 3; kernel_h <= 4; kernel_h++) {
for (int kernel_w = 3; kernel_w <= 4; kernel_w++) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(kernel_h);
pooling_param->set_kernel_w(kernel_w);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
this->blob_top_vec_.push_back(this->blob_top_mask_);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-4, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
this->blob_top_vec_.pop_back();
}
}
}
TYPED_TEST(CuDNNPoolingLayerTest, TestForwardAveCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
pooling_param->set_stride(1);
pooling_param->set_pad(1);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
this->blob_bottom_->Reshape(1, 1, 3, 3);
FillerParameter filler_param;
filler_param.set_value(TypeParam(2));
ConstantFiller<TypeParam> filler(filler_param);
filler.Fill(this->blob_bottom_);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 1);
EXPECT_EQ(this->blob_top_->channels(), 1);
EXPECT_EQ(this->blob_top_->height(), 3);
EXPECT_EQ(this->blob_top_->width(), 3);
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
TypeParam epsilon = 1e-5;
EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0 , epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
}
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAveCuDNN) {
Caffe::set_mode(Caffe::GPU);
for (int kernel_h = 3; kernel_h <= 4; kernel_h++) {
for (int kernel_w = 3; kernel_w <= 4; kernel_w++) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(kernel_h);
pooling_param->set_kernel_w(kernel_w);
pooling_param->set_stride(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
}
}
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
for (int kernel_h = 3; kernel_h <= 4; kernel_h++) {
for (int kernel_w = 3; kernel_w <= 4; kernel_w++) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_h(kernel_h);
pooling_param->set_kernel_w(kernel_w);
pooling_param->set_stride(2);
pooling_param->set_pad(2);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
CuDNNPoolingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
}
}
#endif
} // namespace caffe

Просмотреть файл

@ -80,4 +80,72 @@ TYPED_TEST(SoftmaxLayerTest, TestGradient) {
&(this->blob_top_vec_));
}
#ifdef USE_CUDNN
template <typename Dtype>
class CuDNNSoftmaxLayerTest : public ::testing::Test {
protected:
CuDNNSoftmaxLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 10, 2, 3)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~CuDNNSoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; }
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
TYPED_TEST_CASE(CuDNNSoftmaxLayerTest, TestDtypes);
TYPED_TEST(CuDNNSoftmaxLayerTest, TestForwardCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNSoftmaxLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test sum
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
for (int k = 0; k < this->blob_bottom_->height(); ++k) {
for (int l = 0; l < this->blob_bottom_->width(); ++l) {
TypeParam sum = 0;
for (int j = 0; j < this->blob_top_->channels(); ++j) {
sum += this->blob_top_->data_at(i, j, k, l);
}
EXPECT_GE(sum, 0.999);
EXPECT_LE(sum, 1.001);
// Test exact values
TypeParam scale = 0;
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
scale += exp(this->blob_bottom_->data_at(i, j, k, l));
}
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
<< "debug: " << i << " " << j;
EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
<< "debug: " << i << " " << j;
}
}
}
}
}
TYPED_TEST(CuDNNSoftmaxLayerTest, TestGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
CuDNNSoftmaxLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}
#endif
} // namespace caffe

Просмотреть файл

@ -329,7 +329,12 @@ void caffe_gpu_powx<double>(const int N, const double* a,
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index])
- (x[index] < Dtype(0)));
#if CUDA_VERSION >= 6050
// __signbit to pick up the CUDA function.
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = __signbit(x[index]));
#else
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = signbit(x[index]));
#endif
__global__ void popc_kernel(const int n, const float* a,
const float* b, uint8_t* y) {