зеркало из https://github.com/microsoft/caffe.git
Extend Crop to N-D, changed CropParameter.
This commit is contained in:
Родитель
64e78bdc76
Коммит
952fd17e52
|
@ -41,9 +41,27 @@ class CropLayer : public Layer<Dtype> {
|
||||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||||
|
|
||||||
int crop_h_, crop_w_;
|
vector<int> offsets;
|
||||||
};
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
void crop_copy(const vector<Blob<Dtype>*>& bottom,
|
||||||
|
const vector<Blob<Dtype>*>& top,
|
||||||
|
const vector<int>& offsets,
|
||||||
|
vector<int> indices,
|
||||||
|
int cur_dim,
|
||||||
|
const Dtype* src_data,
|
||||||
|
Dtype* dest_data,
|
||||||
|
bool is_forward);
|
||||||
|
|
||||||
|
void crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||||
|
const vector<Blob<Dtype>*>& top,
|
||||||
|
const vector<int>& offsets,
|
||||||
|
vector<int> indices,
|
||||||
|
int cur_dim,
|
||||||
|
const Dtype* src_data,
|
||||||
|
Dtype* dest_data,
|
||||||
|
bool is_forward);
|
||||||
|
};
|
||||||
} // namespace caffe
|
} // namespace caffe
|
||||||
|
|
||||||
#endif // CAFFE_CROP_LAYER_HPP_
|
#endif // CAFFE_CROP_LAYER_HPP_
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
#include "caffe/layer.hpp"
|
#include "caffe/layer.hpp"
|
||||||
#include "caffe/layers/crop_layer.hpp"
|
#include "caffe/layers/crop_layer.hpp"
|
||||||
#include "caffe/net.hpp"
|
#include "caffe/net.hpp"
|
||||||
|
@ -13,40 +15,108 @@ namespace caffe {
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
void CropLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
void CropLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||||
const vector<Blob<Dtype>*>& top) {
|
const vector<Blob<Dtype>*>& top) {
|
||||||
const CropParameter& param = this->layer_param_.crop_param();
|
|
||||||
CHECK_EQ(bottom.size(), 2) << "Wrong number of bottom blobs.";
|
CHECK_EQ(bottom.size(), 2) << "Wrong number of bottom blobs.";
|
||||||
CHECK_EQ(bottom[0]->num_axes(), 4) << "Only works with 4D blobs.";
|
// parameter setup moved to Reshape because it depends on size.
|
||||||
CHECK_EQ(bottom[1]->num_axes(), 4) << "Only works with 4D blobs.";
|
|
||||||
crop_h_ = param.offset_height();
|
|
||||||
crop_w_ = param.offset_width();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
|
void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||||
const vector<Blob<Dtype>*>& top) {
|
const vector<Blob<Dtype>*>& top) {
|
||||||
// Check that the image we are cropping minus the margin is bigger than the
|
const CropParameter& param = this->layer_param_.crop_param();
|
||||||
// destination image.
|
// bottom[0] supplies the data
|
||||||
CHECK_GT(bottom[0]->height()-crop_h_, bottom[1]->height())
|
// bottom[1] supplies the size
|
||||||
<< "invalid offset";
|
int input_dim = bottom[0]->num_axes();
|
||||||
CHECK_GT(bottom[0]->width()-crop_w_, bottom[1]->width()) << "invalid offset";
|
CHECK_LT(param.axis(), input_dim) << "crop axis bigger than input dim";
|
||||||
top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[1]->height(),
|
// initialize all offsets to 0
|
||||||
bottom[1]->width());
|
offsets = vector<int>(input_dim, 0);
|
||||||
|
// initialize new shape to bottom[0]
|
||||||
|
vector<int> new_shape(bottom[0]->shape());
|
||||||
|
|
||||||
|
if (param.offset_size() > 1) {
|
||||||
|
// the number of crop values specified must be equal to the number
|
||||||
|
// of dimensions following axis
|
||||||
|
CHECK_EQ(param.axis() + param.offset_size(), input_dim)
|
||||||
|
<< "number of crop values specified must be equal to the number of "
|
||||||
|
<< "dimensions following axis.";
|
||||||
|
}
|
||||||
|
// apply crops
|
||||||
|
for (int i = 0; i < input_dim; ++i) {
|
||||||
|
int crop_offset = 0;
|
||||||
|
int new_size = bottom[0]->shape(i);
|
||||||
|
if (i >= param.axis() && param.offset_size() == 1) {
|
||||||
|
// if only one crop value is supplied, crop all dimensions after axis
|
||||||
|
// by this crop value
|
||||||
|
crop_offset = param.offset(0);
|
||||||
|
new_size = bottom[1]->shape(i);
|
||||||
|
} else if (i >= param.axis() && param.offset_size() > 1) {
|
||||||
|
// crop values specified must be equal to the number of dimensions
|
||||||
|
// following axis
|
||||||
|
crop_offset = param.offset(i - param.axis());
|
||||||
|
new_size = bottom[1]->shape(i);
|
||||||
|
}
|
||||||
|
// Check that the image we are cropping minus the margin is bigger
|
||||||
|
// than the destination image.
|
||||||
|
CHECK_GE(bottom[0]->shape(i) - crop_offset,
|
||||||
|
bottom[1]->shape(i))
|
||||||
|
<< "invalid crop parameters in dimension: " << i;
|
||||||
|
// Now set new size and offsets
|
||||||
|
new_shape[i] = new_size;
|
||||||
|
offsets[i] = crop_offset;
|
||||||
|
}
|
||||||
|
top[0]->Reshape(new_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// recursive copy function
|
||||||
|
template <typename Dtype>
|
||||||
|
void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
|
||||||
|
const vector<Blob<Dtype>*>& top,
|
||||||
|
const vector<int>& offsets,
|
||||||
|
vector<int> indices,
|
||||||
|
int cur_dim,
|
||||||
|
const Dtype* src_data,
|
||||||
|
Dtype* dest_data,
|
||||||
|
bool is_forward) {
|
||||||
|
if (cur_dim + 1 < top[0]->num_axes()) {
|
||||||
|
// We are not yet at the final dimension, call copy recursivley
|
||||||
|
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
|
||||||
|
indices[cur_dim] = i;
|
||||||
|
crop_copy(bottom, top, offsets, indices, cur_dim+1,
|
||||||
|
src_data, dest_data, is_forward);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We are at the last dimensions, which is stored continously in memory
|
||||||
|
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
|
||||||
|
// prepare index vector reduced(red) and with offsets(off)
|
||||||
|
std::vector<int> ind_red(cur_dim, 0);
|
||||||
|
std::vector<int> ind_off(cur_dim+1, 0);
|
||||||
|
for (int j = 0; j < cur_dim; ++j) {
|
||||||
|
ind_red[j] = indices[j];
|
||||||
|
ind_off[j] = indices[j] + offsets[j];
|
||||||
|
}
|
||||||
|
ind_off[cur_dim] = offsets[cur_dim];
|
||||||
|
// do the copy
|
||||||
|
if (is_forward) {
|
||||||
|
caffe_copy(top[0]->shape(cur_dim),
|
||||||
|
src_data + bottom[0]->offset(ind_off),
|
||||||
|
dest_data + top[0]->offset(ind_red));
|
||||||
|
} else {
|
||||||
|
// in the backwards pass the src_data is top_diff
|
||||||
|
// and the dest_data is bottom_diff
|
||||||
|
caffe_copy(top[0]->shape(cur_dim),
|
||||||
|
src_data + top[0]->offset(ind_red),
|
||||||
|
dest_data + bottom[0]->offset(ind_off));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||||
const vector<Blob<Dtype>*>& top) {
|
const vector<Blob<Dtype>*>& top) {
|
||||||
|
std::vector<int> indices(top[0]->num_axes(), 0);
|
||||||
const Dtype* bottom_data = bottom[0]->cpu_data();
|
const Dtype* bottom_data = bottom[0]->cpu_data();
|
||||||
Dtype* top_data = top[0]->mutable_cpu_data();
|
Dtype* top_data = top[0]->mutable_cpu_data();
|
||||||
for (int n = 0; n < top[0]->num(); ++n) {
|
crop_copy(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
|
||||||
for (int c = 0; c < top[0]->channels(); ++c) {
|
|
||||||
for (int h = 0; h < top[0]->height(); ++h) {
|
|
||||||
caffe_copy(top[0]->width(),
|
|
||||||
bottom_data + bottom[0]->offset(n, c, crop_h_ + h, crop_w_),
|
|
||||||
top_data + top[0]->offset(n, c, h));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
|
@ -54,17 +124,11 @@ void CropLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||||
const Dtype* top_diff = top[0]->cpu_diff();
|
const Dtype* top_diff = top[0]->cpu_diff();
|
||||||
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
|
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
|
||||||
|
|
||||||
if (propagate_down[0]) {
|
if (propagate_down[0]) {
|
||||||
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
|
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
|
||||||
for (int n = 0; n < top[0]->num(); ++n) {
|
std::vector<int> indices(top[0]->num_axes(), 0);
|
||||||
for (int c = 0; c < top[0]->channels(); ++c) {
|
crop_copy(bottom, top, offsets, indices, 0, top_diff, bottom_diff, false);
|
||||||
for (int h = 0; h < top[0]->height(); ++h) {
|
|
||||||
caffe_copy(top[0]->width(),
|
|
||||||
top_diff + top[0]->offset(n, c, h),
|
|
||||||
bottom_diff + bottom[0]->offset(n, c, crop_h_ + h, crop_w_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,19 +22,90 @@ __global__ void copy_kernel(const int n, const int height, const int width,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recursive copy function, this function is similar to crop_copy but loops
|
||||||
|
// over all but the last two dimensions. It is implemented this way to allow
|
||||||
|
// for ND cropping while still relying on a CUDA kernel for the innermost
|
||||||
|
// two dimensions for performance reasons.
|
||||||
|
// An alternative way to implement ND cropping relying more on the kernel
|
||||||
|
// would require passing offsets to the kernel, which is a bit problematic
|
||||||
|
// because it is of variable length. Since in the standard (N,C,W,H) case
|
||||||
|
// N,C are usually not cropped a speedup could be achieved by not looping
|
||||||
|
// the application of the copy_kernel around these dimensions.
|
||||||
|
template <typename Dtype>
|
||||||
|
void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||||
|
const vector<Blob<Dtype>*>& top,
|
||||||
|
const vector<int>& offsets,
|
||||||
|
vector<int> indices,
|
||||||
|
int cur_dim,
|
||||||
|
const Dtype* src_data,
|
||||||
|
Dtype* dest_data,
|
||||||
|
bool is_forward) {
|
||||||
|
if (cur_dim + 2 < top[0]->num_axes()) {
|
||||||
|
// We are not yet at the final dimension, call copy recursivley
|
||||||
|
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
|
||||||
|
indices[cur_dim] = i;
|
||||||
|
crop_copy_gpu(bottom, top, offsets, indices, cur_dim+1,
|
||||||
|
src_data, dest_data, is_forward);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We are at the last two dimensions, which are stored continously in memory
|
||||||
|
// With (N,C,H,W)
|
||||||
|
// (0,1,2,3) cur_dim -> H
|
||||||
|
// cur_dim+1 -> W
|
||||||
|
const int lines = top[0]->shape(cur_dim);
|
||||||
|
const int height = top[0]->shape(cur_dim);
|
||||||
|
const int width = top[0]->shape(cur_dim+1);
|
||||||
|
std::vector<int> ind_off(cur_dim+2, 0);
|
||||||
|
for (int j = 0; j < cur_dim; ++j) {
|
||||||
|
ind_off[j] = indices[j] + offsets[j];
|
||||||
|
}
|
||||||
|
ind_off[cur_dim] = offsets[cur_dim];
|
||||||
|
ind_off[cur_dim+1] = offsets[cur_dim+1];
|
||||||
|
// Compute copy strides
|
||||||
|
const int src_outer_stride =
|
||||||
|
bottom[0]->shape(cur_dim)*bottom[0]->shape(cur_dim+1);
|
||||||
|
const int src_inner_stride = bottom[0]->shape(cur_dim+1);
|
||||||
|
const int dest_outer_stride =
|
||||||
|
top[0]->shape(cur_dim)*top[0]->shape(cur_dim+1);
|
||||||
|
const int dest_inner_stride = top[0]->shape(cur_dim+1);
|
||||||
|
|
||||||
|
if (is_forward) {
|
||||||
|
const Dtype* bottom_data = bottom[0]->gpu_data() +
|
||||||
|
bottom[0]->offset(ind_off);
|
||||||
|
Dtype* top_data = top[0]->mutable_gpu_data() +
|
||||||
|
top[0]->offset(indices);
|
||||||
|
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||||
|
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
|
||||||
|
lines, height, width,
|
||||||
|
src_outer_stride, src_inner_stride,
|
||||||
|
dest_outer_stride, dest_inner_stride,
|
||||||
|
bottom_data, top_data);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
const Dtype* top_diff = top[0]->gpu_diff() +
|
||||||
|
top[0]->offset(indices);
|
||||||
|
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff() +
|
||||||
|
bottom[0]->offset(ind_off);
|
||||||
|
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||||
|
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
|
||||||
|
lines, height, width,
|
||||||
|
dest_outer_stride, dest_inner_stride,
|
||||||
|
src_outer_stride, src_inner_stride,
|
||||||
|
top_diff, bottom_diff);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||||
const vector<Blob<Dtype>*>& top) {
|
const vector<Blob<Dtype>*>& top) {
|
||||||
|
std::vector<int> indices(top[0]->num_axes(), 0);
|
||||||
|
// This works because crop_copy uses caffe_copy which calls cudaMemcpy.
|
||||||
|
// My intuition is that calling this thousands of times is probably less
|
||||||
|
// efficient than writing a custom kernel.
|
||||||
const Dtype* bottom_data = bottom[0]->gpu_data();
|
const Dtype* bottom_data = bottom[0]->gpu_data();
|
||||||
Dtype* top_data = top[0]->mutable_gpu_data();
|
Dtype* top_data = top[0]->mutable_gpu_data();
|
||||||
const int lines = top[0]->count() / top[0]->width();
|
crop_copy_gpu(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
|
||||||
|
|
||||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
|
||||||
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
|
|
||||||
lines, top[0]->height(), top[0]->width(),
|
|
||||||
bottom[0]->height() * bottom[0]->width(), bottom[0]->width(),
|
|
||||||
top[0]->height() * top[0]->width(), top[0]->width(),
|
|
||||||
bottom_data + bottom[0]->offset(0, 0, crop_h_, crop_w_), top_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
|
@ -42,16 +113,12 @@ void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||||
const Dtype* top_diff = top[0]->gpu_diff();
|
const Dtype* top_diff = top[0]->gpu_diff();
|
||||||
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
|
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
|
||||||
const int lines = top[0]->count() / top[0]->width();
|
|
||||||
|
|
||||||
if (propagate_down[0]) {
|
if (propagate_down[0]) {
|
||||||
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
|
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
|
||||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
std::vector<int> indices(top[0]->num_axes(), 0);
|
||||||
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
|
crop_copy_gpu(bottom, top, offsets, indices, 0, top_diff, bottom_diff,
|
||||||
lines, top[0]->height(), top[0]->width(),
|
false);
|
||||||
top[0]->height() * top[0]->width(), top[0]->width(),
|
|
||||||
bottom[0]->height() * bottom[0]->width(), bottom[0]->width(),
|
|
||||||
top_diff, bottom_diff + bottom[0]->offset(0, 0, crop_h_, crop_w_));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -600,10 +600,19 @@ message ConvolutionParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
message CropParameter {
|
message CropParameter {
|
||||||
// Assumes standard dimensions: ( N,C,H,W )
|
// To crop, elements of the first bottom are selected to fit the dimensions
|
||||||
// This could possibly be extended to use "optional BlobShape offsets"
|
// of the second, reference bottom. The crop is configured by
|
||||||
optional uint32 offset_height = 1[default = 0];
|
// - the crop `axis` to pick the dimensions for cropping
|
||||||
optional uint32 offset_width = 2[default = 0];
|
// - the crop `offset` to set the shift for all/each dimension
|
||||||
|
// to align the cropped bottom with the reference bottom.
|
||||||
|
// All dimensions up to but excluding `axis` are preserved, while
|
||||||
|
// the dimensions including and trailing `axis` are cropped.
|
||||||
|
// If only one `offset` is set, then all dimensions are offset by this amount.
|
||||||
|
// Otherwise, the number of offsets must equal the number of cropped axes to
|
||||||
|
// shift the crop in each dimension accordingly.
|
||||||
|
// Note: standard dimensions are N,C,H,W so the default is a spatial crop.
|
||||||
|
optional uint32 axis = 1 [default = 2];
|
||||||
|
repeated uint32 offset = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DataParameter {
|
message DataParameter {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче