From 814c72947745bcea7a0a64f22adf7665d2228b10 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Fri, 25 Oct 2013 13:48:34 -0700 Subject: [PATCH] bnll --- include/caffe/vision_layers.hpp | 19 +++++++ src/caffe/layer_factory.cpp | 2 + src/caffe/layers/bnll_layer.cu | 89 +++++++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 src/caffe/layers/bnll_layer.cu diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 432d7ffe..a57badfc 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -45,6 +45,25 @@ class ReLULayer : public NeuronLayer { }; +template +class BNLLLayer : public NeuronLayer { + public: + explicit BNLLLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + + virtual Dtype Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); +}; + + template class DropoutLayer : public NeuronLayer { public: diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 6961bb3f..178607f4 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -21,6 +21,8 @@ Layer* GetLayer(const LayerParameter& param) { const std::string& type = param.type(); if (type == "accuracy") { return new AccuracyLayer(param); + } else if (type == "bnll") { + return new BNLLLayer(param); } else if (type == "conv") { return new ConvolutionLayer(param); } else if (type == "data") { diff --git a/src/caffe/layers/bnll_layer.cu b/src/caffe/layers/bnll_layer.cu new file mode 100644 index 00000000..c9a33ed5 --- /dev/null +++ b/src/caffe/layers/bnll_layer.cu @@ -0,0 +1,89 @@ +// Copyright 2013 Yangqing Jia + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include + +using std::max; + +namespace caffe { + +const float kBNLL_THRESHOLD = 50.; + +template +void BNLLLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + top_data[i] = log(1. + exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)))); + } +} + +template +Dtype BNLLLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const int count = (*bottom)[0]->count(); + for (int i = 0; i < count; ++i) { + Dtype expval = exp(min(bottom_data[index], Dtype(kBNLL_THRESHOLD))); + bottom_diff[index] = top_diff[index] * expval / (expval + 1.); + } + } + return Dtype(0); +} + +template +__global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD))); + } +} + +template +void BNLLLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + BNLLForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void BNLLBackward(const int n, const Dtype* in_diff, + const Dtype* in_data, Dtype* out_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + Dtype expval = exp(min(in_data[index], Dtype(kBNLL_THRESHOLD))); + out_diff[index] = in_diff[index] * expval / (expval + 1.); + } +} + +template +Dtype BNLLLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + const int count = (*bottom)[0]->count(); + BNLLBackward<<>>( + count, top_diff, bottom_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } + return Dtype(0); +} + +INSTANTIATE_CLASS(BNLLLayer); + + +} // namespace caffe