This commit is contained in:
Yangqing Jia 2013-10-25 13:48:34 -07:00
Родитель 22f45e67ab
Коммит 814c729477
3 изменённых файлов: 110 добавлений и 0 удалений

Просмотреть файл

@ -45,6 +45,25 @@ class ReLULayer : public NeuronLayer<Dtype> {
};
template <typename Dtype>
class BNLLLayer : public NeuronLayer<Dtype> {
public:
explicit BNLLLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {}
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};
template <typename Dtype>
class DropoutLayer : public NeuronLayer<Dtype> {
public:

Просмотреть файл

@ -21,6 +21,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
const std::string& type = param.type();
if (type == "accuracy") {
return new AccuracyLayer<Dtype>(param);
} else if (type == "bnll") {
return new BNLLLayer<Dtype>(param);
} else if (type == "conv") {
return new ConvolutionLayer<Dtype>(param);
} else if (type == "data") {

Просмотреть файл

@ -0,0 +1,89 @@
// Copyright 2013 Yangqing Jia
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include <algorithm>
using std::max;
namespace caffe {
const float kBNLL_THRESHOLD = 50.;
template <typename Dtype>
void BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = log(1. + exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD))));
}
}
template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
for (int i = 0; i < count; ++i) {
Dtype expval = exp(min(bottom_data[index], Dtype(kBNLL_THRESHOLD)));
bottom_diff[index] = top_diff[index] * expval / (expval + 1.);
}
}
return Dtype(0);
}
template <typename Dtype>
__global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD)));
}
}
template <typename Dtype>
void BNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
BNLLForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void BNLLBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
Dtype expval = exp(min(in_data[index], Dtype(kBNLL_THRESHOLD)));
out_diff[index] = in_diff[index] * expval / (expval + 1.);
}
}
template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
BNLLBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff);
CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}
INSTANTIATE_CLASS(BNLLLayer);
} // namespace caffe