зеркало из https://github.com/microsoft/caffe.git
R-FCN change sets
- add position-sensitive RoI pooling - modify BN and scale layers to save memory - add online hard example mining (OHEM) - enrich Matlab interface
This commit is contained in:
Родитель
827b78a868
Коммит
4cdcd00850
|
@ -93,6 +93,10 @@ using std::string;
|
|||
using std::stringstream;
|
||||
using std::vector;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define snprintf _snprintf
|
||||
#endif
|
||||
|
||||
// A global initialization function that you should call in your main function.
|
||||
// Currently it initializes google flags and google logging.
|
||||
void GlobalInit(int* pargc, char*** pargv);
|
||||
|
|
|
@ -318,6 +318,15 @@ class Layer {
|
|||
|
||||
inline Phase phase() { return phase_; }
|
||||
|
||||
/**
|
||||
* @brief set phase
|
||||
* enable train and test with one network, for saving memory
|
||||
*/
|
||||
virtual inline void set_phase(Phase phase){
|
||||
phase_ = phase;
|
||||
}
|
||||
|
||||
|
||||
|
||||
protected:
|
||||
/** The protobuf that stores the layer parameters */
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
#ifndef CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_
|
||||
#define CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
#include "caffe/layers/loss_layer.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
/**
|
||||
* @brief BoxAnnotatorOHEMLayer: Annotate box labels for Online Hard Example Mining (OHEM) training
|
||||
* R-FCN
|
||||
* Written by Yi Li
|
||||
*/
|
||||
template <typename Dtype>
|
||||
class BoxAnnotatorOHEMLayer :public Layer<Dtype>{
|
||||
public:
|
||||
explicit BoxAnnotatorOHEMLayer(const LayerParameter& param)
|
||||
: Layer<Dtype>(param) {}
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
|
||||
virtual inline const char* type() const { return "BoxAnnotatorOHEM"; }
|
||||
|
||||
virtual inline int ExactNumBottomBlobs() const { return 4; }
|
||||
virtual inline int ExactNumTopBlobs() const { return 2; }
|
||||
|
||||
protected:
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
|
||||
int num_;
|
||||
int height_;
|
||||
int width_;
|
||||
int spatial_dim_;
|
||||
int bbox_channels_;
|
||||
|
||||
int roi_per_img_;
|
||||
int ignore_label_;
|
||||
};
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif //CAFFE_BOX_ANNOTATOR_OHEM_LAYER_HPP_
|
|
@ -13,7 +13,20 @@ namespace caffe {
|
|||
/**
|
||||
* @brief Perform position-sensitive max pooling on regions of interest specified by input, takes
|
||||
* as input N position-sensitive score maps and a list of R regions of interest.
|
||||
*
|
||||
* ROIPoolingLayer takes 2 inputs and produces 1 output. bottom[0] is
|
||||
* [N x (C x K^2) x H x W] position-sensitive score maps on which pooling is performed. bottom[1] is
|
||||
* [R x 5] containing a list R ROI tuples with batch index and coordinates of
|
||||
* regions of interest. Each row in bottom[1] is a ROI tuple in format
|
||||
* [batch_index x1 y1 x2 y2], where batch_index corresponds to the index of
|
||||
* instance in the first input and x1 y1 x2 y2 are 0-indexed coordinates
|
||||
* of ROI rectangle (including its boundaries). The output top[0] is [R x C x K x K] score maps pooled
|
||||
* within the ROI tuples.
|
||||
* @param param provides PSROIPoolingParameter psroi_pooling_param,
|
||||
* with PSROIPoolingLayer options:
|
||||
* - output_dim. The pooled output channel number.
|
||||
* - group_size. The number of groups to encode position-sensitive score maps
|
||||
* - spatial_scale. Multiplicative spatial scale factor to translate ROI
|
||||
* coordinates from their input scale to the scale used when pooling.
|
||||
* R-FCN
|
||||
* Written by Yi Li
|
||||
*/
|
||||
|
@ -21,43 +34,43 @@ namespace caffe {
|
|||
template <typename Dtype>
|
||||
class PSROIPoolingLayer : public Layer<Dtype> {
|
||||
public:
|
||||
explicit PSROIPoolingLayer(const LayerParameter& param)
|
||||
: Layer<Dtype>(param) {}
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
explicit PSROIPoolingLayer(const LayerParameter& param)
|
||||
: Layer<Dtype>(param) {}
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
|
||||
virtual inline const char* type() const { return "PSROIPooling"; }
|
||||
virtual inline const char* type() const { return "PSROIPooling"; }
|
||||
|
||||
virtual inline int MinBottomBlobs() const { return 2; }
|
||||
virtual inline int MaxBottomBlobs() const { return 2; }
|
||||
virtual inline int MinTopBlobs() const { return 1; }
|
||||
virtual inline int MaxTopBlobs() const { return 1; }
|
||||
virtual inline int MinBottomBlobs() const { return 2; }
|
||||
virtual inline int MaxBottomBlobs() const { return 2; }
|
||||
virtual inline int MinTopBlobs() const { return 1; }
|
||||
virtual inline int MaxTopBlobs() const { return 1; }
|
||||
|
||||
protected:
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
|
||||
Dtype spatial_scale_;
|
||||
int output_dim_;
|
||||
int group_size_;
|
||||
Dtype spatial_scale_;
|
||||
int output_dim_;
|
||||
int group_size_;
|
||||
|
||||
int channels_;
|
||||
int height_;
|
||||
int width_;
|
||||
int channels_;
|
||||
int height_;
|
||||
int width_;
|
||||
|
||||
int pooled_height_;
|
||||
int pooled_width_;
|
||||
Blob<int> mapping_channel_;
|
||||
int pooled_height_;
|
||||
int pooled_width_;
|
||||
Blob<int> mapping_channel_;
|
||||
};
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_ROI_POOLING_LAYER_HPP_
|
||||
#endif // CAFFE_PSROI_POOLING_LAYER_HPP_
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
#ifndef CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_
|
||||
#define CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
#include "caffe/layers/loss_layer.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
/**
|
||||
* @brief SmoothL1LossOHEMLayer
|
||||
*
|
||||
* R-FCN
|
||||
* Written by Yi Li
|
||||
*/
|
||||
template <typename Dtype>
|
||||
class SmoothL1LossOHEMLayer : public LossLayer<Dtype> {
|
||||
public:
|
||||
explicit SmoothL1LossOHEMLayer(const LayerParameter& param)
|
||||
: LossLayer<Dtype>(param), diff_() {}
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
|
||||
virtual inline const char* type() const { return "SmoothL1LossOHEM"; }
|
||||
|
||||
virtual inline int ExactNumBottomBlobs() const { return -1; }
|
||||
virtual inline int MinBottomBlobs() const { return 2; }
|
||||
virtual inline int MaxBottomBlobs() const { return 3; }
|
||||
virtual inline int ExactNumTopBlobs() const { return -1; }
|
||||
virtual inline int MinTopBlobs() const { return 1; }
|
||||
virtual inline int MaxTopBlobs() const { return 2; }
|
||||
|
||||
/**
|
||||
* Unlike most loss layers, in the SmoothL1LossOHEMLayer we can backpropagate
|
||||
* to both inputs -- override to return true and always allow force_backward.
|
||||
*/
|
||||
virtual inline bool AllowForceBackward(const int bottom_index) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
|
||||
/// Read the normalization mode parameter and compute the normalizer based
|
||||
/// on the blob size.
|
||||
virtual Dtype get_normalizer(
|
||||
LossParameter_NormalizationMode normalization_mode, Dtype pre_fixed_normalizer);
|
||||
|
||||
Blob<Dtype> diff_;
|
||||
Blob<Dtype> errors_;
|
||||
bool has_weights_;
|
||||
|
||||
int outer_num_, inner_num_;
|
||||
|
||||
/// How to normalize the output loss.
|
||||
LossParameter_NormalizationMode normalization_;
|
||||
};
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_SMOOTH_L1_LOSS_OHEM_LAYER_HPP_
|
|
@ -0,0 +1,132 @@
|
|||
#ifndef CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_
|
||||
#define CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
#include "caffe/layers/loss_layer.hpp"
|
||||
#include "caffe/layers/softmax_layer.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
/**
|
||||
* @brief Computes the multinomial logistic loss for a one-of-many
|
||||
* classification task, passing real-valued predictions through a
|
||||
* softmax to get a probability distribution over classes.
|
||||
* An additional per-instance loss is produced in output for OHEM
|
||||
*
|
||||
* This layer should be preferred over separate
|
||||
* SoftmaxLayer + MultinomialLogisticLossLayer
|
||||
* as its gradient computation is more numerically stable.
|
||||
* At test time, this layer can be replaced simply by a SoftmaxLayer.
|
||||
*
|
||||
* @param bottom input Blob vector (length 2)
|
||||
* -# @f$ (N \times C \times H \times W) @f$
|
||||
* the predictions @f$ x @f$, a Blob with values in
|
||||
* @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of
|
||||
* the @f$ K = CHW @f$ classes. This layer maps these scores to a
|
||||
* probability distribution over classes using the softmax function
|
||||
* @f$ \hat{p}_{nk} = \exp(x_{nk}) /
|
||||
* \left[\sum_{k'} \exp(x_{nk'})\right] @f$ (see SoftmaxLayer).
|
||||
* -# @f$ (N \times 1 \times 1 \times 1) @f$
|
||||
* the labels @f$ l @f$, an integer-valued Blob with values
|
||||
* @f$ l_n \in [0, 1, 2, ..., K - 1] @f$
|
||||
* indicating the correct class label among the @f$ K @f$ classes
|
||||
* @param top output Blob vector (length 1)
|
||||
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
|
||||
* the computed cross-entropy classification loss: @f$ E =
|
||||
* \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n})
|
||||
* @f$, for softmax output class probabilites @f$ \hat{p} @f$
|
||||
* @f$, per-instance cross-entropy classification loss
|
||||
*/
|
||||
template <typename Dtype>
|
||||
class SoftmaxWithLossOHEMLayer : public LossLayer<Dtype> {
|
||||
public:
|
||||
/**
|
||||
* @param param provides LossParameter loss_param, with options:
|
||||
* - ignore_label (optional)
|
||||
* Specify a label value that should be ignored when computing the loss.
|
||||
* - normalize (optional, default true)
|
||||
* If true, the loss is normalized by the number of (nonignored) labels
|
||||
* present; otherwise the loss is simply summed over spatial locations.
|
||||
*/
|
||||
explicit SoftmaxWithLossOHEMLayer(const LayerParameter& param)
|
||||
: LossLayer<Dtype>(param) {}
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
|
||||
virtual inline const char* type() const { return "SoftmaxWithLoss"; }
|
||||
virtual inline int ExactNumTopBlobs() const { return -1; }
|
||||
virtual inline int MinTopBlobs() const { return 1; }
|
||||
virtual inline int MaxTopBlobs() const { return 3; }
|
||||
|
||||
protected:
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top);
|
||||
/**
|
||||
* @brief Computes the softmax loss error gradient w.r.t. the predictions.
|
||||
*
|
||||
* Gradients cannot be computed with respect to the label inputs (bottom[1]),
|
||||
* so this method ignores bottom[1] and requires !propagate_down[1], crashing
|
||||
* if propagate_down[1] is set.
|
||||
*
|
||||
* @param top output Blob vector (length 1), providing the error gradient with
|
||||
* respect to the outputs
|
||||
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
|
||||
* This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
|
||||
* as @f$ \lambda @f$ is the coefficient of this layer's output
|
||||
* @f$\ell_i@f$ in the overall Net loss
|
||||
* @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
|
||||
* @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
|
||||
* (*Assuming that this top Blob is not used as a bottom (input) by any
|
||||
* other layer of the Net.)
|
||||
* @param propagate_down see Layer::Backward.
|
||||
* propagate_down[1] must be false as we can't compute gradients with
|
||||
* respect to the labels.
|
||||
* @param bottom input Blob vector (length 2)
|
||||
* -# @f$ (N \times C \times H \times W) @f$
|
||||
* the predictions @f$ x @f$; Backward computes diff
|
||||
* @f$ \frac{\partial E}{\partial x} @f$
|
||||
* -# @f$ (N \times 1 \times 1 \times 1) @f$
|
||||
* the labels -- ignored as we can't compute their error gradients
|
||||
*/
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
|
||||
|
||||
/// Read the normalization mode parameter and compute the normalizer based
|
||||
/// on the blob size. If normalization_mode is VALID, the count of valid
|
||||
/// outputs will be read from valid_count, unless it is -1 in which case
|
||||
/// all outputs are assumed to be valid.
|
||||
virtual Dtype get_normalizer(
|
||||
LossParameter_NormalizationMode normalization_mode, int valid_count);
|
||||
|
||||
/// The internal SoftmaxLayer used to map predictions to a distribution.
|
||||
shared_ptr<Layer<Dtype> > softmax_layer_;
|
||||
/// prob stores the output probability predictions from the SoftmaxLayer.
|
||||
Blob<Dtype> prob_;
|
||||
/// bottom vector holder used in call to the underlying SoftmaxLayer::Forward
|
||||
vector<Blob<Dtype>*> softmax_bottom_vec_;
|
||||
/// top vector holder used in call to the underlying SoftmaxLayer::Forward
|
||||
vector<Blob<Dtype>*> softmax_top_vec_;
|
||||
/// Whether to ignore instances with a certain label.
|
||||
bool has_ignore_label_;
|
||||
/// The label indicating that an instance should be ignored.
|
||||
int ignore_label_;
|
||||
/// How to normalize the output loss.
|
||||
LossParameter_NormalizationMode normalization_;
|
||||
|
||||
int softmax_axis_, outer_num_, inner_num_;
|
||||
};
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_SOFTMAX_WITH_LOSS_OHEM_LAYER_HPP_
|
|
@ -32,6 +32,10 @@ class Net {
|
|||
/// @brief Initialize a network with a NetParameter.
|
||||
void Init(const NetParameter& param);
|
||||
|
||||
/// @brief set phase
|
||||
/// enable train and test with one network, for saving memory
|
||||
void SetPhase(Phase phase);
|
||||
|
||||
/**
|
||||
* @brief Run Forward and return the result.
|
||||
*
|
||||
|
@ -150,6 +154,14 @@ class Net {
|
|||
inline const vector<vector<Blob<Dtype>*> >& top_vecs() const {
|
||||
return top_vecs_;
|
||||
}
|
||||
|
||||
inline const vector<vector<int> >& bottom_id_vecs() const {
|
||||
return bottom_id_vecs_;
|
||||
}
|
||||
inline const vector<vector<int> >& top_id_vecs() const {
|
||||
return top_id_vecs_;
|
||||
}
|
||||
|
||||
/// @brief returns the ids of the top blobs of layer i
|
||||
inline const vector<int> & top_ids(int i) const {
|
||||
CHECK_GE(i, 0) << "Invalid layer id";
|
||||
|
|
|
@ -73,6 +73,7 @@ class Solver {
|
|||
return test_nets_;
|
||||
}
|
||||
int iter() { return iter_; }
|
||||
int max_iter() const { return param_.max_iter(); }
|
||||
|
||||
// Invoked at specific points during an iteration
|
||||
class Callback {
|
||||
|
|
|
@ -33,6 +33,9 @@ classdef Blob < handle
|
|||
diff = self.check_and_preprocess_data(diff);
|
||||
caffe_('blob_set_diff', self.hBlob_self, diff);
|
||||
end
|
||||
function copy_data_from(self, blob)
|
||||
caffe_('blob_copy_data', self.hBlob_self, blob.hBlob_self);
|
||||
end
|
||||
end
|
||||
|
||||
methods (Access = private)
|
||||
|
|
|
@ -25,6 +25,9 @@ classdef Layer < handle
|
|||
self.params(n) = caffe.Blob(self.attributes.hBlob_blobs(n));
|
||||
end
|
||||
end
|
||||
function set_params_data(self, blob_index, params)
|
||||
caffe.Blob(self.attributes.hBlob_blobs(blob_index)).set_data(params);
|
||||
end
|
||||
function layer_type = type(self)
|
||||
layer_type = caffe_('layer_get_type', self.hLayer_self);
|
||||
end
|
||||
|
|
|
@ -21,6 +21,8 @@ classdef Net < handle
|
|||
name2blob_index
|
||||
layer_names
|
||||
blob_names
|
||||
bottom_id_vecs
|
||||
top_id_vecs
|
||||
end
|
||||
|
||||
methods
|
||||
|
@ -67,6 +69,23 @@ classdef Net < handle
|
|||
% expose layer_names and blob_names for public read access
|
||||
self.layer_names = self.attributes.layer_names;
|
||||
self.blob_names = self.attributes.blob_names;
|
||||
|
||||
% expose bottom_id_vecs and top_id_vecs for public read access
|
||||
self.attributes.bottom_id_vecs = cellfun(@(x) x+1, self.attributes.bottom_id_vecs, 'UniformOutput', false);
|
||||
self.bottom_id_vecs = self.attributes.bottom_id_vecs;
|
||||
self.attributes.top_id_vecs = cellfun(@(x) x+1, self.attributes.top_id_vecs, 'UniformOutput', false);
|
||||
self.top_id_vecs = self.attributes.top_id_vecs;
|
||||
end
|
||||
function set_phase(self, phase_name)
|
||||
CHECK(ischar(phase_name), 'phase_name must be a string');
|
||||
CHECK(strcmp(phase_name, 'train') || strcmp(phase_name, 'test'), ...
|
||||
sprintf('phase_name can only be %strain%s or %stest%s', ...
|
||||
char(39), char(39), char(39), char(39)));
|
||||
caffe_('net_set_phase', self.hNet_self, phase_name);
|
||||
end
|
||||
function share_weights_with(self, net)
|
||||
CHECK(is_valid_handle(net.hNet_net), 'invalid Net handle');
|
||||
caffe_('net_share_trained_layers_with', self.hNet_net, net.hNet_net);
|
||||
end
|
||||
function layer = layers(self, layer_name)
|
||||
CHECK(ischar(layer_name), 'layer_name must be a string');
|
||||
|
@ -81,18 +100,43 @@ classdef Net < handle
|
|||
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
|
||||
blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index);
|
||||
end
|
||||
function set_params_data(self, layer_name, blob_index, data)
|
||||
CHECK(ischar(layer_name), 'layer_name must be a string');
|
||||
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
|
||||
self.layer_vec(self.name2layer_index(layer_name)).set_params_data(blob_index, data);
|
||||
end
|
||||
function forward_prefilled(self)
|
||||
caffe_('net_forward', self.hNet_self);
|
||||
end
|
||||
function backward_prefilled(self)
|
||||
caffe_('net_backward', self.hNet_self);
|
||||
end
|
||||
function set_input_data(self, input_data)
|
||||
CHECK(iscell(input_data), 'input_data must be a cell array');
|
||||
CHECK(length(input_data) == length(self.inputs), ...
|
||||
'input data cell length must match input blob number');
|
||||
% copy data to input blobs
|
||||
for n = 1:length(self.inputs)
|
||||
self.blobs(self.inputs{n}).set_data(input_data{n});
|
||||
end
|
||||
end
|
||||
function res = get_output(self)
|
||||
% get onput blobs
|
||||
res = struct('blob_name', '', 'data', []);
|
||||
for n = 1:length(self.outputs)
|
||||
res(n).blob_name = self.outputs{n};
|
||||
res(n).data = self.blobs(self.outputs{n}).get_data();
|
||||
end
|
||||
end
|
||||
function res = forward(self, input_data)
|
||||
CHECK(iscell(input_data), 'input_data must be a cell array');
|
||||
CHECK(length(input_data) == length(self.inputs), ...
|
||||
'input data cell length must match input blob number');
|
||||
% copy data to input blobs
|
||||
for n = 1:length(self.inputs)
|
||||
if isempty(input_data{n})
|
||||
continue;
|
||||
end
|
||||
self.blobs(self.inputs{n}).set_data(input_data{n});
|
||||
end
|
||||
self.forward_prefilled();
|
||||
|
@ -125,6 +169,21 @@ classdef Net < handle
|
|||
function reshape(self)
|
||||
caffe_('net_reshape', self.hNet_self);
|
||||
end
|
||||
function reshape_as_input(self, input_data)
|
||||
CHECK(iscell(input_data), 'input_data must be a cell array');
|
||||
CHECK(length(input_data) == length(self.inputs), ...
|
||||
'input data cell length must match input blob number');
|
||||
% reshape input blobs
|
||||
for n = 1:length(self.inputs)
|
||||
if isempty(input_data{n})
|
||||
continue;
|
||||
end
|
||||
input_data_size = size(input_data{n});
|
||||
input_data_size_extended = [input_data_size, ones(1, 4 - length(input_data_size))];
|
||||
self.blobs(self.inputs{n}).reshape(input_data_size_extended);
|
||||
end
|
||||
self.reshape();
|
||||
end
|
||||
function save(self, weights_file)
|
||||
CHECK(ischar(weights_file), 'weights_file must be a string');
|
||||
caffe_('net_save', self.hNet_self, weights_file);
|
||||
|
|
|
@ -39,6 +39,9 @@ classdef Solver < handle
|
|||
function iter = iter(self)
|
||||
iter = caffe_('solver_get_iter', self.hSolver_self);
|
||||
end
|
||||
function max_iter = max_iter(self)
|
||||
max_iter = caffe_('solver_get_max_iter', self.hSolver_self);
|
||||
end
|
||||
function restore(self, snapshot_filename)
|
||||
CHECK(ischar(snapshot_filename), 'snapshot_filename must be a string');
|
||||
CHECK_FILE_EXIST(snapshot_filename);
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
function init_log(log_base_filename)
|
||||
% init_log(log_base_filename)
|
||||
% init Caffe's log
|
||||
|
||||
CHECK(ischar(log_base_filename) && ~isempty(log_base_filename), ...
|
||||
'log_base_filename must be string');
|
||||
|
||||
[log_base_dir] = fileparts(log_base_filename);
|
||||
if ~exist(log_base_dir, 'dir')
|
||||
mkdir(log_base_dir);
|
||||
end
|
||||
|
||||
caffe_('init_log', log_base_filename);
|
||||
|
||||
end
|
|
@ -14,6 +14,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "mex.h"
|
||||
#include "gpu/mxGPUArray.h"
|
||||
|
||||
#include "caffe/caffe.hpp"
|
||||
|
||||
|
@ -63,9 +64,21 @@ enum WhichMemory { DATA, DIFF };
|
|||
// Copy matlab array to Blob data or diff
|
||||
static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
|
||||
WhichMemory data_or_diff) {
|
||||
mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
|
||||
|
||||
const float* mat_mem_ptr = NULL;
|
||||
mxGPUArray const *mx_mat_gpu;
|
||||
if (mxIsGPUArray(mx_mat)){
|
||||
mxInitGPU();
|
||||
mx_mat_gpu = mxGPUCreateFromMxArray(mx_mat);
|
||||
mat_mem_ptr = reinterpret_cast<const float*>(mxGPUGetDataReadOnly(mx_mat_gpu));
|
||||
mxCHECK(blob->count() == mxGPUGetNumberOfElements(mx_mat_gpu),
|
||||
"number of elements in target blob doesn't match that in input mxArray");
|
||||
const float* mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
|
||||
}
|
||||
else{
|
||||
mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
|
||||
"number of elements in target blob doesn't match that in input mxArray");
|
||||
mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
|
||||
}
|
||||
float* blob_mem_ptr = NULL;
|
||||
switch (Caffe::mode()) {
|
||||
case Caffe::CPU:
|
||||
|
@ -80,6 +93,10 @@ static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
|
|||
mxERROR("Unknown Caffe mode");
|
||||
}
|
||||
caffe_copy(blob->count(), mat_mem_ptr, blob_mem_ptr);
|
||||
|
||||
if (mxIsGPUArray(mx_mat)){
|
||||
mxGPUDestroyGPUArray(mx_mat_gpu);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy Blob data or diff to matlab array
|
||||
|
@ -123,6 +140,16 @@ static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
|
|||
return mx_vec;
|
||||
}
|
||||
|
||||
|
||||
// Convert vector<vector<int> > to matlab cell of (row vector)s
|
||||
static mxArray* int_vec_vec_to_mx_cell_vec(const vector<vector<int> >& int_vec_vec) {
|
||||
mxArray* mx_cell_vec = mxCreateCellMatrix(int_vec_vec.size(), 1);
|
||||
for (int i = 0; i < int_vec_vec.size(); i++){
|
||||
mxSetCell(mx_cell_vec, i, int_vec_to_mx_vec(int_vec_vec[i]));
|
||||
}
|
||||
return mx_cell_vec;
|
||||
}
|
||||
|
||||
// Convert vector<string> to matlab cell vector of strings
|
||||
static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
|
||||
mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
|
||||
|
@ -228,6 +255,14 @@ static void solver_get_iter(MEX_ARGS) {
|
|||
plhs[0] = mxCreateDoubleScalar(solver->iter());
|
||||
}
|
||||
|
||||
// Usage: caffe_('solver_get_max_iter', hSolver)
|
||||
static void solver_get_max_iter(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
|
||||
"Usage: caffe_('solver_get_max_iter', hSolver)");
|
||||
Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
|
||||
plhs[0] = mxCreateDoubleScalar(solver->max_iter());
|
||||
}
|
||||
|
||||
// Usage: caffe_('solver_restore', hSolver, snapshot_file)
|
||||
static void solver_restore(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
|
||||
|
@ -278,14 +313,34 @@ static void get_net(MEX_ARGS) {
|
|||
mxFree(phase_name);
|
||||
}
|
||||
|
||||
// Usage: caffe_('net_set_phase', hNet, phase_name)
|
||||
static void net_set_phase(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
|
||||
"Usage: caffe_('net_set_phase', hNet, phase_name)");
|
||||
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
|
||||
char* phase_name = mxArrayToString(prhs[1]);
|
||||
Phase phase;
|
||||
if (strcmp(phase_name, "train") == 0) {
|
||||
phase = TRAIN;
|
||||
}
|
||||
else if (strcmp(phase_name, "test") == 0) {
|
||||
phase = TEST;
|
||||
}
|
||||
else {
|
||||
mxERROR("Unknown phase");
|
||||
}
|
||||
net->SetPhase(phase);
|
||||
mxFree(phase_name);
|
||||
}
|
||||
|
||||
// Usage: caffe_('net_get_attr', hNet)
|
||||
static void net_get_attr(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
|
||||
"Usage: caffe_('net_get_attr', hNet)");
|
||||
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
|
||||
const int net_attr_num = 6;
|
||||
const int net_attr_num = 8;
|
||||
const char* net_attrs[net_attr_num] = { "hLayer_layers", "hBlob_blobs",
|
||||
"input_blob_indices", "output_blob_indices", "layer_names", "blob_names"};
|
||||
"input_blob_indices", "output_blob_indices", "layer_names", "blob_names", "bottom_id_vecs", "top_id_vecs" };
|
||||
mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num,
|
||||
net_attrs);
|
||||
mxSetField(mx_net_attr, 0, "hLayer_layers",
|
||||
|
@ -300,6 +355,10 @@ static void net_get_attr(MEX_ARGS) {
|
|||
str_vec_to_mx_strcell(net->layer_names()));
|
||||
mxSetField(mx_net_attr, 0, "blob_names",
|
||||
str_vec_to_mx_strcell(net->blob_names()));
|
||||
mxSetField(mx_net_attr, 0, "bottom_id_vecs",
|
||||
int_vec_vec_to_mx_cell_vec(net->bottom_id_vecs()));
|
||||
mxSetField(mx_net_attr, 0, "top_id_vecs",
|
||||
int_vec_vec_to_mx_cell_vec(net->top_id_vecs()));
|
||||
plhs[0] = mx_net_attr;
|
||||
}
|
||||
|
||||
|
@ -330,6 +389,15 @@ static void net_copy_from(MEX_ARGS) {
|
|||
mxFree(weights_file);
|
||||
}
|
||||
|
||||
// Usage: caffe_('net_shared_with', hNet, hNet_trained)
|
||||
static void net_share_trained_layers_with(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsStruct(prhs[1]),
|
||||
"Usage: caffe_('net_shared_with', hNet, hNet_trained)");
|
||||
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
|
||||
Net<float>* net_trained = handle_to_ptr<Net<float> >(prhs[1]);
|
||||
net->ShareTrainedLayersWith(net_trained);
|
||||
}
|
||||
|
||||
// Usage: caffe_('net_reshape', hNet)
|
||||
static void net_reshape(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
|
||||
|
@ -413,12 +481,24 @@ static void blob_get_data(MEX_ARGS) {
|
|||
|
||||
// Usage: caffe_('blob_set_data', hBlob, new_data)
|
||||
static void blob_set_data(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && (mxIsSingle(prhs[1]) || mxIsGPUArray(prhs[1])),
|
||||
"Usage: caffe_('blob_set_data', hBlob, new_data)");
|
||||
Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
|
||||
mx_mat_to_blob(prhs[1], blob, DATA);
|
||||
}
|
||||
|
||||
// Usage: caffe_('blob_copy_data', hBlob_to, hBlob_from)
|
||||
static void blob_copy_data(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsStruct(prhs[1]),
|
||||
"Usage: caffe_('blob_copy_data', hBlob_to, hBlob_from)");
|
||||
Blob<float>* blob_to = handle_to_ptr<Blob<float> >(prhs[0]);
|
||||
Blob<float>* blob_from = handle_to_ptr<Blob<float> >(prhs[1]);
|
||||
//mxCHECK(blob_from->count() == blob_to->count(),
|
||||
// "number of elements in target blob doesn't match that in source blob");
|
||||
|
||||
blob_to->CopyFrom(*blob_from, false, true);
|
||||
}
|
||||
|
||||
// Usage: caffe_('blob_get_diff', hBlob)
|
||||
static void blob_get_diff(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
|
||||
|
@ -473,6 +553,54 @@ static void reset(MEX_ARGS) {
|
|||
init_key = static_cast<double>(caffe_rng_rand());
|
||||
}
|
||||
|
||||
// Usage: caffe_('set_random_seed', random_seed)
|
||||
static void set_random_seed(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsDouble(prhs[0]),
|
||||
"Usage: caffe_('set_random_seed', random_seed)");
|
||||
int random_seed = static_cast<unsigned int>(mxGetScalar(prhs[0]));
|
||||
Caffe::set_random_seed(random_seed);
|
||||
}
|
||||
|
||||
static void glog_failure_handler(){
|
||||
static bool is_glog_failure = false;
|
||||
if (!is_glog_failure)
|
||||
{
|
||||
is_glog_failure = true;
|
||||
::google::FlushLogFiles(0);
|
||||
mexErrMsgTxt("glog check error, please check log and clear mex");
|
||||
}
|
||||
}
|
||||
|
||||
static void protobuf_log_handler(::google::protobuf::LogLevel level, const char* filename, int line,
|
||||
const std::string& message)
|
||||
{
|
||||
const int max_err_length = 512;
|
||||
char err_message[max_err_length];
|
||||
snprintf(err_message, max_err_length, "Protobuf : %s . at %s Line %d",
|
||||
message.c_str(), filename, line);
|
||||
LOG(INFO) << err_message;
|
||||
::google::FlushLogFiles(0);
|
||||
mexErrMsgTxt(err_message);
|
||||
}
|
||||
|
||||
// Usage: caffe_('init_log', log_base_filename)
|
||||
static void init_log(MEX_ARGS) {
|
||||
static bool is_log_inited = false;
|
||||
|
||||
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
|
||||
"Usage: caffe_('init_log', log_dir)");
|
||||
if (is_log_inited)
|
||||
::google::ShutdownGoogleLogging();
|
||||
char* log_base_filename = mxArrayToString(prhs[0]);
|
||||
::google::SetLogDestination(0, log_base_filename);
|
||||
mxFree(log_base_filename);
|
||||
::google::protobuf::SetLogHandler(&protobuf_log_handler);
|
||||
::google::InitGoogleLogging("caffe_mex");
|
||||
::google::InstallFailureFunction(&glog_failure_handler);
|
||||
|
||||
is_log_inited = true;
|
||||
}
|
||||
|
||||
// Usage: caffe_('read_mean', mean_proto_file)
|
||||
static void read_mean(MEX_ARGS) {
|
||||
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
|
||||
|
@ -528,37 +656,43 @@ struct handler_registry {
|
|||
|
||||
static handler_registry handlers[] = {
|
||||
// Public API functions
|
||||
{ "get_solver", get_solver },
|
||||
{ "solver_get_attr", solver_get_attr },
|
||||
{ "solver_get_iter", solver_get_iter },
|
||||
{ "solver_restore", solver_restore },
|
||||
{ "solver_solve", solver_solve },
|
||||
{ "solver_step", solver_step },
|
||||
{ "get_net", get_net },
|
||||
{ "net_get_attr", net_get_attr },
|
||||
{ "net_forward", net_forward },
|
||||
{ "net_backward", net_backward },
|
||||
{ "net_copy_from", net_copy_from },
|
||||
{ "net_reshape", net_reshape },
|
||||
{ "net_save", net_save },
|
||||
{ "layer_get_attr", layer_get_attr },
|
||||
{ "layer_get_type", layer_get_type },
|
||||
{ "blob_get_shape", blob_get_shape },
|
||||
{ "blob_reshape", blob_reshape },
|
||||
{ "blob_get_data", blob_get_data },
|
||||
{ "blob_set_data", blob_set_data },
|
||||
{ "blob_get_diff", blob_get_diff },
|
||||
{ "blob_set_diff", blob_set_diff },
|
||||
{ "set_mode_cpu", set_mode_cpu },
|
||||
{ "set_mode_gpu", set_mode_gpu },
|
||||
{ "set_device", set_device },
|
||||
{ "get_init_key", get_init_key },
|
||||
{ "reset", reset },
|
||||
{ "read_mean", read_mean },
|
||||
{ "write_mean", write_mean },
|
||||
{ "version", version },
|
||||
{ "get_solver", get_solver },
|
||||
{ "solver_get_attr", solver_get_attr },
|
||||
{ "solver_get_iter", solver_get_iter },
|
||||
{ "solver_get_max_iter", solver_get_max_iter },
|
||||
{ "solver_restore", solver_restore },
|
||||
{ "solver_solve", solver_solve },
|
||||
{ "solver_step", solver_step },
|
||||
{ "get_net", get_net },
|
||||
{ "net_get_attr", net_get_attr },
|
||||
{ "net_set_phase", net_set_phase },
|
||||
{ "net_forward", net_forward },
|
||||
{ "net_backward", net_backward },
|
||||
{ "net_copy_from", net_copy_from },
|
||||
{ "net_share_trained_layers_with", net_share_trained_layers_with },
|
||||
{ "net_reshape", net_reshape },
|
||||
{ "net_save", net_save },
|
||||
{ "layer_get_attr", layer_get_attr },
|
||||
{ "layer_get_type", layer_get_type },
|
||||
{ "blob_get_shape", blob_get_shape },
|
||||
{ "blob_reshape", blob_reshape },
|
||||
{ "blob_get_data", blob_get_data },
|
||||
{ "blob_set_data", blob_set_data },
|
||||
{ "blob_copy_data", blob_copy_data },
|
||||
{ "blob_get_diff", blob_get_diff },
|
||||
{ "blob_set_diff", blob_set_diff },
|
||||
{ "set_mode_cpu", set_mode_cpu },
|
||||
{ "set_mode_gpu", set_mode_gpu },
|
||||
{ "set_device", set_device },
|
||||
{ "set_random_seed", set_random_seed },
|
||||
{ "get_init_key", get_init_key },
|
||||
{ "init_log", init_log },
|
||||
{ "reset", reset },
|
||||
{ "read_mean", read_mean },
|
||||
{ "write_mean", write_mean },
|
||||
{ "version", version },
|
||||
// The end.
|
||||
{ "END", NULL },
|
||||
{ "END", NULL },
|
||||
};
|
||||
|
||||
/** -----------------------------------------------------------------
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
function set_random_seed(random_seed)
|
||||
% set_random_seed(random_seed)
|
||||
% set Caffe's random_seed
|
||||
|
||||
CHECK(isscalar(random_seed) && random_seed >= 0, ...
|
||||
'random_seed must be non-negative integer');
|
||||
random_seed = double(random_seed);
|
||||
|
||||
caffe_('set_random_seed', random_seed);
|
||||
|
||||
end
|
|
@ -49,7 +49,7 @@ void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
|
|||
variance_.Reshape(sz);
|
||||
temp_.ReshapeLike(*bottom[0]);
|
||||
if (use_global_stats_) {
|
||||
x_norm_.ReshapeLike(*bottom[0]);
|
||||
x_norm_.ReshapeLike(*bottom[0]);
|
||||
}
|
||||
sz[0]=bottom[0]->shape(0);
|
||||
batch_sum_multiplier_.Reshape(sz);
|
||||
|
|
|
@ -86,8 +86,8 @@ void BatchNormLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
|||
// TODO(cdoersch): The caching is only needed because later in-place layers
|
||||
// might clobber the data. Can we skip this if they won't?
|
||||
if (!use_global_stats_) {
|
||||
caffe_copy(x_norm_.count(), top_data,
|
||||
x_norm_.mutable_gpu_data());
|
||||
caffe_copy(x_norm_.count(), top_data,
|
||||
x_norm_.mutable_gpu_data());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,13 +99,13 @@ void BatchNormLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
|||
if (bottom[0] != top[0]) {
|
||||
top_diff = top[0]->gpu_diff();
|
||||
} else {
|
||||
if (use_global_stats_) {
|
||||
top_diff = top[0]->gpu_diff();
|
||||
}
|
||||
else {
|
||||
caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
|
||||
top_diff = x_norm_.gpu_diff();
|
||||
}
|
||||
if (use_global_stats_) {
|
||||
top_diff = top[0]->gpu_diff();
|
||||
}
|
||||
else {
|
||||
caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
|
||||
top_diff = x_norm_.gpu_diff();
|
||||
}
|
||||
}
|
||||
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
|
||||
if (use_global_stats_) {
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
// ------------------------------------------------------------------
|
||||
// R-FCN
|
||||
// Copyright (c) 2016 Microsoft
|
||||
// Licensed under The MIT License [see r-fcn/LICENSE for details]
|
||||
// Written by Yi Li
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/layers/box_annotator_ohem_layer.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
using std::max;
|
||||
using std::min;
|
||||
using std::floor;
|
||||
using std::ceil;
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
BoxAnnotatorOHEMParameter box_anno_param = this->layer_param_.box_annotator_ohem_param();
|
||||
roi_per_img_ = box_anno_param.roi_per_img();
|
||||
CHECK_GT(roi_per_img_, 0);
|
||||
ignore_label_ = box_anno_param.ignore_label();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
num_ = bottom[0]->num();
|
||||
CHECK_EQ(5, bottom[0]->channels());
|
||||
height_ = bottom[0]->height();
|
||||
width_ = bottom[0]->width();
|
||||
spatial_dim_ = height_*width_;
|
||||
|
||||
CHECK_EQ(bottom[1]->num(), num_);
|
||||
CHECK_EQ(bottom[1]->channels(), 1);
|
||||
CHECK_EQ(bottom[1]->height(), height_);
|
||||
CHECK_EQ(bottom[1]->width(), width_);
|
||||
|
||||
CHECK_EQ(bottom[2]->num(), num_);
|
||||
CHECK_EQ(bottom[2]->channels(), 1);
|
||||
CHECK_EQ(bottom[2]->height(), height_);
|
||||
CHECK_EQ(bottom[2]->width(), width_);
|
||||
|
||||
CHECK_EQ(bottom[3]->num(), num_);
|
||||
bbox_channels_ = bottom[3]->channels();
|
||||
CHECK_EQ(bottom[3]->height(), height_);
|
||||
CHECK_EQ(bottom[3]->width(), width_);
|
||||
|
||||
// Labels for scoring
|
||||
top[0]->Reshape(num_, 1, height_, width_);
|
||||
// Loss weights for bbox regression
|
||||
top[1]->Reshape(num_, bbox_channels_, height_, width_);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
|
||||
#ifdef CPU_ONLY
|
||||
STUB_GPU(BoxAnnotatorOHEMLayer);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_CLASS(BoxAnnotatorOHEMLayer);
|
||||
REGISTER_LAYER_CLASS(BoxAnnotatorOHEM);
|
||||
|
||||
} // namespace caffe
|
|
@ -0,0 +1,79 @@
|
|||
// ------------------------------------------------------------------
|
||||
// R-FCN
|
||||
// Copyright (c) 2016 Microsoft
|
||||
// Licensed under The MIT License [see r-fcn/LICENSE for details]
|
||||
// Written by Yi Li
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/layers/box_annotator_ohem_layer.hpp"
|
||||
|
||||
using std::max;
|
||||
using std::min;
|
||||
|
||||
namespace caffe {
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
const Dtype* bottom_rois = bottom[0]->cpu_data();
|
||||
const Dtype* bottom_loss = bottom[1]->cpu_data();
|
||||
const Dtype* bottom_labels = bottom[2]->cpu_data();
|
||||
const Dtype* bottom_bbox_loss_weights = bottom[3]->cpu_data();
|
||||
Dtype* top_labels = top[0]->mutable_cpu_data();
|
||||
Dtype* top_bbox_loss_weights = top[1]->mutable_cpu_data();
|
||||
caffe_set(top[0]->count(), Dtype(ignore_label_), top_labels);
|
||||
caffe_set(top[1]->count(), Dtype(0), top_bbox_loss_weights);
|
||||
|
||||
int num_rois_ = bottom[1]->count();
|
||||
|
||||
int num_imgs = -1;
|
||||
for (int n = 0; n < num_rois_; n++){
|
||||
for (int s = 0; s < spatial_dim_; s++){
|
||||
num_imgs = bottom_rois[0]>num_imgs ? bottom_rois[0] : num_imgs;
|
||||
bottom_rois++;
|
||||
}
|
||||
bottom_rois += (5-1)*spatial_dim_;
|
||||
}
|
||||
num_imgs++;
|
||||
CHECK_GT(num_imgs, 0)
|
||||
<< "number of images must be greater than 0 at BoxAnnotatorOHEMLayer";
|
||||
bottom_rois = bottom[0]->cpu_data();
|
||||
|
||||
// Find rois with max loss
|
||||
vector<int> sorted_idx(num_rois_);
|
||||
for (int i = 0; i < num_rois_; i++){
|
||||
sorted_idx[i] = i;
|
||||
}
|
||||
std::sort(sorted_idx.begin(), sorted_idx.end(),
|
||||
[bottom_loss](int i1, int i2){return bottom_loss[i1] > bottom_loss[i2]; });
|
||||
|
||||
// Generate output labels for scoring and loss_weights for bbox regression
|
||||
vector<int> number_left(num_imgs, roi_per_img_);
|
||||
for (int i = 0; i < num_rois_; i++){
|
||||
int index = sorted_idx[i];
|
||||
int s = index % (width_*height_);
|
||||
int n = index / (width_*height_);
|
||||
int batch_ind = bottom_rois[n*5*spatial_dim_+s];
|
||||
if (number_left[batch_ind]>0){
|
||||
number_left[batch_ind]--;
|
||||
top_labels[index] = bottom_labels[index];
|
||||
for (int j = 0; j < bbox_channels_; j++){
|
||||
int bbox_index = (n*bbox_channels_+j)*spatial_dim_+s;
|
||||
top_bbox_loss_weights[bbox_index]=bottom_bbox_loss_weights[bbox_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void BoxAnnotatorOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
return;
|
||||
}
|
||||
|
||||
INSTANTIATE_LAYER_GPU_FUNCS(BoxAnnotatorOHEMLayer);
|
||||
|
||||
} // namespace caffe
|
|
@ -121,7 +121,7 @@ namespace caffe {
|
|||
int ph = (index / pooled_width) % pooled_height;
|
||||
int n = index / pooled_width / pooled_height / output_dim;
|
||||
|
||||
// [start, end) interval for spatial sampling
|
||||
// [start, end) interval for spatial sampling
|
||||
bottom_rois += n * 5;
|
||||
int roi_batch_ind = bottom_rois[0];
|
||||
Dtype roi_start_w = static_cast<Dtype>(round(bottom_rois[1])) * spatial_scale;
|
||||
|
|
|
@ -89,10 +89,10 @@ void ScaleLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
|
|||
scale_dim_ = scale->count();
|
||||
inner_dim_ = bottom[0]->count(axis_ + scale->num_axes());
|
||||
if (bottom[0] == top[0]) { // in-place computation
|
||||
const bool scale_param = (bottom.size() == 1);
|
||||
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
|
||||
temp_.ReshapeLike(*bottom[0]);
|
||||
}
|
||||
const bool scale_param = (bottom.size() == 1);
|
||||
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
|
||||
temp_.ReshapeLike(*bottom[0]);
|
||||
}
|
||||
} else {
|
||||
top[0]->ReshapeLike(*bottom[0]);
|
||||
}
|
||||
|
|
|
@ -36,11 +36,11 @@ void ScaleLayer<Dtype>::Forward_gpu(
|
|||
// Note that this is only necessary for Backward; we could skip this if not
|
||||
// doing Backward, but Caffe currently provides no way of knowing whether
|
||||
// we'll need to do Backward at the time of the Forward call.
|
||||
const bool scale_param = (bottom.size() == 1);
|
||||
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
|
||||
caffe_copy(bottom[0]->count(), bottom[0]->gpu_data(),
|
||||
temp_.mutable_gpu_data());
|
||||
}
|
||||
const bool scale_param = (bottom.size() == 1);
|
||||
if (!scale_param || (scale_param && this->param_propagate_down_[0])) {
|
||||
caffe_copy(bottom[0]->count(), bottom[0]->gpu_data(),
|
||||
temp_.mutable_gpu_data());
|
||||
}
|
||||
}
|
||||
const Dtype* scale_data =
|
||||
((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data();
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
// --------------------------------------------------------
|
||||
// R-FCN
|
||||
// Written by Yi Li, 2016.
|
||||
// --------------------------------------------------------
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/layers/smooth_l1_loss_ohem_layer.hpp"
|
||||
#include "caffe/util/math_functions.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::LayerSetUp(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
has_weights_ = (bottom.size() == 3);
|
||||
|
||||
if (!this->layer_param_.loss_param().has_normalization() &&
|
||||
this->layer_param_.loss_param().has_normalize()) {
|
||||
normalization_ = this->layer_param_.loss_param().normalize() ?
|
||||
LossParameter_NormalizationMode_VALID :
|
||||
LossParameter_NormalizationMode_BATCH_SIZE;
|
||||
}
|
||||
else {
|
||||
normalization_ = this->layer_param_.loss_param().normalization();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::Reshape(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
LossLayer<Dtype>::Reshape(bottom, top);
|
||||
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
|
||||
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
|
||||
CHECK_EQ(bottom[0]->width(), bottom[1]->width());
|
||||
if (has_weights_) {
|
||||
CHECK_EQ(bottom[0]->channels(), bottom[2]->channels());
|
||||
CHECK_EQ(bottom[0]->height(), bottom[2]->height());
|
||||
CHECK_EQ(bottom[0]->width(), bottom[2]->width());
|
||||
}
|
||||
|
||||
outer_num_ = bottom[0]->num();
|
||||
inner_num_ = bottom[0]->height() * bottom[0]->width();
|
||||
|
||||
diff_.Reshape(bottom[0]->num(), bottom[0]->channels(),
|
||||
bottom[0]->height(), bottom[0]->width());
|
||||
errors_.Reshape(bottom[0]->num(), bottom[0]->channels(),
|
||||
bottom[0]->height(), bottom[0]->width());
|
||||
|
||||
// top[2] stores per-instance loss, which takes the shape of N*1*H*W
|
||||
if (top.size()>=2) {
|
||||
top[1]->Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
Dtype SmoothL1LossOHEMLayer<Dtype>::get_normalizer(
|
||||
LossParameter_NormalizationMode normalization_mode, Dtype pre_fixed_normalizer) {
|
||||
Dtype normalizer;
|
||||
switch (normalization_mode) {
|
||||
case LossParameter_NormalizationMode_FULL:
|
||||
normalizer = Dtype(outer_num_ * inner_num_);
|
||||
break;
|
||||
case LossParameter_NormalizationMode_VALID:
|
||||
normalizer = Dtype(outer_num_ * inner_num_);
|
||||
break;
|
||||
case LossParameter_NormalizationMode_BATCH_SIZE:
|
||||
normalizer = Dtype(outer_num_);
|
||||
break;
|
||||
case LossParameter_NormalizationMode_PRE_FIXED:
|
||||
normalizer = pre_fixed_normalizer;
|
||||
break;
|
||||
case LossParameter_NormalizationMode_NONE:
|
||||
normalizer = Dtype(1);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown normalization mode: "
|
||||
<< LossParameter_NormalizationMode_Name(normalization_mode);
|
||||
}
|
||||
// Some users will have no labels for some examples in order to 'turn off' a
|
||||
// particular loss in a multi-task setup. The max prevents NaNs in that case.
|
||||
return std::max(Dtype(1.0), normalizer);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
#ifdef CPU_ONLY
|
||||
STUB_GPU(SmoothL1LossOHEMLayer);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_CLASS(SmoothL1LossOHEMLayer);
|
||||
REGISTER_LAYER_CLASS(SmoothL1LossOHEM);
|
||||
|
||||
} // namespace caffe
|
|
@ -0,0 +1,127 @@
|
|||
// --------------------------------------------------------
|
||||
// R-FCN
|
||||
// Written by Yi Li, 2016.
|
||||
// --------------------------------------------------------
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "thrust/device_vector.h"
|
||||
|
||||
#include "caffe/layers/smooth_l1_loss_ohem_layer.hpp"
|
||||
#include "caffe/util/math_functions.hpp"
|
||||
|
||||
namespace caffe {
|
||||
template <typename Dtype>
|
||||
__global__ void SmoothL1ForwardGPU(const int n, const Dtype* in, Dtype* out) {
|
||||
// f(x) = 0.5 * x^2 if |x| < 1
|
||||
// |x| - 0.5 otherwise
|
||||
CUDA_KERNEL_LOOP(index, n) {
|
||||
Dtype val = in[index];
|
||||
Dtype abs_val = abs(val);
|
||||
if (abs_val < 1) {
|
||||
out[index] = 0.5 * val * val;
|
||||
}
|
||||
else {
|
||||
out[index] = abs_val - 0.5;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void kernel_channel_sum(const int num, const int channels,
|
||||
const int spatial_dim, const Dtype* data, Dtype* channel_sum) {
|
||||
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
|
||||
int n = index / spatial_dim;
|
||||
int s = index % spatial_dim;
|
||||
Dtype sum = 0;
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
sum += data[(n * channels + c) * spatial_dim + s];
|
||||
}
|
||||
channel_sum[index] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
int count = bottom[0]->count();
|
||||
caffe_gpu_sub(
|
||||
count,
|
||||
bottom[0]->gpu_data(),
|
||||
bottom[1]->gpu_data(),
|
||||
diff_.mutable_gpu_data()); // d := b0 - b1
|
||||
if (has_weights_) {
|
||||
caffe_gpu_mul(
|
||||
count,
|
||||
bottom[2]->gpu_data(),
|
||||
diff_.gpu_data(),
|
||||
diff_.mutable_gpu_data()); // d := w * (b0 - b1)
|
||||
}
|
||||
SmoothL1ForwardGPU<Dtype> << <CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS >> >(
|
||||
count, diff_.gpu_data(), errors_.mutable_gpu_data());
|
||||
CUDA_POST_KERNEL_CHECK;
|
||||
|
||||
|
||||
Dtype loss;
|
||||
caffe_gpu_asum(count, errors_.gpu_data(), &loss);
|
||||
int spatial_dim = diff_.height() * diff_.width();
|
||||
|
||||
Dtype pre_fixed_normalizer = this->layer_param_.loss_param().pre_fixed_normalizer();
|
||||
top[0]->mutable_cpu_data()[0] = loss / get_normalizer(normalization_,
|
||||
pre_fixed_normalizer);
|
||||
|
||||
// Output per-instance loss
|
||||
if (top.size() >= 2) {
|
||||
kernel_channel_sum<Dtype> << <CAFFE_GET_BLOCKS(top[0]->count()), CAFFE_CUDA_NUM_THREADS >> >
|
||||
(outer_num_, bottom[0]->channels(), inner_num_, errors_.gpu_data(),
|
||||
top[1]->mutable_gpu_data());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void SmoothL1BackwardGPU(const int n, const Dtype* in, Dtype* out) {
|
||||
// f'(x) = x if |x| < 1
|
||||
// = sign(x) otherwise
|
||||
CUDA_KERNEL_LOOP(index, n) {
|
||||
Dtype val = in[index];
|
||||
Dtype abs_val = abs(val);
|
||||
if (abs_val < 1) {
|
||||
out[index] = val;
|
||||
}
|
||||
else {
|
||||
out[index] = (Dtype(0) < val) - (val < Dtype(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SmoothL1LossOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
int count = diff_.count();
|
||||
SmoothL1BackwardGPU<Dtype> << <CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS >> >(
|
||||
count, diff_.gpu_data(), diff_.mutable_gpu_data());
|
||||
CUDA_POST_KERNEL_CHECK;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (propagate_down[i]) {
|
||||
const Dtype sign = (i == 0) ? 1 : -1;
|
||||
int spatial_dim = diff_.height() * diff_.width();
|
||||
|
||||
Dtype pre_fixed_normalizer = this->layer_param_.loss_param().pre_fixed_normalizer();
|
||||
Dtype normalizer = get_normalizer(normalization_, pre_fixed_normalizer);
|
||||
Dtype alpha = sign * top[0]->cpu_diff()[0] / normalizer;
|
||||
|
||||
caffe_gpu_axpby(
|
||||
bottom[i]->count(), // count
|
||||
alpha, // alpha
|
||||
diff_.gpu_data(), // x
|
||||
Dtype(0), // beta
|
||||
bottom[i]->mutable_gpu_diff()); // y
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_LAYER_GPU_FUNCS(SmoothL1LossOHEMLayer);
|
||||
|
||||
} // namespace caffe
|
|
@ -0,0 +1,114 @@
|
|||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/layers/softmax_loss_ohem_layer.hpp"
|
||||
#include "caffe/util/math_functions.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::LayerSetUp(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
LossLayer<Dtype>::LayerSetUp(bottom, top);
|
||||
LayerParameter softmax_param(this->layer_param_);
|
||||
// Fix a bug which occurs with more than one output
|
||||
softmax_param.clear_loss_weight();
|
||||
softmax_param.set_type("Softmax");
|
||||
softmax_layer_ = LayerRegistry<Dtype>::CreateLayer(softmax_param);
|
||||
softmax_bottom_vec_.clear();
|
||||
softmax_bottom_vec_.push_back(bottom[0]);
|
||||
softmax_top_vec_.clear();
|
||||
softmax_top_vec_.push_back(&prob_);
|
||||
softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_);
|
||||
|
||||
has_ignore_label_ =
|
||||
this->layer_param_.loss_param().has_ignore_label();
|
||||
if (has_ignore_label_) {
|
||||
ignore_label_ = this->layer_param_.loss_param().ignore_label();
|
||||
}
|
||||
if (!this->layer_param_.loss_param().has_normalization() &&
|
||||
this->layer_param_.loss_param().has_normalize()) {
|
||||
normalization_ = this->layer_param_.loss_param().normalize() ?
|
||||
LossParameter_NormalizationMode_VALID :
|
||||
LossParameter_NormalizationMode_BATCH_SIZE;
|
||||
} else {
|
||||
normalization_ = this->layer_param_.loss_param().normalization();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::Reshape(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
LossLayer<Dtype>::Reshape(bottom, top);
|
||||
softmax_layer_->Reshape(softmax_bottom_vec_, softmax_top_vec_);
|
||||
softmax_axis_ =
|
||||
bottom[0]->CanonicalAxisIndex(this->layer_param_.softmax_param().axis());
|
||||
outer_num_ = bottom[0]->count(0, softmax_axis_);
|
||||
inner_num_ = bottom[0]->count(softmax_axis_ + 1);
|
||||
CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
|
||||
<< "Number of labels must match number of predictions; "
|
||||
<< "e.g., if softmax axis == 1 and prediction shape is (N, C, H, W), "
|
||||
<< "label count (number of labels) must be N*H*W, "
|
||||
<< "with integer values in {0, 1, ..., C-1}.";
|
||||
if (top.size() >= 2) {
|
||||
// softmax output
|
||||
top[1]->ReshapeLike(*bottom[0]);
|
||||
}
|
||||
|
||||
// top[2] stores per-instance loss, which takes the shape of N*1*H*W
|
||||
if (top.size() >= 3) {
|
||||
top[2]->ReshapeLike(*bottom[1]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
Dtype SoftmaxWithLossOHEMLayer<Dtype>::get_normalizer(
|
||||
LossParameter_NormalizationMode normalization_mode, int valid_count) {
|
||||
Dtype normalizer;
|
||||
switch (normalization_mode) {
|
||||
case LossParameter_NormalizationMode_FULL:
|
||||
normalizer = Dtype(outer_num_ * inner_num_);
|
||||
break;
|
||||
case LossParameter_NormalizationMode_VALID:
|
||||
if (valid_count == -1) {
|
||||
normalizer = Dtype(outer_num_ * inner_num_);
|
||||
} else {
|
||||
normalizer = Dtype(valid_count);
|
||||
}
|
||||
break;
|
||||
case LossParameter_NormalizationMode_BATCH_SIZE:
|
||||
normalizer = Dtype(outer_num_);
|
||||
break;
|
||||
case LossParameter_NormalizationMode_NONE:
|
||||
normalizer = Dtype(1);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown normalization mode: "
|
||||
<< LossParameter_NormalizationMode_Name(normalization_mode);
|
||||
}
|
||||
// Some users will have no labels for some examples in order to 'turn off' a
|
||||
// particular loss in a multi-task setup. The max prevents NaNs in that case.
|
||||
return std::max(Dtype(1.0), normalizer);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::Forward_cpu(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
#ifdef CPU_ONLY
|
||||
STUB_GPU(SoftmaxWithLossOHEMLayer);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_CLASS(SoftmaxWithLossOHEMLayer);
|
||||
REGISTER_LAYER_CLASS(SoftmaxWithLossOHEM);
|
||||
|
||||
} // namespace caffe
|
|
@ -0,0 +1,136 @@
|
|||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe/layers/softmax_loss_ohem_layer.hpp"
|
||||
#include "caffe/util/math_functions.hpp"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void SoftmaxLossForwardGPU(const int nthreads,
|
||||
const Dtype* prob_data, const Dtype* label, Dtype* loss,
|
||||
const int num, const int dim, const int spatial_dim,
|
||||
const bool has_ignore_label_, const int ignore_label_,
|
||||
Dtype* counts) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / spatial_dim;
|
||||
const int s = index % spatial_dim;
|
||||
const int label_value = static_cast<int>(label[n * spatial_dim + s]);
|
||||
if (has_ignore_label_ && label_value == ignore_label_) {
|
||||
loss[index] = 0;
|
||||
counts[index] = 0;
|
||||
} else {
|
||||
loss[index] = -log(max(prob_data[n * dim + label_value * spatial_dim + s],
|
||||
Dtype(FLT_MIN)));
|
||||
counts[index] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::Forward_gpu(
|
||||
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
|
||||
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
|
||||
const Dtype* prob_data = prob_.gpu_data();
|
||||
const Dtype* label = bottom[1]->gpu_data();
|
||||
const int dim = prob_.count() / outer_num_;
|
||||
const int nthreads = outer_num_ * inner_num_;
|
||||
// Since this memory is not used for anything until it is overwritten
|
||||
// on the backward pass, we use it here to avoid having to allocate new GPU
|
||||
// memory to accumulate intermediate results in the kernel.
|
||||
Dtype* loss_data = bottom[0]->mutable_gpu_diff();
|
||||
// Similarly, this memory is never used elsewhere, and thus we can use it
|
||||
// to avoid having to allocate additional GPU memory.
|
||||
Dtype* counts = prob_.mutable_gpu_diff();
|
||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||
SoftmaxLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
|
||||
CAFFE_CUDA_NUM_THREADS>>>(nthreads, prob_data, label, loss_data,
|
||||
outer_num_, dim, inner_num_, has_ignore_label_, ignore_label_, counts);
|
||||
Dtype loss;
|
||||
caffe_gpu_asum(nthreads, loss_data, &loss);
|
||||
Dtype valid_count = -1;
|
||||
// Only launch another CUDA kernel if we actually need the count of valid
|
||||
// outputs.
|
||||
if (normalization_ == LossParameter_NormalizationMode_VALID &&
|
||||
has_ignore_label_) {
|
||||
caffe_gpu_asum(nthreads, counts, &valid_count);
|
||||
}
|
||||
top[0]->mutable_cpu_data()[0] = loss / get_normalizer(normalization_,
|
||||
valid_count);
|
||||
if (top.size() >= 2) {
|
||||
top[1]->ShareData(prob_);
|
||||
}
|
||||
if (top.size() >= 3) {
|
||||
// Output per-instance loss
|
||||
caffe_gpu_memcpy(top[2]->count() * sizeof(Dtype), loss_data, top[2]->mutable_gpu_data());
|
||||
}
|
||||
|
||||
// Fix a bug, which happens when propagate_down[0] = false in backward
|
||||
caffe_gpu_set(bottom[0]->count(), Dtype(0), bottom[0]->mutable_gpu_diff());
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void SoftmaxLossBackwardGPU(const int nthreads, const Dtype* top,
|
||||
const Dtype* label, Dtype* bottom_diff, const int num, const int dim,
|
||||
const int spatial_dim, const bool has_ignore_label_,
|
||||
const int ignore_label_, Dtype* counts) {
|
||||
const int channels = dim / spatial_dim;
|
||||
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / spatial_dim;
|
||||
const int s = index % spatial_dim;
|
||||
const int label_value = static_cast<int>(label[n * spatial_dim + s]);
|
||||
|
||||
if (has_ignore_label_ && label_value == ignore_label_) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
bottom_diff[n * dim + c * spatial_dim + s] = 0;
|
||||
}
|
||||
counts[index] = 0;
|
||||
} else {
|
||||
bottom_diff[n * dim + label_value * spatial_dim + s] -= 1;
|
||||
counts[index] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void SoftmaxWithLossOHEMLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
|
||||
if (propagate_down[1]) {
|
||||
LOG(FATAL) << this->type()
|
||||
<< " Layer cannot backpropagate to label inputs.";
|
||||
}
|
||||
if (propagate_down[0]) {
|
||||
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
|
||||
const Dtype* prob_data = prob_.gpu_data();
|
||||
const Dtype* top_data = top[0]->gpu_data();
|
||||
caffe_gpu_memcpy(prob_.count() * sizeof(Dtype), prob_data, bottom_diff);
|
||||
const Dtype* label = bottom[1]->gpu_data();
|
||||
const int dim = prob_.count() / outer_num_;
|
||||
const int nthreads = outer_num_ * inner_num_;
|
||||
// Since this memory is never used for anything else,
|
||||
// we use to to avoid allocating new GPU memory.
|
||||
Dtype* counts = prob_.mutable_gpu_diff();
|
||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||
SoftmaxLossBackwardGPU<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
|
||||
CAFFE_CUDA_NUM_THREADS>>>(nthreads, top_data, label, bottom_diff,
|
||||
outer_num_, dim, inner_num_, has_ignore_label_, ignore_label_, counts);
|
||||
|
||||
Dtype valid_count = -1;
|
||||
// Only launch another CUDA kernel if we actually need the count of valid
|
||||
// outputs.
|
||||
if (normalization_ == LossParameter_NormalizationMode_VALID &&
|
||||
has_ignore_label_) {
|
||||
caffe_gpu_asum(nthreads, counts, &valid_count);
|
||||
}
|
||||
const Dtype loss_weight = top[0]->cpu_diff()[0] /
|
||||
get_normalizer(normalization_, valid_count);
|
||||
caffe_gpu_scal(prob_.count(), loss_weight , bottom_diff);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_LAYER_GPU_FUNCS(SoftmaxWithLossOHEMLayer);
|
||||
|
||||
} // namespace caffe
|
|
@ -283,6 +283,16 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
|
|||
LOG_IF(INFO, Caffe::root_solver()) << "Network initialization done.";
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void Net<Dtype>::SetPhase(Phase phase){
|
||||
// set all layers
|
||||
for (int i = 0; i < layers_.size(); ++i){
|
||||
layers_[i]->set_phase(phase);
|
||||
}
|
||||
// set net phase
|
||||
phase_ = phase;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void Net<Dtype>::FilterNet(const NetParameter& param,
|
||||
NetParameter* param_filtered) {
|
||||
|
@ -761,7 +771,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
|
|||
LOG(INFO) << "Ignoring source layer " << source_layer_name;
|
||||
continue;
|
||||
}
|
||||
DLOG(INFO) << "Copying source layer " << source_layer_name;
|
||||
LOG(INFO) << "Copying source layer " << source_layer_name;
|
||||
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
|
||||
layers_[target_layer_id]->blobs();
|
||||
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
|
||||
|
|
|
@ -306,7 +306,7 @@ message ParamSpec {
|
|||
// NOTE
|
||||
// Update the next available ID when you add a new LayerParameter field.
|
||||
//
|
||||
// LayerParameter next available layer-specific ID: 150 (last added: psroi_pooling_param)
|
||||
// LayerParameter next available layer-specific ID: 151 (last added: box_annotator_ohem_param)
|
||||
message LayerParameter {
|
||||
optional string name = 1; // the layer name
|
||||
optional string type = 2; // the layer type
|
||||
|
@ -361,6 +361,7 @@ message LayerParameter {
|
|||
optional AccuracyParameter accuracy_param = 102;
|
||||
optional ArgMaxParameter argmax_param = 103;
|
||||
optional BatchNormParameter batch_norm_param = 139;
|
||||
optional BoxAnnotatorOHEMParameter box_annotator_ohem_param = 150;
|
||||
optional BiasParameter bias_param = 141;
|
||||
optional ConcatParameter concat_param = 104;
|
||||
optional ContrastiveLossParameter contrastive_loss_param = 105;
|
||||
|
@ -389,7 +390,7 @@ message LayerParameter {
|
|||
optional PoolingParameter pooling_param = 121;
|
||||
optional PowerParameter power_param = 122;
|
||||
optional PReLUParameter prelu_param = 131;
|
||||
optional PSROIPoolingParameter psroi_pooling_param = 149;
|
||||
optional PSROIPoolingParameter psroi_pooling_param = 149;
|
||||
optional PythonParameter python_param = 130;
|
||||
optional RecurrentParameter recurrent_param = 146;
|
||||
optional ReductionParameter reduction_param = 136;
|
||||
|
@ -450,14 +451,18 @@ message LossParameter {
|
|||
VALID = 1;
|
||||
// Divide by the batch size.
|
||||
BATCH_SIZE = 2;
|
||||
// Divide by pre-fixed normalizer
|
||||
PRE_FIXED = 3;
|
||||
// Do not normalize the loss.
|
||||
NONE = 3;
|
||||
NONE = 4;
|
||||
}
|
||||
optional NormalizationMode normalization = 3 [default = VALID];
|
||||
// Deprecated. Ignored if normalization is specified. If normalization
|
||||
// is not specified, then setting this to false will be equivalent to
|
||||
// normalization = BATCH_SIZE to be consistent with previous behavior.
|
||||
optional bool normalize = 2;
|
||||
//pre-fixed normalizer
|
||||
optional float pre_fixed_normalizer = 4 [default = 1];
|
||||
}
|
||||
|
||||
// Messages that store parameters used by individual layer types follow, in
|
||||
|
@ -514,6 +519,11 @@ message BatchNormParameter {
|
|||
optional float eps = 3 [default = 1e-5];
|
||||
}
|
||||
|
||||
message BoxAnnotatorOHEMParameter {
|
||||
required uint32 roi_per_img = 1; // number of rois for training
|
||||
optional int32 ignore_label = 2 [default = -1]; // ignore_label in scoring
|
||||
}
|
||||
|
||||
message BiasParameter {
|
||||
// The first axis of bottom[0] (the first input Blob) along which to apply
|
||||
// bottom[1] (the second input Blob). May be negative to index from the end
|
||||
|
@ -924,7 +934,7 @@ message PowerParameter {
|
|||
message PSROIPoolingParameter {
|
||||
required float spatial_scale = 1;
|
||||
required int32 output_dim = 2; // output channel number
|
||||
required int32 group_size = 3; // equal to pooled_size
|
||||
required int32 group_size = 3; // number of groups to encode position-sensitive score maps
|
||||
}
|
||||
|
||||
message PythonParameter {
|
||||
|
|
|
@ -106,6 +106,7 @@
|
|||
<ClCompile Include="..\..\src\caffe\layers\batch_reindex_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\bias_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\bnll_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\concat_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\contrastive_loss_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\conv_layer.cpp" />
|
||||
|
@ -148,6 +149,7 @@
|
|||
<ClCompile Include="..\..\src\caffe\layers\pooling_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\power_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\prelu_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\reduction_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\relu_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\reshape_layer.cpp" />
|
||||
|
@ -156,8 +158,11 @@
|
|||
<ClCompile Include="..\..\src\caffe\layers\sigmoid_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\silence_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\slice_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\softmax_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\split_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\spp_layer.cpp" />
|
||||
<ClCompile Include="..\..\src\caffe\layers\tanh_layer.cpp" />
|
||||
|
@ -250,6 +255,7 @@
|
|||
<ClInclude Include="..\..\include\caffe\layers\pooling_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\power_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\prelu_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\psroi_pooling_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\python_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\reduction_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\relu_layer.hpp" />
|
||||
|
@ -259,6 +265,7 @@
|
|||
<ClInclude Include="..\..\include\caffe\layers\sigmoid_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\silence_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\slice_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\smooth_l1_loss_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\softmax_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\softmax_loss_layer.hpp" />
|
||||
<ClInclude Include="..\..\include\caffe\layers\split_layer.hpp" />
|
||||
|
@ -354,6 +361,13 @@
|
|||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cu" />
|
||||
<CudaCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cu" />
|
||||
<CudaCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cu" />
|
||||
<CudaCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cu" />
|
||||
<CudaCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cu" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<Import Project="$(SolutionDir)\CommonSettings.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
|
|
|
@ -336,6 +336,21 @@
|
|||
<ClCompile Include="..\..\src\caffe\util\signal_handler.cpp">
|
||||
<Filter>src\util</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cpp">
|
||||
<Filter>src\layers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cpp">
|
||||
<Filter>src\layers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cpp">
|
||||
<Filter>src\layers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cpp">
|
||||
<Filter>src\layers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cpp">
|
||||
<Filter>src\layers</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\include\caffe\proto\caffe.pb.h">
|
||||
|
@ -638,6 +653,12 @@
|
|||
<ClInclude Include="..\..\include\caffe\layers\scale_layer.hpp">
|
||||
<Filter>include\layers</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\include\caffe\layers\psroi_pooling_layer.hpp">
|
||||
<Filter>include\layers</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\include\caffe\layers\smooth_l1_loss_layer.hpp">
|
||||
<Filter>include\layers</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\contrastive_loss_layer.cu">
|
||||
|
@ -811,6 +832,21 @@
|
|||
<CudaCompile Include="..\..\src\caffe\solvers\sgd_solver.cu">
|
||||
<Filter>cu\solvers</Filter>
|
||||
</CudaCompile>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\psroi_pooling_layer.cu">
|
||||
<Filter>cu\layers</Filter>
|
||||
</CudaCompile>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\smooth_l1_loss_layer.cu">
|
||||
<Filter>cu\layers</Filter>
|
||||
</CudaCompile>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\softmax_loss_ohem_layer.cu">
|
||||
<Filter>cu\layers</Filter>
|
||||
</CudaCompile>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\box_annotator_ohem_layer.cu">
|
||||
<Filter>cu\layers</Filter>
|
||||
</CudaCompile>
|
||||
<CudaCompile Include="..\..\src\caffe\layers\smooth_L1_loss_ohem_layer.cu">
|
||||
<Filter>cu\layers</Filter>
|
||||
</CudaCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
|
|
|
@ -34,12 +34,12 @@
|
|||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<Link>
|
||||
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;libmat.lib;gpu.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<Link>
|
||||
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>libcaffe.lib;libmx.lib;libmex.lib;libmat.lib;gpu.lib;$(CudaDependencies);%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup>
|
||||
|
|
Загрузка…
Ссылка в новой задаче