This commit is contained in:
Yangqing Jia 2013-09-19 16:26:01 -07:00
Родитель 1e31fc50ce
Коммит 6f1de8bf5a
9 изменённых файлов: 130 добавлений и 4 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -2,6 +2,7 @@
*.slo *.slo
*.lo *.lo
*.o *.o
*.cuo
# Compiled Dynamic libraries # Compiled Dynamic libraries
*.so *.so

Просмотреть файл

@ -17,7 +17,7 @@ PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc} PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py} PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py}
CXX_OBJS := ${CXX_SRCS:.cpp=.o} CXX_OBJS := ${CXX_SRCS:.cpp=.o}
CU_OBJS := ${CU_SRCS:.cu=.o} CU_OBJS := ${CU_SRCS:.cu=.cuo}
PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o} PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS) OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
TEST_OBJS := ${TEST_SRCS:.cpp=.o} TEST_OBJS := ${TEST_SRCS:.cpp=.o}
@ -63,7 +63,7 @@ $(TEST_BINS): %.testbin : %.o
$(NAME): $(PROTO_GEN_CC) $(OBJS) $(NAME): $(PROTO_GEN_CC) $(OBJS)
$(LINK) -shared $(OBJS) -o $(NAME) $(LINK) -shared $(OBJS) -o $(NAME)
$(CU_OBJS): %.o: %.cu $(CU_OBJS): %.cuo: %.cu
$(NVCC) -c $< -o $@ $(NVCC) -c $< -o $@
$(PROTO_GEN_CC): $(PROTO_SRCS) $(PROTO_GEN_CC): $(PROTO_SRCS)

Просмотреть файл

@ -82,6 +82,7 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>( DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_, count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_,
top_data); top_data);
CUDA_POST_KERNEL_CHECK;
} else { } else {
CUDA_CHECK(cudaMemcpy(top_data, bottom_data, CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
count * sizeof(Dtype), cudaMemcpyDeviceToDevice)); count * sizeof(Dtype), cudaMemcpyDeviceToDevice));
@ -112,6 +113,7 @@ Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const int count = (*bottom)[0]->count(); const int count = (*bottom)[0]->count();
DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>( DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, uint_thres_, scale_, bottom_diff); count, top_diff, mask, uint_thres_, scale_, bottom_diff);
CUDA_POST_KERNEL_CHECK;
} }
return Dtype(0); return Dtype(0);
} }

Просмотреть файл

@ -1,6 +1,7 @@
#include "caffeine/layer.hpp" #include "caffeine/layer.hpp"
#include "caffeine/util/im2col.hpp" #include "caffeine/util/im2col.hpp"
#include "caffeine/vision_layers.hpp" #include "caffeine/vision_layers.hpp"
#include "caffeine/common.hpp"
namespace caffeine { namespace caffeine {
@ -29,6 +30,17 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
} }
} }
template <typename Dtype>
void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
for (int n = 0; n < bottom[0]->num(); ++n) {
im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
}
}
template <typename Dtype> template <typename Dtype>
Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) { const bool propagate_down, vector<Blob<Dtype>*>* bottom) {

Просмотреть файл

@ -126,6 +126,7 @@ Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>( PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_, count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_); PAD_);
CUDA_POST_KERNEL_CHECK;
} }
return Dtype(0); return Dtype(0);
} }

Просмотреть файл

@ -64,6 +64,21 @@ TYPED_TEST(Im2colLayerTest, TestCPU) {
} }
} }
TYPED_TEST(Im2colLayerTest, TestGPU) {
LayerParameter layer_param;
layer_param.set_kernelsize(3);
layer_param.set_stride(2);
Im2colLayer<TypeParam> layer(layer_param);
Caffeine::set_mode(Caffeine::GPU);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// We are lazy and will only check the top left block
for (int c = 0; c < 27; ++c) {
EXPECT_EQ(this->blob_bottom_->data_at(0, (c / 9), (c / 3) % 3, c % 3),
this->blob_top_->data_at(0, c, 0, 0));
}
}
TYPED_TEST(Im2colLayerTest, TestCPUGradient) { TYPED_TEST(Im2colLayerTest, TestCPUGradient) {
LayerParameter layer_param; LayerParameter layer_param;
layer_param.set_kernelsize(3); layer_param.set_kernelsize(3);

Просмотреть файл

@ -0,0 +1,87 @@
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "caffeine/common.hpp"
#include "caffeine/util/im2col.hpp"
namespace caffeine {
template <typename Dtype>
__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
const int height, const int width, const int ksize,
const int stride, const int height_col, const int width_col, Dtype* data_col) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * ksize * ksize;
int h_in = h_out * stride;
int w_in = w_out * stride;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize; ++i) {
for (int j = 0; j < ksize; ++j) {
*data_col = data_im[i * width + j];
data_col += height_col * width_col;
}
}
}
}
template <typename Dtype>
void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
Dtype* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int num_kernels = channels * height_col * width_col;
im2col_gpu_kernel<<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
num_kernels, data_im, height, width, ksize, stride, height_col, width_col,
data_col);
CUDA_POST_KERNEL_CHECK;
}
// Explicit instantiation
template void im2col_gpu<float>(const float* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
float* data_col);
template void im2col_gpu<double>(const double* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
double* data_col);
/*
template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int ksize, const int stride,
Dtype* data_im) {
memset(data_im, 0, sizeof(Dtype) * height * width * channels);
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
int channels_col = channels * ksize * ksize;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % ksize;
int h_offset = (c / ksize) % ksize;
int c_im = c / ksize / ksize;
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
data_im[(c_im * height + h * stride + h_offset) * width + w * stride
+ w_offset] += data_col[(c * height_col + h) * width_col + w];
}
}
}
}
// Explicit instantiation
template void col2im_gpu<float>(const float* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
float* data_im);
template void col2im_gpu<double>(const double* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
double* data_im);
*/
} // namespace caffeine

Просмотреть файл

@ -13,7 +13,15 @@ void col2im_cpu(const Dtype* data_col, const int channels,
const int height, const int width, const int psize, const int stride, const int height, const int width, const int psize, const int stride,
Dtype* data_im); Dtype* data_im);
template <typename Dtype>
void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
Dtype* data_col);
template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
Dtype* data_im);
} // namespace caffeine } // namespace caffeine

Просмотреть файл

@ -146,8 +146,8 @@ class Im2colLayer : public Layer<Dtype> {
protected: protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top); vector<Blob<Dtype>*>* top);
//virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// vector<Blob<Dtype>*>* top); vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top, virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom); const bool propagate_down, vector<Blob<Dtype>*>* bottom);
//virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top, //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,