- add position-sensitive RoI pooling
- modify BN and scale layers to save memory
- add online hard example mining (OHEM)
- enrich Matlab interface
This commit is contained in:
Jifeng Dai 2016-07-25 15:27:12 +08:00
Родитель 827b78a868
Коммит 4cdcd00850
31 изменённых файлов: 1341 добавлений и 92 удалений

Просмотреть файл

@ -93,6 +93,10 @@ using std::string;
using std::stringstream;
using std::vector;
#ifdef _MSC_VER
#define snprintf _snprintf
#endif
// A global initialization function that you should call in your main function.
// Currently it initializes google flags and google logging.
void GlobalInit(int* pargc, char*** pargv);

Просмотреть файл

@ -318,6 +318,15 @@ class Layer {
inline Phase phase() { return phase_; }
/**
* @brief set phase
* enable train and test with one network, for saving memory
*/
virtual inline void set_phase(Phase phase){
phase_ = phase;
}
protected:
/** The protobuf that stores the layer parameters */

Просмотреть файл

@ -0,0 +1,57 @@
#ifndef CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_
#define CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/layers/loss_layer.hpp"
namespace caffe {
/**
* @brief BoxAnnotatorOHEMLayer: Annotate box labels for Online Hard Example Mining (OHEM) training
* R-FCN
* Written by Yi Li
*/
template <typename Dtype>
class BoxAnnotatorOHEMLayer :public Layer<Dtype>{
public:
explicit BoxAnnotatorOHEMLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "BoxAnnotatorOHEM"; }
virtual inline int ExactNumBottomBlobs() const { return 4; }
virtual inline int ExactNumTopBlobs() const { return 2; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
int num_;
int height_;
int width_;
int spatial_dim_;
int bbox_channels_;
int roi_per_img_;
int ignore_label_;
};
} // namespace caffe
#endif //CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_

Просмотреть файл

@ -13,7 +13,20 @@ namespace caffe {
/**
* @brief Perform position-sensitive max pooling on regions of interest specified by input, takes
* as input N position-sensitive score maps and a list of R regions of interest.
*
* ROIPoolingLayer takes 2 inputs and produces 1 output. bottom[0] is
* [N x (C x K^2) x H x W] position-sensitive score maps on which pooling is performed. bottom[1] is
* [R x 5] containing a list R ROI tuples with batch index and coordinates of
* regions of interest. Each row in bottom[1] is a ROI tuple in format
* [batch_index x1 y1 x2 y2], where batch_index corresponds to the index of
* instance in the first input and x1 y1 x2 y2 are 0-indexed coordinates
* of ROI rectangle (including its boundaries). The output top[0] is [R x C x K x K] score maps pooled
* within the ROI tuples.
* @param param provides PSROIPoolingParameter psroi_pooling_param,
* with PSROIPoolingLayer options:
* - output_dim. The pooled output channel number.
* - group_size. The number of groups to encode position-sensitive score maps
* - spatial_scale. Multiplicative spatial scale factor to translate ROI
* coordinates from their input scale to the scale used when pooling.
* R-FCN
* Written by Yi Li
*/
@ -21,43 +34,43 @@ namespace caffe {
template <typename Dtype>
class PSROIPoolingLayer : public Layer<Dtype> {
public:
explicit PSROIPoolingLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
explicit PSROIPoolingLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "PSROIPooling"; }
virtual inline const char* type() const { return "PSROIPooling"; }
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 2; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 1; }
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 2; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 1; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
Dtype spatial_scale_;
int output_dim_;
int group_size_;
Dtype spatial_scale_;
int output_dim_;
int group_size_;
int channels_;
int height_;
int width_;
int channels_;
int height_;
int width_;
int pooled_height_;
int pooled_width_;
Blob<int> mapping_channel_;
int pooled_height_;
int pooled_width_;
Blob<int> mapping_channel_;
};
} // namespace caffe
#endif // CAFFE_ROI_POOLING_LAYER_HPP_
#endif // CAFFE_PSROI_POOLING_LAYER_HPP_

Просмотреть файл

@ -0,0 +1,76 @@
#ifndef CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_
#define CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/layers/loss_layer.hpp"
namespace caffe {
/**
* @brief SmoothL1LossOHEMLayer
*
* R-FCN
* Written by Yi Li
*/
template <typename Dtype>
class SmoothL1LossOHEMLayer : public LossLayer<Dtype> {
public:
explicit SmoothL1LossOHEMLayer(const LayerParameter& param)
: LossLayer<Dtype>(param), diff_() {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "SmoothL1LossOHEM"; }
virtual inline int ExactNumBottomBlobs() const { return -1; }
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 3; }
virtual inline int ExactNumTopBlobs() const { return -1; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 2; }
/**
* Unlike most loss layers, in the SmoothL1LossOHEMLayer we can backpropagate
* to both inputs -- override to return true and always allow force_backward.
*/
virtual inline bool AllowForceBackward(const int bottom_index) const {
return true;
}
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
/// Read the normalization mode parameter and compute the normalizer based
/// on the blob size.
virtual Dtype get_normalizer(
LossParameter_NormalizationMode normalization_mode, Dtype pre_fixed_normalizer);
Blob<Dtype> diff_;
Blob<Dtype> errors_;
bool has_weights_;
int outer_num_, inner_num_;
/// How to normalize the output loss.
LossParameter_NormalizationMode normalization_;
};
} // namespace caffe
#endif // CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_

Просмотреть файл

@ -0,0 +1,132 @@
#ifndef CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_
#define CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/layers/loss_layer.hpp"
#include "caffe/layers/softmax_layer.hpp"
namespace caffe {
/**
* @brief Computes the multinomial logistic loss for a one-of-many
* classification task, passing real-valued predictions through a
* softmax to get a probability distribution over classes.
* An additional per-instance loss is produced in output for OHEM
*
* This layer should be preferred over separate
* SoftmaxLayer + MultinomialLogisticLossLayer
* as its gradient computation is more numerically stable.
* At test time, this layer can be replaced simply by a SoftmaxLayer.
*
* @param bottom input Blob vector (length 2)
* -# @f$ (N \times C \times H \times W) @f$
* the predictions @f$ x @f$, a Blob with values in
* @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of
* the @f$ K = CHW @f$ classes. This layer maps these scores to a
* probability distribution over classes using the softmax function
* @f$ \hat{p}_{nk} = \exp(x_{nk}) /
* \left[\sum_{k'} \exp(x_{nk'})\right] @f$ (see SoftmaxLayer).
* -# @f$ (N \times 1 \times 1 \times 1) @f$
* the labels @f$ l @f$, an integer-valued Blob with values
* @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
* indicating the correct class label among the @f$ K @f$ classes
* @param top output Blob vector (length 1)
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
* the computed cross-entropy classification loss: @f$ E =
* \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n})
* @f$, for softmax output class probabilites @f$ \hat{p} @f$
* @f$, per-instance cross-entropy classification loss
*/
template <typename Dtype>
class SoftmaxWithLossOHEMLayer : public LossLayer<Dtype> {
public:
/**
* @param param provides LossParameter loss_param, with options:
* - ignore_label (optional)
* Specify a label value that should be ignored when computing the loss.
* - normalize (optional, default true)
* If true, the loss is normalized by the number of (nonignored) labels
* present; otherwise the loss is simply summed over spatial locations.
*/
explicit SoftmaxWithLossOHEMLayer(const LayerParameter& param)
: LossLayer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "SoftmaxWithLoss"; }
virtual inline int ExactNumTopBlobs() const { return -1; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 3; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
/**
* @brief Computes the softmax loss error gradient w.r.t. the predictions.
*
* Gradients cannot be computed with respect to the label inputs (bottom[1]),
* so this method ignores bottom[1] and requires !propagate_down[1], crashing
* if propagate_down[1] is set.
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
* This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
* as @f$ \lambda @f$ is the coefficient of this layer's output
* @f$\ell_i@f$ in the overall Net loss
* @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
* @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
* (*Assuming that this top Blob is not used as a bottom (input) by any
* other layer of the Net.)
* @param propagate_down see Layer::Backward.
* propagate_down[1] must be false as we can't compute gradients with
* respect to the labels.
* @param bottom input Blob vector (length 2)
* -# @f$ (N \times C \times H \times W) @f$
* the predictions @f$ x @f$; Backward computes diff
* @f$ \frac{\partial E}{\partial x} @f$
* -# @f$ (N \times 1 \times 1 \times 1) @f$
* the labels -- ignored as we can't compute their error gradients
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
/// Read the normalization mode parameter and compute the normalizer based
/// on the blob size. If normalization_mode is VALID, the count of valid
/// outputs will be read from valid_count, unless it is -1 in which case
/// all outputs are assumed to be valid.
virtual Dtype get_normalizer(
LossParameter_NormalizationMode normalization_mode, int valid_count);
/// The internal SoftmaxLayer used to map predictions to a distribution.
shared_ptr<Layer<Dtype> > softmax_layer_;
/// prob stores the output probability predictions from the SoftmaxLayer.
Blob<Dtype> prob_;
/// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
vector<Blob<Dtype>*> softmax_bottom_vec_;
/// top vector holder used in call to the underlying SoftmaxLayer::Forward
vector<Blob<Dtype>*> softmax_top_vec_;
/// Whether to ignore instances with a certain label.
bool has_ignore_label_;
/// The label indicating that an instance should be ignored.
int ignore_label_;
/// How to normalize the output loss.
LossParameter_NormalizationMode normalization_;
int softmax_axis_, outer_num_, inner_num_;
};
} // namespace caffe
#endif // CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_

Просмотреть файл

@ -32,6 +32,10 @@ class Net {
/// @brief Initialize a network with a NetParameter.
void Init(const NetParameter& param);
/// @brief set phase
/// enable train and test with one network, for saving memory
void SetPhase(Phase phase);
/**
* @brief Run Forward and return the result.
*
@ -150,6 +154,14 @@ class Net {
inline const vector<vector<Blob<Dtype>*> >& top_vecs() const {
return top_vecs_;
}
inline const vector<vector<int> >& bottom_id_vecs() const {
return bottom_id_vecs_;
}
inline const vector<vector<int> >& top_id_vecs() const {
return top_id_vecs_;
}
/// @brief returns the ids of the top blobs of layer i
inline const vector<int> & top_ids(int i) const {
CHECK_GE(i, 0) << "Invalid layer id";

Просмотреть файл

@ -73,6 +73,7 @@ class Solver {
return test_nets_;
}
int iter() { return iter_; }
int max_iter() const { return param_.max_iter(); }
// Invoked at specific points during an iteration
class Callback {

Просмотреть файл

@ -33,6 +33,9 @@ classdef Blob < handle
diff = self.check_and_preprocess_data(diff);
caffe_('blob_set_diff', self.hBlob_self, diff);
end
function copy_data_from(self, blob)
caffe_('blob_copy_data', self.hBlob_self, blob.hBlob_self);
end
end
methods (Access = private)

Просмотреть файл

@ -25,6 +25,9 @@ classdef Layer < handle
self.params(n) = caffe.Blob(self.attributes.hBlob_blobs(n));
end
end
function set_params_data(self, blob_index, params)
caffe.Blob(self.attributes.hBlob_blobs(blob_index)).set_data(params);
end
function layer_type = type(self)
layer_type = caffe_('layer_get_type', self.hLayer_self);
end

Просмотреть файл

@ -21,6 +21,8 @@ classdef Net < handle
name2blob_index
layer_names
blob_names
bottom_id_vecs
top_id_vecs
end
methods
@ -67,6 +69,23 @@ classdef Net < handle
% expose layer_names and blob_names for public read access
self.layer_names = self.attributes.layer_names;
self.blob_names = self.attributes.blob_names;
% expose bottom_id_vecs and top_id_vecs for public read access
self.attributes.bottom_id_vecs = cellfun(@(x) x+1, self.attributes.bottom_id_vecs, 'UniformOutput', false);
self.bottom_id_vecs = self.attributes.bottom_id_vecs;
self.attributes.top_id_vecs = cellfun(@(x) x+1, self.attributes.top_id_vecs, 'UniformOutput', false);
self.top_id_vecs = self.attributes.top_id_vecs;
end
function set_phase(self, phase_name)
CHECK(ischar(phase_name), 'phase_name must be a string');
CHECK(strcmp(phase_name, 'train') || strcmp(phase_name, 'test'), ...
sprintf('phase_name can only be %strain%s or %stest%s', ...
char(39), char(39), char(39), char(39)));
caffe_('net_set_phase', self.hNet_self, phase_name);
end
function share_weights_with(self, net)
CHECK(is_valid_handle(net.hNet_net), 'invalid Net handle');
caffe_('net_share_trained_layers_with', self.hNet_net, net.hNet_net);
end
function layer = layers(self, layer_name)
CHECK(ischar(layer_name), 'layer_name must be a string');
@ -81,18 +100,43 @@ classdef Net < handle
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index);
end
function set_params_data(self, layer_name, blob_index, data)
CHECK(ischar(layer_name), 'layer_name must be a string');
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
self.layer_vec(self.name2layer_index(layer_name)).set_params_data(blob_index, data);
end
function forward_prefilled(self)
caffe_('net_forward', self.hNet_self);
end
function backward_prefilled(self)
caffe_('net_backward', self.hNet_self);
end
function set_input_data(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
'input data cell length must match input blob number');
% copy data to input blobs
for n = 1:length(self.inputs)
self.blobs(self.inputs{n}).set_data(input_data{n});
end
end
function res = get_output(self)
% get onput blobs
res = struct('blob_name', '', 'data', []);
for n = 1:length(self.outputs)
res(n).blob_name = self.outputs{n};
res(n).data = self.blobs(self.outputs{n}).get_data();
end
end
function res = forward(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
'input data cell length must match input blob number');
% copy data to input blobs
for n = 1:length(self.inputs)
if isempty(input_data{n})
continue;
end
self.blobs(self.inputs{n}).set_data(input_data{n});
end
self.forward_prefilled();
@ -125,6 +169,21 @@ classdef Net < handle
function reshape(self)
caffe_('net_reshape', self.hNet_self);
end
function reshape_as_input(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
'input data cell length must match input blob number');
% reshape input blobs
for n = 1:length(self.inputs)
if isempty(input_data{n})
continue;
end
input_data_size = size(input_data{n});
input_data_size_extended = [input_data_size, ones(1, 4 - length(input_data_size))];
self.blobs(self.inputs{n}).reshape(input_data_size_extended);
end
self.reshape();
end
function save(self, weights_file)
CHECK(ischar(weights_file), 'weights_file must be a string');
caffe_('net_save', self.hNet_self, weights_file);

Просмотреть файл

@ -39,6 +39,9 @@ classdef Solver < handle
function iter = iter(self)
iter = caffe_('solver_get_iter', self.hSolver_self);
end
function max_iter = max_iter(self)
max_iter = caffe_('solver_get_max_iter', self.hSolver_self);
end
function restore(self, snapshot_filename)
CHECK(ischar(snapshot_filename), 'snapshot_filename must be a string');
CHECK_FILE_EXIST(snapshot_filename);

15
matlab/+caffe/init_log.m Normal file
Просмотреть файл

@ -0,0 +1,15 @@
function init_log(log_base_filename)
% init_log(log_base_filename)
% init Caffe's log
CHECK(ischar(log_base_filename) && ~isempty(log_base_filename), ...
'log_base_filename must be string');
[log_base_dir] = fileparts(log_base_filename);
if ~exist(log_base_dir, 'dir')
mkdir(log_base_dir);
end
caffe_('init_log', log_base_filename);
end

Просмотреть файл

@ -14,6 +14,7 @@
#include <vector>
#include "mex.h"
#include "gpu/mxGPUArray.h"
#include "caffe/caffe.hpp"
@ -63,9 +64,21 @@ enum WhichMemory { DATA, DIFF };
// Copy matlab array to Blob data or diff
static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
WhichMemory data_or_diff) {
mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
const float* mat_mem_ptr = NULL;
mxGPUArray const *mx_mat_gpu;
if (mxIsGPUArray(mx_mat)){
mxInitGPU();
mx_mat_gpu = mxGPUCreateFromMxArray(mx_mat);
mat_mem_ptr = reinterpret_cast<const float*>(mxGPUGetDataReadOnly(mx_mat_gpu));
mxCHECK(blob->count() == mxGPUGetNumberOfElements(mx_mat_gpu),
"number of elements in target blob doesn't match that in input mxArray");
const float* mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
}
else{
mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
"number of elements in target blob doesn't match that in input mxArray");
mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
}
float* blob_mem_ptr = NULL;
switch (Caffe::mode()) {
case Caffe::CPU:
@ -80,6 +93,10 @@ static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
mxERROR("Unknown Caffe mode");
}
caffe_copy(blob->count(), mat_mem_ptr, blob_mem_ptr);
if (mxIsGPUArray(mx_mat)){
mxGPUDestroyGPUArray(mx_mat_gpu);
}
}
// Copy Blob data or diff to matlab array
@ -123,6 +140,16 @@ static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
return mx_vec;
}
// Convert vector<vector<int> > to matlab cell of (row vector)s
static mxArray* int_vec_vec_to_mx_cell_vec(const vector<vector<int> >& int_vec_vec) {
mxArray* mx_cell_vec = mxCreateCellMatrix(int_vec_vec.size(), 1);
for (int i = 0; i < int_vec_vec.size(); i++){
mxSetCell(mx_cell_vec, i, int_vec_to_mx_vec(int_vec_vec[i]));
}
return mx_cell_vec;
}
// Convert vector<string> to matlab cell vector of strings
static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
@ -228,6 +255,14 @@ static void solver_get_iter(MEX_ARGS) {
plhs[0] = mxCreateDoubleScalar(solver->iter());
}
// Usage: caffe_('solver_get_max_iter', hSolver)
static void solver_get_max_iter(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
"Usage: caffe_('solver_get_max_iter', hSolver)");
Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
plhs[0] = mxCreateDoubleScalar(solver->max_iter());
}
// Usage: caffe_('solver_restore', hSolver, snapshot_file)
static void solver_restore(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
@ -278,14 +313,34 @@ static void get_net(MEX_ARGS) {
mxFree(phase_name);
}
// Usage: caffe_('net_set_phase', hNet, phase_name)
static void net_set_phase(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
"Usage: caffe_('net_set_phase', hNet, phase_name)");
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
char* phase_name = mxArrayToString(prhs[1]);
Phase phase;
if (strcmp(phase_name, "train") == 0) {
phase = TRAIN;
}
else if (strcmp(phase_name, "test") == 0) {
phase = TEST;
}
else {
mxERROR("Unknown phase");
}
net->SetPhase(phase);
mxFree(phase_name);
}
// Usage: caffe_('net_get_attr', hNet)
static void net_get_attr(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
"Usage: caffe_('net_get_attr', hNet)");
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
const int net_attr_num = 6;
const int net_attr_num = 8;
const char* net_attrs[net_attr_num] = { "hLayer_layers", "hBlob_blobs",
"input_blob_indices", "output_blob_indices", "layer_names", "blob_names"};
"input_blob_indices", "output_blob_indices", "layer_names", "blob_names", "bottom_id_vecs", "top_id_vecs" };
mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num,
net_attrs);
mxSetField(mx_net_attr, 0, "hLayer_layers",
@ -300,6 +355,10 @@ static void net_get_attr(MEX_ARGS) {
str_vec_to_mx_strcell(net->layer_names()));
mxSetField(mx_net_attr, 0, "blob_names",
str_vec_to_mx_strcell(net->blob_names()));
mxSetField(mx_net_attr, 0, "bottom_id_vecs",
int_vec_vec_to_mx_cell_vec(net->bottom_id_vecs()));
mxSetField(mx_net_attr, 0, "top_id_vecs",
int_vec_vec_to_mx_cell_vec(net->top_id_vecs()));
plhs[0] = mx_net_attr;
}
@ -330,6 +389,15 @@ static void net_copy_from(MEX_ARGS) {
mxFree(weights_file);
}
// Usage: caffe_('net_shared_with', hNet, hNet_trained)
static void net_share_trained_layers_with(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsStruct(prhs[1]),
"Usage: caffe_('net_shared_with', hNet, hNet_trained)");
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
Net<float>* net_trained = handle_to_ptr<Net<float> >(prhs[1]);
net->ShareTrainedLayersWith(net_trained);
}
// Usage: caffe_('net_reshape', hNet)
static void net_reshape(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@ -413,12 +481,24 @@ static void blob_get_data(MEX_ARGS) {
// Usage: caffe_('blob_set_data', hBlob, new_data)
static void blob_set_data(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && (mxIsSingle(prhs[1]) || mxIsGPUArray(prhs[1])),
"Usage: caffe_('blob_set_data', hBlob, new_data)");
Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
mx_mat_to_blob(prhs[1], blob, DATA);
}
// Usage: caffe_('blob_copy_data', hBlob_to, hBlob_from)
static void blob_copy_data(MEX_ARGS) {
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsStruct(prhs[1]),
"Usage: caffe_('blob_copy_data', hBlob_to, hBlob_from)");
Blob<float>* blob_to = handle_to_ptr<Blob<float> >(prhs[0]);
Blob<float>* blob_from = handle_to_ptr<Blob<float> >(prhs[1]);
//mxCHECK(blob_from->count() == blob_to->count(),
// "number of elements in target blob doesn't match that in source blob");
blob_to->CopyFrom(*blob_from, false, true);
}
// Usage: caffe_('blob_get_diff', hBlob)
static void blob_get_diff(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@ -473,6 +553,54 @@ static void reset(MEX_ARGS) {
init_key = static_cast<double>(caffe_rng_rand());
}
// Usage: caffe_('set_random_seed', random_seed)
static void set_random_seed(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsDouble(prhs[0]),
"Usage: caffe_('set_random_seed', random_seed)");
int random_seed = static_cast<unsigned int>(mxGetScalar(prhs[0]));
Caffe::set_random_seed(random_seed);
}
static void glog_failure_handler(){
static bool is_glog_failure = false;
if (!is_glog_failure)
{
is_glog_failure = true;
::google::FlushLogFiles(0);
mexErrMsgTxt("glog check error, please check log and clear mex");
}
}
static void protobuf_log_handler(::google::protobuf::LogLevel level, const char* filename, int line,
const std::string& message)
{
const int max_err_length = 512;
char err_message[max_err_length];
snprintf(err_message, max_err_length, "Protobuf : %s . at %s Line %d",
message.c_str(), filename, line);
LOG(INFO) << err_message;
::google::FlushLogFiles(0);
mexErrMsgTxt(err_message);
}
// Usage: caffe_('init_log', log_base_filename)
static void init_log(MEX_ARGS) {
static bool is_log_inited = false;
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
"Usage: caffe_('init_log', log_dir)");
if (is_log_inited)
::google::ShutdownGoogleLogging();
char* log_base_filename = mxArrayToString(prhs[0]);
::google::SetLogDestination(0, log_base_filename);
mxFree(log_base_filename);
::google::protobuf::SetLogHandler(&protobuf_log_handler);
::google::InitGoogleLogging("caffe_mex");
::google::InstallFailureFunction(&glog_failure_handler);
is_log_inited = true;
}
// Usage: caffe_('read_mean', mean_proto_file)
static void read_mean(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
@ -528,37 +656,43 @@ struct handler_registry {
static handler_registry handlers[] = {
// Public API functions
{ "get_solver", get_solver },
{ "solver_get_attr", solver_get_attr },
{ "solver_get_iter", solver_get_iter },
{ "solver_restore", solver_restore },
{ "solver_solve", solver_solve },
{ "solver_step", solver_step },
{ "get_net", get_net },
{ "net_get_attr", net_get_attr },
{ "net_forward", net_forward },
{ "net_backward", net_backward },
{ "net_copy_from", net_copy_from },
{ "net_reshape", net_reshape },
{ "net_save", net_save },
{ "layer_get_attr", layer_get_attr },
{ "layer_get_type", layer_get_type },
{ "blob_get_shape", blob_get_shape },
{ "blob_reshape", blob_reshape },
{ "blob_get_data", blob_get_data },
{ "blob_set_data", blob_set_data },
{ "blob_get_diff", blob_get_diff },
{ "blob_set_diff", blob_set_diff },
{ "set_mode_cpu", set_mode_cpu },
{ "set_mode_gpu", set_mode_gpu },
{ "set_device", set_device },
{ "get_init_key", get_init_key },
{ "reset", reset },
{ "read_mean", read_mean },
{ "write_mean", write_mean },
{ "version", version },
{ "get_solver", get_solver },
{ "solver_get_attr", solver_get_attr },
{ "solver_get_iter", solver_get_iter },
{ "solver_get_max_iter", solver_get_max_iter },
{ "solver_restore", solver_restore },
{ "solver_solve", solver_solve },
{ "solver_step", solver_step },
{ "get_net", get_net },
{ "net_get_attr", net_get_attr },
{ "net_set_phase", net_set_phase },
{ "net_forward", net_forward },
{ "net_backward", net_backward },
{ "net_copy_from", net_copy_from },
{ "net_share_trained_layers_with", net_share_trained_layers_with },
{ "net_reshape", net_reshape },
{ "net_save", net_save },
{ "layer_get_attr", layer_get_attr },
{ "layer_get_type", layer_get_type },
{ "blob_get_shape", blob_get_shape },
{ "blob_reshape", blob_reshape },
{ "blob_get_data", blob_get_data },
{ "blob_set_data", blob_set_data },
{ "blob_copy_data", blob_copy_data },
{ "blob_get_diff", blob_get_diff },
{ "blob_set_diff", blob_set_diff },
{ "set_mode_cpu", set_mode_cpu },
{ "set_mode_gpu", set_mode_gpu },
{ "set_device", set_device },
{ "set_random_seed", set_random_seed },
{ "get_init_key", get_init_key },
{ "init_log", init_log },
{ "reset", reset },
{ "read_mean", read_mean },
{ "write_mean", write_mean },
{ "version", version },
// The end.
{ "END", NULL },
{ "END", NULL },
};
/** -----------------------------------------------------------------

Просмотреть файл

@ -0,0 +1,11 @@
function set_random_seed(random_seed)
% set_random_seed(random_seed)
% set Caffe's random_seed
CHECK(isscalar(random_seed) && random_seed >= 0, ...
'random_seed must be non-negative integer');
random_seed = double(random_seed);
caffe_('set_random_seed', random_seed);
end

Просмотреть файл

@ -49,7 +49,7 @@ void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
variance_.Reshape(sz);
temp_.ReshapeLike(*bottom[0]);
if (use_global_stats_) {
x_norm_.ReshapeLike(*bottom[0]);
x_norm_.ReshapeLike(*bottom[0]);
}
sz[0]=bottom[0]->shape(0);
batch_sum_multiplier_.Reshape(sz);

Просмотреть файл

@ -86,8 +86,8 @@ void BatchNormLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// TODO(cdoersch): The caching is only needed because later in-place layers
// might clobber the data. Can we skip this if they won't?
if (!use_global_stats_) {
caffe_copy(x_norm_.count(), top_data,
x_norm_.mutable_gpu_data());
caffe_copy(x_norm_.count(), top_data,
x_norm_.mutable_gpu_data());
}
}
@ -99,13 +99,13 @@ void BatchNormLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (bottom[0] != top[0]) {
top_diff = top[0]->gpu_diff();
} else {
if (use_global_stats_) {
top_diff = top[0]->gpu_diff();
}
else {
caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
top_diff = x_norm_.gpu_diff();
}
if (use_global_stats_) {
top_diff = top[0]->gpu_diff();
}
else {
caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
top_diff = x_norm_.gpu_diff();
}
}
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
if (use_global_stats_) {

Просмотреть файл

@ -0,0 +1,85 @@
// ------------------------------------------------------------------
// R-FCN
// Copyright (c) 2016 Microsoft
// Licensed under The MIT License [see r-fcn/LICENSE for details]
// Written by Yi Li
// ------------------------------------------------------------------
#include <cfloat>
#include <string>
#include <utility>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/layers/box_annotator_ohem_layer.hpp"
#include "caffe/proto/caffe.pb.h"
using std::max;
using std::min;
using std::floor;
using std::ceil;
namespace caffe {
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
BoxAnnotatorOHEMParameter box_anno_param = this->layer_param_.box_annotator_ohem_param();
roi_per_img_ = box_anno_param.roi_per_img();
CHECK_GT(roi_per_img_, 0);
ignore_label_ = box_anno_param.ignore_label();
}
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
num_ = bottom[0]->num();
CHECK_EQ(5, bottom[0]->channels());
height_ = bottom[0]->height();
width_ = bottom[0]->width();
spatial_dim_ = height_*width_;
CHECK_EQ(bottom[1]->num(), num_);
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), height_);
CHECK_EQ(bottom[1]->width(), width_);
CHECK_EQ(bottom[2]->num(), num_);
CHECK_EQ(bottom[2]->channels(), 1);
CHECK_EQ(bottom[2]->height(), height_);
CHECK_EQ(bottom[2]->width(), width_);
CHECK_EQ(bottom[3]->num(), num_);
bbox_channels_ = bottom[3]->channels();
CHECK_EQ(bottom[3]->height(), height_);
CHECK_EQ(bottom[3]->width(), width_);
// Labels for scoring
top[0]->Reshape(num_, 1, height_, width_);
// Loss weights for bbox regression
top[1]->Reshape(num_, bbox_channels_, height_, width_);
}
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
NOT_IMPLEMENTED;
}
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}
#ifdef CPU_ONLY
STUB_GPU(BoxAnnotatorOHEMLayer);
#endif
INSTANTIATE_CLASS(BoxAnnotatorOHEMLayer);
REGISTER_LAYER_CLASS(BoxAnnotatorOHEM);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,79 @@
// ------------------------------------------------------------------
// R-FCN
// Copyright (c) 2016 Microsoft
// Licensed under The MIT License [see r-fcn/LICENSE for details]
// Written by Yi Li
// ------------------------------------------------------------------
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layers/box_annotator_ohem_layer.hpp"
using std::max;
using std::min;
namespace caffe {
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_rois = bottom[0]->cpu_data();
const Dtype* bottom_loss = bottom[1]->cpu_data();
const Dtype* bottom_labels = bottom[2]->cpu_data();
const Dtype* bottom_bbox_loss_weights = bottom[3]->cpu_data();
Dtype* top_labels = top[0]->mutable_cpu_data();
Dtype* top_bbox_loss_weights = top[1]->mutable_cpu_data();
caffe_set(top[0]->count(), Dtype(ignore_label_), top_labels);
caffe_set(top[1]->count(), Dtype(0), top_bbox_loss_weights);
int num_rois_ = bottom[1]->count();
int num_imgs = -1;
for (int n = 0; n < num_rois_; n++){
for (int s = 0; s < spatial_dim_; s++){
num_imgs = bottom_rois[0]>num_imgs ? bottom_rois[0] : num_imgs;
bottom_rois++;
}
bottom_rois += (5-1)*spatial_dim_;
}
num_imgs++;
CHECK_GT(num_imgs, 0)
<< "number of images must be greater than 0 at BoxAnnotatorOHEMLayer";
bottom_rois = bottom[0]->cpu_data();
// Find rois with max loss
vector<int> sorted_idx(num_rois_);
for (int i = 0; i < num_rois_; i++){
sorted_idx[i] = i;
}
std::sort(sorted_idx.begin(), sorted_idx.end(),
[bottom_loss](int i1, int i2){return bottom_loss[i1] > bottom_loss[i2]; });
// Generate output labels for scoring and loss_weights for bbox regression
vector<int> number_left(num_imgs, roi_per_img_);
for (int i = 0; i < num_rois_; i++){
int index = sorted_idx[i];
int s = index % (width_*height_);
int n = index / (width_*height_);
int batch_ind = bottom_rois[n*5*spatial_dim_+s];
if (number_left[batch_ind]>0){
number_left[batch_ind]--;
top_labels[index] = bottom_labels[index];
for (int j = 0; j < bbox_channels_; j++){
int bbox_index = (n*bbox_channels_+j)*spatial_dim_+s;
top_bbox_loss_weights[bbox_index]=bottom_bbox_loss_weights[bbox_index];
}
}
}
}
template <typename Dtype>
void BoxAnnotatorOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
return;
}
INSTANTIATE_LAYER_GPU_FUNCS(BoxAnnotatorOHEMLayer);
} // namespace caffe

Просмотреть файл

@ -121,7 +121,7 @@ namespace caffe {
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
// [start, end) interval for spatial sampling
bottom_rois += n * 5;
int roi_batch_ind = bottom_rois[0];
Dtype roi_start_w = static_cast<Dtype>(round(bottom_rois[1])) * spatial_scale;

Просмотреть файл

@ -89,10 +89,10 @@ void ScaleLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
scale_dim_ = scale->count();
inner_dim_ = bottom[0]->count(axis_ + scale->num_axes());
if (bottom[0] == top[0]) { // in-place computation
const bool scale_param = (bottom.size() == 1);
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
temp_.ReshapeLike(*bottom[0]);
}
const bool scale_param = (bottom.size() == 1);
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
temp_.ReshapeLike(*bottom[0]);
}
} else {
top[0]->ReshapeLike(*bottom[0]);
}

Просмотреть файл

@ -36,11 +36,11 @@ void ScaleLayer<Dtype>::Forward_gpu(
// Note that this is only necessary for Backward; we could skip this if not
// doing Backward, but Caffe currently provides no way of knowing whether
// we'll need to do Backward at the time of the Forward call.
const bool scale_param = (bottom.size() == 1);
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
caffe_copy(bottom[0]->count(), bottom[0]->gpu_data(),
temp_.mutable_gpu_data());
}
const bool scale_param = (bottom.size() == 1);
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
caffe_copy(bottom[0]->count(), bottom[0]->gpu_data(),
temp_.mutable_gpu_data());
}
}
const Dtype* scale_data =
((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data();

Просмотреть файл

@ -0,0 +1,106 @@
// --------------------------------------------------------
// R-FCN
// Written by Yi Li, 2016.
// --------------------------------------------------------
#include <string>
#include <utility>
#include <vector>
#include "caffe/layers/smooth_l1_loss_ohem_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
has_weights_ = (bottom.size() == 3);
if (!this->layer_param_.loss_param().has_normalization() &&
this->layer_param_.loss_param().has_normalize()) {
normalization_ = this->layer_param_.loss_param().normalize() ?
LossParameter_NormalizationMode_VALID :
LossParameter_NormalizationMode_BATCH_SIZE;
}
else {
normalization_ = this->layer_param_.loss_param().normalization();
}
}
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
CHECK_EQ(bottom[0]->width(), bottom[1]->width());
if (has_weights_) {
CHECK_EQ(bottom[0]->channels(), bottom[2]->channels());
CHECK_EQ(bottom[0]->height(), bottom[2]->height());
CHECK_EQ(bottom[0]->width(), bottom[2]->width());
}
outer_num_ = bottom[0]->num();
inner_num_ = bottom[0]->height() * bottom[0]->width();
diff_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
errors_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
// top[2] stores per-instance loss, which takes the shape of N*1*H*W
if (top.size()>=2) {
top[1]->Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
}
}
template <typename Dtype>
Dtype SmoothL1LossOHEMLayer<Dtype>::get_normalizer(
LossParameter_NormalizationMode normalization_mode, Dtype pre_fixed_normalizer) {
Dtype normalizer;
switch (normalization_mode) {
case LossParameter_NormalizationMode_FULL:
normalizer = Dtype(outer_num_ * inner_num_);
break;
case LossParameter_NormalizationMode_VALID:
normalizer = Dtype(outer_num_ * inner_num_);
break;
case LossParameter_NormalizationMode_BATCH_SIZE:
normalizer = Dtype(outer_num_);
break;
case LossParameter_NormalizationMode_PRE_FIXED:
normalizer = pre_fixed_normalizer;
break;
case LossParameter_NormalizationMode_NONE:
normalizer = Dtype(1);
break;
default:
LOG(FATAL) << "Unknown normalization mode: "
<< LossParameter_NormalizationMode_Name(normalization_mode);
}
// Some users will have no labels for some examples in order to 'turn off' a
// particular loss in a multi-task setup. The max prevents NaNs in that case.
return std::max(Dtype(1.0), normalizer);
}
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
NOT_IMPLEMENTED;
}
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}
#ifdef CPU_ONLY
STUB_GPU(SmoothL1LossOHEMLayer);
#endif
INSTANTIATE_CLASS(SmoothL1LossOHEMLayer);
REGISTER_LAYER_CLASS(SmoothL1LossOHEM);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,127 @@
// --------------------------------------------------------
// R-FCN
// Written by Yi Li, 2016.
// --------------------------------------------------------
#include <algorithm>
#include <cfloat>
#include <vector>
#include "thrust/device_vector.h"
#include "caffe/layers/smooth_l1_loss_ohem_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
__global__ void SmoothL1ForwardGPU(const int n, const Dtype* in, Dtype* out) {
// f(x) = 0.5 * x^2 if |x| < 1
// |x| - 0.5 otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1) {
out[index] = 0.5 * val * val;
}
else {
out[index] = abs_val - 0.5;
}
}
}
template <typename Dtype>
__global__ void kernel_channel_sum(const int num, const int channels,
const int spatial_dim, const Dtype* data, Dtype* channel_sum) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype sum = 0;
for (int c = 0; c < channels; ++c) {
sum += data[(n * channels + c) * spatial_dim + s];
}
channel_sum[index] = sum;
}
}
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
int count = bottom[0]->count();
caffe_gpu_sub(
count,
bottom[0]->gpu_data(),
bottom[1]->gpu_data(),
diff_.mutable_gpu_data()); // d := b0 - b1
if (has_weights_) {
caffe_gpu_mul(
count,
bottom[2]->gpu_data(),
diff_.gpu_data(),
diff_.mutable_gpu_data()); // d := w * (b0 - b1)
}
SmoothL1ForwardGPU<Dtype> << <CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS >> >(
count, diff_.gpu_data(), errors_.mutable_gpu_data());
CUDA_POST_KERNEL_CHECK;
Dtype loss;
caffe_gpu_asum(count, errors_.gpu_data(), &loss);
int spatial_dim = diff_.height() * diff_.width();
Dtype pre_fixed_normalizer = this->layer_param_.loss_param().pre_fixed_normalizer();
top[0]->mutable_cpu_data()[0] = loss / get_normalizer(normalization_,
pre_fixed_normalizer);
// Output per-instance loss
if (top.size() >= 2) {
kernel_channel_sum<Dtype> << <CAFFE_GET_BLOCKS(top[0]->count()), CAFFE_CUDA_NUM_THREADS >> >
(outer_num_, bottom[0]->channels(), inner_num_, errors_.gpu_data(),
top[1]->mutable_gpu_data());
}
}
template <typename Dtype>
__global__ void SmoothL1BackwardGPU(const int n, const Dtype* in, Dtype* out) {
// f'(x) = x if |x| < 1
// = sign(x) otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1) {
out[index] = val;
}
else {
out[index] = (Dtype(0) < val) - (val < Dtype(0));
}
}
}
template <typename Dtype>
void SmoothL1LossOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
int count = diff_.count();
SmoothL1BackwardGPU<Dtype> << <CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS >> >(
count, diff_.gpu_data(), diff_.mutable_gpu_data());
CUDA_POST_KERNEL_CHECK;
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
int spatial_dim = diff_.height() * diff_.width();
Dtype pre_fixed_normalizer = this->layer_param_.loss_param().pre_fixed_normalizer();
Dtype normalizer = get_normalizer(normalization_, pre_fixed_normalizer);
Dtype alpha = sign * top[0]->cpu_diff()[0] / normalizer;
caffe_gpu_axpby(
bottom[i]->count(), // count
alpha, // alpha
diff_.gpu_data(), // x
Dtype(0), // beta
bottom[i]->mutable_gpu_diff()); // y
}
}
}
INSTANTIATE_LAYER_GPU_FUNCS(SmoothL1LossOHEMLayer);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,114 @@
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layers/softmax_loss_ohem_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
LayerParameter softmax_param(this->layer_param_);
// Fix a bug which occurs with more than one output
softmax_param.clear_loss_weight();
softmax_param.set_type("Softmax");
softmax_layer_ = LayerRegistry<Dtype>::CreateLayer(softmax_param);
softmax_bottom_vec_.clear();
softmax_bottom_vec_.push_back(bottom[0]);
softmax_top_vec_.clear();
softmax_top_vec_.push_back(&prob_);
softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_);
has_ignore_label_ =
this->layer_param_.loss_param().has_ignore_label();
if (has_ignore_label_) {
ignore_label_ = this->layer_param_.loss_param().ignore_label();
}
if (!this->layer_param_.loss_param().has_normalization() &&
this->layer_param_.loss_param().has_normalize()) {
normalization_ = this->layer_param_.loss_param().normalize() ?
LossParameter_NormalizationMode_VALID :
LossParameter_NormalizationMode_BATCH_SIZE;
} else {
normalization_ = this->layer_param_.loss_param().normalization();
}
}
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::Reshape(bottom, top);
softmax_layer_->Reshape(softmax_bottom_vec_, softmax_top_vec_);
softmax_axis_ =
bottom[0]->CanonicalAxisIndex(this->layer_param_.softmax_param().axis());
outer_num_ = bottom[0]->count(0, softmax_axis_);
inner_num_ = bottom[0]->count(softmax_axis_ + 1);
CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
<< "Number of labels must match number of predictions; "
<< "e.g., if softmax axis == 1 and prediction shape is (N, C, H, W), "
<< "label count (number of labels) must be N*H*W, "
<< "with integer values in {0, 1, ..., C-1}.";
if (top.size() >= 2) {
// softmax output
top[1]->ReshapeLike(*bottom[0]);
}
// top[2] stores per-instance loss, which takes the shape of N*1*H*W
if (top.size() >= 3) {
top[2]->ReshapeLike(*bottom[1]);
}
}
template <typename Dtype>
Dtype SoftmaxWithLossOHEMLayer<Dtype>::get_normalizer(
LossParameter_NormalizationMode normalization_mode, int valid_count) {
Dtype normalizer;
switch (normalization_mode) {
case LossParameter_NormalizationMode_FULL:
normalizer = Dtype(outer_num_ * inner_num_);
break;
case LossParameter_NormalizationMode_VALID:
if (valid_count == -1) {
normalizer = Dtype(outer_num_ * inner_num_);
} else {
normalizer = Dtype(valid_count);
}
break;
case LossParameter_NormalizationMode_BATCH_SIZE:
normalizer = Dtype(outer_num_);
break;
case LossParameter_NormalizationMode_NONE:
normalizer = Dtype(1);
break;
default:
LOG(FATAL) << "Unknown normalization mode: "
<< LossParameter_NormalizationMode_Name(normalization_mode);
}
// Some users will have no labels for some examples in order to 'turn off' a
// particular loss in a multi-task setup. The max prevents NaNs in that case.
return std::max(Dtype(1.0), normalizer);
}
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
NOT_IMPLEMENTED;
}
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}
#ifdef CPU_ONLY
STUB_GPU(SoftmaxWithLossOHEMLayer);
#endif
INSTANTIATE_CLASS(SoftmaxWithLossOHEMLayer);
REGISTER_LAYER_CLASS(SoftmaxWithLossOHEM);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,136 @@
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layers/softmax_loss_ohem_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
__global__ void SoftmaxLossForwardGPU(const int nthreads,
const Dtype* prob_data, const Dtype* label, Dtype* loss,
const int num, const int dim, const int spatial_dim,
const bool has_ignore_label_, const int ignore_label_,
Dtype* counts) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int n = index / spatial_dim;
const int s = index % spatial_dim;
const int label_value = static_cast<int>(label[n * spatial_dim + s]);
if (has_ignore_label_ && label_value == ignore_label_) {
loss[index] = 0;
counts[index] = 0;
} else {
loss[index] = -log(max(prob_data[n * dim + label_value * spatial_dim + s],
Dtype(FLT_MIN)));
counts[index] = 1;
}
}
}
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
const Dtype* prob_data = prob_.gpu_data();
const Dtype* label = bottom[1]->gpu_data();
const int dim = prob_.count() / outer_num_;
const int nthreads = outer_num_ * inner_num_;
// Since this memory is not used for anything until it is overwritten
// on the backward pass, we use it here to avoid having to allocate new GPU
// memory to accumulate intermediate results in the kernel.
Dtype* loss_data = bottom[0]->mutable_gpu_diff();
// Similarly, this memory is never used elsewhere, and thus we can use it
// to avoid having to allocate additional GPU memory.
Dtype* counts = prob_.mutable_gpu_diff();
// NOLINT_NEXT_LINE(whitespace/operators)
SoftmaxLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, prob_data, label, loss_data,
outer_num_, dim, inner_num_, has_ignore_label_, ignore_label_, counts);
Dtype loss;
caffe_gpu_asum(nthreads, loss_data, &loss);
Dtype valid_count = -1;
// Only launch another CUDA kernel if we actually need the count of valid
// outputs.
if (normalization_ == LossParameter_NormalizationMode_VALID &&
has_ignore_label_) {
caffe_gpu_asum(nthreads, counts, &valid_count);
}
top[0]->mutable_cpu_data()[0] = loss / get_normalizer(normalization_,
valid_count);
if (top.size() >= 2) {
top[1]->ShareData(prob_);
}
if (top.size() >= 3) {
// Output per-instance loss
caffe_gpu_memcpy(top[2]->count() * sizeof(Dtype), loss_data, top[2]->mutable_gpu_data());
}
// Fix a bug, which happens when propagate_down[0] = false in backward
caffe_gpu_set(bottom[0]->count(), Dtype(0), bottom[0]->mutable_gpu_diff());
}
template <typename Dtype>
__global__ void SoftmaxLossBackwardGPU(const int nthreads, const Dtype* top,
const Dtype* label, Dtype* bottom_diff, const int num, const int dim,
const int spatial_dim, const bool has_ignore_label_,
const int ignore_label_, Dtype* counts) {
const int channels = dim / spatial_dim;
CUDA_KERNEL_LOOP(index, nthreads) {
const int n = index / spatial_dim;
const int s = index % spatial_dim;
const int label_value = static_cast<int>(label[n * spatial_dim + s]);
if (has_ignore_label_ && label_value == ignore_label_) {
for (int c = 0; c < channels; ++c) {
bottom_diff[n * dim + c * spatial_dim + s] = 0;
}
counts[index] = 0;
} else {
bottom_diff[n * dim + label_value * spatial_dim + s] -= 1;
counts[index] = 1;
}
}
}
template <typename Dtype>
void SoftmaxWithLossOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[1]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
if (propagate_down[0]) {
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const Dtype* prob_data = prob_.gpu_data();
const Dtype* top_data = top[0]->gpu_data();
caffe_gpu_memcpy(prob_.count() * sizeof(Dtype), prob_data, bottom_diff);
const Dtype* label = bottom[1]->gpu_data();
const int dim = prob_.count() / outer_num_;
const int nthreads = outer_num_ * inner_num_;
// Since this memory is never used for anything else,
// we use to to avoid allocating new GPU memory.
Dtype* counts = prob_.mutable_gpu_diff();
// NOLINT_NEXT_LINE(whitespace/operators)
SoftmaxLossBackwardGPU<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, top_data, label, bottom_diff,
outer_num_, dim, inner_num_, has_ignore_label_, ignore_label_, counts);
Dtype valid_count = -1;
// Only launch another CUDA kernel if we actually need the count of valid
// outputs.
if (normalization_ == LossParameter_NormalizationMode_VALID &&
has_ignore_label_) {
caffe_gpu_asum(nthreads, counts, &valid_count);
}
const Dtype loss_weight = top[0]->cpu_diff()[0] /
get_normalizer(normalization_, valid_count);
caffe_gpu_scal(prob_.count(), loss_weight , bottom_diff);
}
}
INSTANTIATE_LAYER_GPU_FUNCS(SoftmaxWithLossOHEMLayer);
} // namespace caffe

Просмотреть файл

@ -283,6 +283,16 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
LOG_IF(INFO, Caffe::root_solver()) << "Network initialization done.";
}
template <typename Dtype>
void Net<Dtype>::SetPhase(Phase phase){
// set all layers
for (int i = 0; i < layers_.size(); ++i){
layers_[i]->set_phase(phase);
}
// set net phase
phase_ = phase;
}
template <typename Dtype>
void Net<Dtype>::FilterNet(const NetParameter& param,
NetParameter* param_filtered) {
@ -761,7 +771,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
LOG(INFO) << "Ignoring source layer " << source_layer_name;
continue;
}
DLOG(INFO) << "Copying source layer " << source_layer_name;
LOG(INFO) << "Copying source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_[target_layer_id]->blobs();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())

Просмотреть файл

@ -306,7 +306,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 150 (last added: psroi_pooling_param)
// LayerParameter next available layer-specific ID: 151 (last added: box_annotator_ohem_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@ -361,6 +361,7 @@ message LayerParameter {
optional AccuracyParameter accuracy_param = 102;
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BoxAnnotatorOHEMParameter box_annotator_ohem_param = 150;
optional BiasParameter bias_param = 141;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
@ -389,7 +390,7 @@ message LayerParameter {
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional PReLUParameter prelu_param = 131;
optional PSROIPoolingParameter psroi_pooling_param = 149;
optional PSROIPoolingParameter psroi_pooling_param = 149;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
@ -450,14 +451,18 @@ message LossParameter {
VALID = 1;
// Divide by the batch size.
BATCH_SIZE = 2;
// Divide by pre-fixed normalizer
PRE_FIXED = 3;
// Do not normalize the loss.
NONE = 3;
NONE = 4;
}
optional NormalizationMode normalization = 3 [default = VALID];
// Deprecated. Ignored if normalization is specified. If normalization
// is not specified, then setting this to false will be equivalent to
// normalization = BATCH_SIZE to be consistent with previous behavior.
optional bool normalize = 2;
//pre-fixed normalizer
optional float pre_fixed_normalizer = 4 [default = 1];
}
// Messages that store parameters used by individual layer types follow, in
@ -514,6 +519,11 @@ message BatchNormParameter {
optional float eps = 3 [default = 1e-5];
}
message BoxAnnotatorOHEMParameter {
required uint32 roi_per_img = 1; // number of rois for training
optional int32 ignore_label = 2 [default = -1]; // ignore_label in scoring
}
message BiasParameter {
// The first axis of bottom[0] (the first input Blob) along which to apply
// bottom[1] (the second input Blob). May be negative to index from the end
@ -924,7 +934,7 @@ message PowerParameter {
message PSROIPoolingParameter {
required float spatial_scale = 1;
required int32 output_dim = 2; // output channel number
required int32 group_size = 3; // equal to pooled_size
required int32 group_size = 3; // number of groups to encode position-sensitive score maps
}
message PythonParameter {

Просмотреть файл

@ -106,6 +106,7 @@
<ClCompile Include="..\..\src\caffe\layers\batch_reindex_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\bias_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\bnll_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\concat_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\contrastive_loss_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\conv_layer.cpp" />
@ -148,6 +149,7 @@
<ClCompile Include="..\..\src\caffe\layers\pooling_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\power_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\prelu_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\reduction_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\relu_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\reshape_layer.cpp" />
@ -156,8 +158,11 @@
<ClCompile Include="..\..\src\caffe\layers\sigmoid_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\silence_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\slice_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\softmax_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\split_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\spp_layer.cpp" />
<ClCompile Include="..\..\src\caffe\layers\tanh_layer.cpp" />
@ -250,6 +255,7 @@
<ClInclude Include="..\..\include\caffe\layers\pooling_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\power_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\prelu_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\psroi_pooling_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\python_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\reduction_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\relu_layer.hpp" />
@ -259,6 +265,7 @@
<ClInclude Include="..\..\include\caffe\layers\sigmoid_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\silence_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\slice_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\smooth_l1_loss_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\softmax_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\softmax_loss_layer.hpp" />
<ClInclude Include="..\..\include\caffe\layers\split_layer.hpp" />
@ -354,6 +361,13 @@
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<ItemGroup>
<CudaCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cu" />
<CudaCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cu" />
<CudaCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cu" />
<CudaCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cu" />
<CudaCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cu" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<Import Project="$(SolutionDir)\CommonSettings.targets" />
<ImportGroup Label="ExtensionTargets">

Просмотреть файл

@ -336,6 +336,21 @@
<ClCompile Include="..\..\src\caffe\util\signal_handler.cpp">
<Filter>src\util</Filter>
</ClCompile>
<ClCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cpp">
<Filter>src\layers</Filter>
</ClCompile>
<ClCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cpp">
<Filter>src\layers</Filter>
</ClCompile>
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cpp">
<Filter>src\layers</Filter>
</ClCompile>
<ClCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cpp">
<Filter>src\layers</Filter>
</ClCompile>
<ClCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cpp">
<Filter>src\layers</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\..\include\caffe\proto\caffe.pb.h">
@ -638,6 +653,12 @@
<ClInclude Include="..\..\include\caffe\layers\scale_layer.hpp">
<Filter>include\layers</Filter>
</ClInclude>
<ClInclude Include="..\..\include\caffe\layers\psroi_pooling_layer.hpp">
<Filter>include\layers</Filter>
</ClInclude>
<ClInclude Include="..\..\include\caffe\layers\smooth_l1_loss_layer.hpp">
<Filter>include\layers</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<CudaCompile Include="..\..\src\caffe\layers\contrastive_loss_layer.cu">
@ -811,6 +832,21 @@
<CudaCompile Include="..\..\src\caffe\solvers\sgd_solver.cu">
<Filter>cu\solvers</Filter>
</CudaCompile>
<CudaCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cu">
<Filter>cu\layers</Filter>
</CudaCompile>
<CudaCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cu">
<Filter>cu\layers</Filter>
</CudaCompile>
<CudaCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cu">
<Filter>cu\layers</Filter>
</CudaCompile>
<CudaCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cu">
<Filter>cu\layers</Filter>
</CudaCompile>
<CudaCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cu">
<Filter>cu\layers</Filter>
</CudaCompile>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />

Просмотреть файл

@ -34,12 +34,12 @@
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Link>
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;libmat.lib;gpu.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Link>
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;libmat.lib;gpu.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup>