Added contrastive loss layer, associated tests, and a siamese network example using shared weights and the contrastive loss.

This commit is contained in:
Nick Carlevaris-Bianco 2014-08-21 09:32:10 -04:00
Родитель fc921bf9d6
Коммит d149c9a98d
14 изменённых файлов: 1323 добавлений и 2 удалений

Просмотреть файл

@ -0,0 +1,123 @@
//
// This script converts the MNIST dataset to the leveldb format used
// by caffe to train siamese network.
// Usage:
// convert_mnist_data input_image_file input_label_file output_db_file
// The MNIST dataset could be downloaded at
// http://yann.lecun.com/exdb/mnist/
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "leveldb/db.h"
#include "stdint.h"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/math_functions.hpp"
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
void read_image(std::ifstream* image_file, std::ifstream* label_file,
uint32_t index, uint32_t rows, uint32_t cols,
char* pixels, char* label) {
image_file->seekg(index * rows * cols + 16);
image_file->read(pixels, rows * cols);
label_file->seekg(index + 8);
label_file->read(label, 1);
}
void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_filename) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_file;
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
// Open leveldb
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status = leveldb::DB::Open(
options, db_filename, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_filename
<< ". Is it already existing?";
char label_i;
char label_j;
char* pixels = new char[2 * rows * cols];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;
caffe::Datum datum;
datum.set_channels(2); // one channel for each image in the pair
datum.set_height(rows);
datum.set_width(cols);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int itemid = 0; itemid < num_items; ++itemid) {
int i = caffe::caffe_rng_rand() % num_items; // pick a random pair
int j = caffe::caffe_rng_rand() % num_items;
read_image(&image_file, &label_file, i, rows, cols,
pixels, &label_i);
read_image(&image_file, &label_file, j, rows, cols,
pixels + (rows * cols), &label_j);
datum.set_data(pixels, 2*rows*cols);
if (label_i == label_j) {
datum.set_label(1);
} else {
datum.set_label(0);
}
datum.SerializeToString(&value);
snprintf(key, kMaxKeyLength, "%08d", itemid);
db->Put(leveldb::WriteOptions(), std::string(key), value);
}
delete db;
delete pixels;
}
int main(int argc, char** argv) {
if (argc != 4) {
printf("This script converts the MNIST dataset to the leveldb format used\n"
"by caffe to train a siamese network.\n"
"Usage:\n"
" convert_mnist_data input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3]);
}
return 0;
}

Просмотреть файл

@ -0,0 +1,21 @@
#!/usr/bin/env sh
# This script converts the mnist data into leveldb format.
EXAMPLES=./build/examples/siamese
DATA=./data/mnist
echo "Creating leveldb..."
rm -rf ./examples/siamese/mnist_siamese_train_leveldb
rm -rf ./examples/siamese/mnist_siamese_test_leveldb
$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/train-images-idx3-ubyte \
$DATA/train-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_train_leveldb
$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/t10k-images-idx3-ubyte \
$DATA/t10k-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_test_leveldb
echo "Done."

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,95 @@
name: "mnist_siamese"
input: "data"
input_dim: 10000
input_dim: 1
input_dim: 28
input_dim: 28
layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
}
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
}
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "ip1"
type: INNER_PRODUCT
bottom: "pool2"
top: "ip1"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 500
}
}
layers {
name: "relu1"
type: RELU
bottom: "ip1"
top: "ip1"
}
layers {
name: "ip2"
type: INNER_PRODUCT
bottom: "ip1"
top: "ip2"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 10
}
}
layers {
name: "feat"
type: INNER_PRODUCT
bottom: "ip2"
top: "feat"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
}
}

Просмотреть файл

@ -0,0 +1,25 @@
# The train/test net protocol buffer definition
net: "examples/siamese/mnist_siamese_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0000
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 50000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/siamese/mnist_siamese"
# solver mode: CPU or GPU
solver_mode: GPU

Просмотреть файл

@ -0,0 +1,313 @@
name: "mnist_siamese_train_test"
layers {
name: "pair_data"
type: DATA
top: "pair_data"
top: "sim"
data_param {
source: "examples/siamese/mnist_siamese_train_leveldb"
scale: 0.00390625
batch_size: 64
}
include: { phase: TRAIN }
}
layers {
name: "pair_data"
type: DATA
top: "pair_data"
top: "sim"
data_param {
source: "examples/siamese/mnist_siamese_test_leveldb"
scale: 0.00390625
batch_size: 100
}
include: { phase: TEST }
}
layers {
name: "slice_pair"
type: SLICE
bottom: "pair_data"
top: "data"
top: "data_p"
slice_param {
slice_dim: 1
slice_point: 1
}
}
layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "conv1_w"
param: "conv1_b"
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "conv2_w"
param: "conv2_b"
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "ip1"
type: INNER_PRODUCT
bottom: "pool2"
top: "ip1"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 500
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "ip1_w"
param: "ip1_b"
}
layers {
name: "relu1"
type: RELU
bottom: "ip1"
top: "ip1"
}
layers {
name: "ip2"
type: INNER_PRODUCT
bottom: "ip1"
top: "ip2"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "ip2_w"
param: "ip2_b"
}
layers {
name: "feat"
type: INNER_PRODUCT
bottom: "ip2"
top: "feat"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "feat_w"
param: "feat_b"
}
layers {
name: "conv1_p"
type: CONVOLUTION
bottom: "data_p"
top: "conv1_p"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "conv1_w"
param: "conv1_b"
}
layers {
name: "pool1_p"
type: POOLING
bottom: "conv1_p"
top: "pool1_p"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2_p"
type: CONVOLUTION
bottom: "pool1_p"
top: "conv2_p"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "conv2_w"
param: "conv2_b"
}
layers {
name: "pool2_p"
type: POOLING
bottom: "conv2_p"
top: "pool2_p"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "ip1_p"
type: INNER_PRODUCT
bottom: "pool2_p"
top: "ip1_p"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 500
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "ip1_w"
param: "ip1_b"
}
layers {
name: "relu1_p"
type: RELU
bottom: "ip1_p"
top: "ip1_p"
}
layers {
name: "ip2_p"
type: INNER_PRODUCT
bottom: "ip1_p"
top: "ip2_p"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "ip2_w"
param: "ip2_b"
}
layers {
name: "feat_p"
type: INNER_PRODUCT
bottom: "ip2_p"
top: "feat_p"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
param: "feat_w"
param: "feat_b"
}
layers {
name: "loss"
type: CONTRASTIVE_LOSS
contrastive_loss_param {
margin: 1.0
}
bottom: "feat"
bottom: "feat_p"
bottom: "sim"
top: "loss"
}

179
examples/siamese/readme.md Normal file
Просмотреть файл

@ -0,0 +1,179 @@
---
title: Siamese Network Tutorial
description: Train and test a siamese network on MNIST data.
category: example
include_in_docs: true
layout: default
priority: 100
---
# Siamese Network Training with Caffe
This example shows how you can use weight sharing and a contrastive loss
function to learn a model using a siamese network in Caffe.
We will assume that you have caffe successfully compiled. If not, please refer
to the [Installation page](../../installation.html). This example builds on the
[MNIST tutorial](mnist.html) so it would be a good idea to read that before
continuing.
*The guide specifies all paths and assumes all commands are executed from the
root caffe directory*
## Prepare Datasets
You will first need to download and convert the data from the MNIST
website. To do this, simply run the following commands:
./data/mnist/get_mnist.sh
./examples/siamese/create_mnist_siamese.sh
After running the script there should be two datasets,
`./examples/siamese/mnist_siamese_train_leveldb`, and
`./examples/siamese/mnist_siamese_test_leveldb`.
## The Model
First, we will define the model that we want to train using the siamese network.
We will use the convolutional net defined in
`./examples/siamese/mnist_siamese.prototxt`. This model is almost
exactly the same as the [LeNet model](mnist.html), the only difference is that
we have replaced the top layers that produced probabilities over the 10 digit
classes with a linear "feature" layer that produces a 2 dimensional vector.
layers {
name: "feat"
type: INNER_PRODUCT
bottom: "ip2"
top: "feat"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
}
}
## Define the Siamese Network
In this section we will define the siamese network used for training. The
resulting network is defined in
`./examples/siamese/mnist_siamese_train_test.prototxt`.
### Reading in the Pair Data
We start with a data layer that reads from the LevelDB database we created
earlier. Each entry in this database contains the image data for a pair of
images (`pair_data`) and a binary label saying if they belong to the same class
or different classes (`sim`).
layers {
name: "pair_data"
type: DATA
top: "pair_data"
top: "sim"
data_param {
source: "examples/siamese/mnist-siamese-train-leveldb"
scale: 0.00390625
batch_size: 64
}
include: { phase: TRAIN }
}
In order to pack a pair of images into the same blob in the database we pack one
image per channel. We want to be able to work with these two images separately,
so we add a slice layer after the data layer. This takes the `pair_data` and
slices it along the channel dimension so that we have a single image in `data`
and its paired image in `data_p.`
layers {
name: "slice_pair"
type: SLICE
bottom: "pair_data"
top: "data"
top: "data_p"
slice_param {
slice_dim: 1
slice_point: 1
}
}
### Building the First Side of the Siamese Net
Now we can specify the first side of the siamese net. This side operates on
`data` and produces `feat`. Starting from the net in
`./examples/siamese/mnist_siamese.prototxt` we add default weight fillers. Then
we name the parameters of the convolutional and inner product layers. Naming the
parameters allows Caffe to share the parameters between layers on both sides of
the siamese net. In the definition this looks like:
...
param: "conv1_w"
param: "conv1_b"
...
param: "conv2_w"
param: "conv2_b"
...
param: "ip1_w"
param: "ip1_b"
...
param: "ip2_w"
param: "ip2_b"
...
### Building the Second Side of the Siamese Net
Now we need to create the second path that operates on `data_p` and produces
`feat_p`. This path is exactly the same as the first. So we can just copy and
paste it. Then we change the name of each layer, input, and output by appending
`_p` to differentiate the "paired" layers from the originals.
### Adding the Contrastive Loss Function
To train the network we will optimize a contrastive loss function proposed in:
Raia Hadsell, Sumit Chopra, and Yann LeCun "Dimensionality Reduction by Learning
an Invariant Mapping". This loss function encourages matching pairs to be close
together in feature space while pushing non-matching pairs apart. This cost
function is implemented with the `CONTRASTIVE_LOSS` layer:
layers {
name: "loss"
type: CONTRASTIVE_LOSS
contrastive_loss_param {
margin: 1.0
}
bottom: "feat"
bottom: "feat_p"
bottom: "sim"
top: "loss"
}
## Define the Solver
Nothing special needs to be done to the solver besides pointing it at the
correct model file. The solver is defined in
`./examples/siamese/mnist_siamese_solver.prototxt`.
## Training and Testing the Model
Training the model is simple after you have written the network definition
protobuf and solver protobuf files. Simply run
`./examples/siamese/train_mnist_siamese.sh`:
./examples/siamese/train_mnist_siamese.sh
# Plotting the results
First, we can draw the model and siamese networks by running the following
commands that draw the DAGs defined in the .prototxt files:
./python/draw_net.py \
./examples/siamese/mnist_siamese.prototxt \
./examples/siamese/mnist_siamese.png
./python/draw_net.py \
./examples/siamese/mnist_siamese_train_test.prototxt \
./examples/siamese/mnist_siamese_train_test.png
Second, we can load the learned model and plot the features using the iPython
notebook:
ipython notebook ./examples/siamese/mnist_siamese.ipynb

Просмотреть файл

@ -0,0 +1,5 @@
#!/usr/bin/env sh
TOOLS=./build/tools
$TOOLS/caffe train --solver=examples/siamese/mnist_siamese_solver.prototxt

Просмотреть файл

@ -117,6 +117,93 @@ class LossLayer : public Layer<Dtype> {
}
};
/**
* @brief Computes the contrastive loss @f$
* E = \frac{1}{2N} \sum\limits_{n=1}^N \left(y\right) d +
* \left(1-y\right) \max \left(margin-d, 0\right)
* @f$ where @f$
* d = \left| \left| a_n - b_n \right| \right|_2^2 @f$. This can be
* used to train siamese networks.
*
* @param bottom input Blob vector (length 3)
* -# @f$ (N \times C \times 1 \times 1) @f$
* the features @f$ a \in [-\infty, +\infty]@f$
* -# @f$ (N \times C \times 1 \times 1) @f$
* the features @f$ b \in [-\infty, +\infty]@f$
* -# @f$ (N \times 1 \times 1 \times 1) @f$
* the binary similarity @f$ s \in [0, 1]@f$
* @param top output Blob vector (length 1)
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
* the computed contrastive loss: @f$ E =
* \frac{1}{2N} \sum\limits_{n=1}^N \left(y\right) d +
* \left(1-y\right) \max \left(margin-d, 0\right)
* @f$ where @f$
* d = \left| \left| a_n - b_n \right| \right|_2^2 @f$.
* This can be used to train siamese networks.
*/
template <typename Dtype>
class ContrastiveLossLayer : public LossLayer<Dtype> {
public:
explicit ContrastiveLossLayer(const LayerParameter& param)
: LossLayer<Dtype>(param), diff_() {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual inline int ExactNumBottomBlobs() const { return 3; }
virtual inline LayerParameter_LayerType type() const {
return LayerParameter_LayerType_CONTRASTIVE_LOSS;
}
/**
* Unlike most loss layers, in the ContrastiveLossLayer we can backpropagate
* to the first two inputs.
*/
virtual inline bool AllowForceBackward(const int bottom_index) const {
return bottom_index != 2;
}
protected:
/// @copydoc ContrastiveLossLayer
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
/**
* @brief Computes the Contrastive error gradient w.r.t. the inputs.
*
* Computes the gradients with respect to the two input vectors (bottom[0] and
* bottom[1]), but not the similarity label (bottom[2]).
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (1 \times 1 \times 1 \times 1) @f$
* This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$,
* as @f$ \lambda @f$ is the coefficient of this layer's output
* @f$\ell_i@f$ in the overall Net loss
* @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence
* @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$.
* (*Assuming that this top Blob is not used as a bottom (input) by any
* other layer of the Net.)
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 2)
* -# @f$ (N \times C \times 1 \times 1) @f$
* the features @f$a@f$; Backward fills their diff with
* gradients if propagate_down[0]
* -# @f$ (N \times C \times 1 \times 1) @f$
* the features @f$b@f$; Backward fills their diff with gradients if
* propagate_down[1]
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Blob<Dtype> diff_; // cached for backward pass
Blob<Dtype> dist_sq_; // cached for backward pass
Blob<Dtype> diff_sq_; // tmp storage for gpu forward pass
Blob<Dtype> summer_vec_; // tmp storage for gpu forward pass
};
/**
* @brief Computes the Euclidean (L2) loss @f$
* E = \frac{1}{2N} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n

Просмотреть файл

@ -189,6 +189,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new BNLLLayer<Dtype>(param);
case LayerParameter_LayerType_CONCAT:
return new ConcatLayer<Dtype>(param);
case LayerParameter_LayerType_CONTRASTIVE_LOSS:
return new ContrastiveLossLayer<Dtype>(param);
case LayerParameter_LayerType_CONVOLUTION:
return GetConvolutionLayer<Dtype>(name, param);
case LayerParameter_LayerType_DATA:

Просмотреть файл

@ -0,0 +1,101 @@
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/loss_layers.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), 1);
CHECK_EQ(bottom[0]->width(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
CHECK_EQ(bottom[2]->channels(), 1);
CHECK_EQ(bottom[2]->height(), 1);
CHECK_EQ(bottom[2]->width(), 1);
diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
diff_sq_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
dist_sq_.Reshape(bottom[0]->num(), 1, 1, 1);
// vector of ones used to sum along channels
summer_vec_.Reshape(bottom[0]->channels(), 1, 1, 1);
for (int i = 0; i < bottom[0]->channels(); ++i)
summer_vec_.mutable_cpu_data()[i] = Dtype(1);
}
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
int count = bottom[0]->count();
caffe_sub(
count,
bottom[0]->cpu_data(), // a
bottom[1]->cpu_data(), // b
diff_.mutable_cpu_data()); // a_i-b_i
const int channels = bottom[0]->channels();
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
Dtype loss(0.0);
for (int i = 0; i < bottom[0]->num(); ++i) {
dist_sq_.mutable_cpu_data()[i] = caffe_cpu_dot(channels,
diff_.cpu_data() + (i*channels), diff_.cpu_data() + (i*channels));
if (static_cast<int>(bottom[2]->cpu_data()[i])) { // similar pairs
loss += dist_sq_.cpu_data()[i];
} else { // dissimilar pairs
loss += std::max(margin-dist_sq_.cpu_data()[i], Dtype(0.0));
}
}
loss = loss / static_cast<Dtype>(bottom[0]->num()) / Dtype(2);
(*top)[0]->mutable_cpu_data()[0] = loss;
}
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] /
static_cast<Dtype>((*bottom)[i]->num());
int num = (*bottom)[i]->num();
int channels = (*bottom)[i]->channels();
for (int j = 0; j < num; ++j) {
Dtype* bout = (*bottom)[i]->mutable_cpu_diff();
if (static_cast<int>((*bottom)[2]->cpu_data()[j])) { // similar pairs
caffe_cpu_axpby(
channels,
alpha,
diff_.cpu_data() + (j*channels),
Dtype(0.0),
bout + (j*channels));
} else { // dissimilar pairs
if ((margin-dist_sq_.cpu_data()[j]) > Dtype(0.0)) {
caffe_cpu_axpby(
channels,
-alpha,
diff_.cpu_data() + (j*channels),
Dtype(0.0),
bout + (j*channels));
} else {
caffe_set(channels, Dtype(0), bout + (j*channels));
}
}
}
}
}
}
#ifdef CPU_ONLY
STUB_GPU(ContrastiveLossLayer);
#endif
INSTANTIATE_CLASS(ContrastiveLossLayer);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,91 @@
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
const int count = bottom[0]->count();
caffe_gpu_sub(
count,
bottom[0]->gpu_data(), // a
bottom[1]->gpu_data(), // b
diff_.mutable_gpu_data()); // a_i-b_i
caffe_gpu_powx(
count,
diff_.mutable_gpu_data(), // a_i-b_i
Dtype(2),
diff_sq_.mutable_gpu_data()); // (a_i-b_i)^2
caffe_gpu_gemv(
CblasNoTrans,
bottom[0]->num(),
bottom[0]->channels(),
Dtype(1.0),
diff_sq_.gpu_data(), // (a_i-b_i)^2
summer_vec_.gpu_data(),
Dtype(0.0),
dist_sq_.mutable_gpu_data()); // \Sum (a_i-b_i)^2
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
Dtype loss(0.0);
for (int i = 0; i < bottom[0]->num(); ++i) {
if (static_cast<int>(bottom[2]->cpu_data()[i])) { // similar pairs
loss += dist_sq_.cpu_data()[i];
} else { // dissimilar pairs
loss += std::max(margin-dist_sq_.cpu_data()[i], Dtype(0.0));
}
}
loss = loss / static_cast<Dtype>(bottom[0]->num()) / Dtype(2);
(*top)[0]->mutable_cpu_data()[0] = loss;
}
template <typename Dtype>
__global__ void CLLForward(const int count, const int channels,
const Dtype margin, const Dtype alpha,
const Dtype* y, const Dtype* diff, const Dtype* dist_sq,
Dtype *bottom_diff) {
CUDA_KERNEL_LOOP(i, count) {
int n = i / channels; // the num index, to access y and dist_sq
if (static_cast<int>(y[n])) { // similar pairs
bottom_diff[i] = alpha * diff[i];
} else { // dissimilar pairs
if ((margin-dist_sq[n]) > 0.0) {
bottom_diff[i] = -alpha * diff[i];
} else {
bottom_diff[i] = 0;
}
}
}
}
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const int count = (*bottom)[0]->count();
const int channels = (*bottom)[0]->channels();
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] /
static_cast<Dtype>((*bottom)[0]->num());
// NOLINT_NEXT_LINE(whitespace/operators)
CLLForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, channels, margin, alpha,
(*bottom)[2]->gpu_data(), // pair similarity 0 or 1
diff_.gpu_data(), // the cached eltwise difference between a and b
dist_sq_.gpu_data(), // the cached square distance between a and b
(*bottom)[i]->mutable_gpu_diff());
CUDA_POST_KERNEL_CHECK;
}
}
}
INSTANTIATE_CLASS(ContrastiveLossLayer);
} // namespace caffe

Просмотреть файл

@ -198,7 +198,7 @@ message NetStateRule {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available ID: 40 (last added: softmax_param)
// LayerParameter next available ID: 41 (last added: contrastive_loss_param)
message LayerParameter {
repeated string bottom = 2; // the name of the bottom blobs
repeated string top = 3; // the name of the top blobs
@ -219,7 +219,7 @@ message LayerParameter {
// line above the enum. Update the next available ID when you add a new
// LayerType.
//
// LayerType next available ID: 37 (last added: SILENCE)
// LayerType next available ID: 38 (last added: CONTRASTIVE_LOSS)
enum LayerType {
// "NONE" layer type is 0th enum element so that we don't cause confusion
// by defaulting to an existent LayerType (instead, should usually error if
@ -230,6 +230,7 @@ message LayerParameter {
ARGMAX = 30;
BNLL = 2;
CONCAT = 3;
CONTRASTIVE_LOSS = 37;
CONVOLUTION = 4;
DATA = 5;
DROPOUT = 6;
@ -292,6 +293,7 @@ message LayerParameter {
optional AccuracyParameter accuracy_param = 27;
optional ArgMaxParameter argmax_param = 23;
optional ConcatParameter concat_param = 9;
optional ContrastiveLossParameter contrastive_loss_param = 40;
optional ConvolutionParameter convolution_param = 10;
optional DataParameter data_param = 11;
optional DropoutParameter dropout_param = 12;
@ -367,6 +369,12 @@ message ConcatParameter {
optional uint32 concat_dim = 1 [default = 1];
}
// Message that stores parameters used by ContrastiveLossLayer
message ContrastiveLossParameter {
//margin for dissimilar pair
optional float margin = 1 [default = 1.0];
}
// Message that stores parameters used by ConvolutionLayer
message ConvolutionParameter {
optional uint32 num_output = 1; // The number of outputs for the layer

Просмотреть файл

@ -0,0 +1,102 @@
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <vector>
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/test/test_caffe_main.hpp"
#include "caffe/test/test_gradient_check_util.hpp"
namespace caffe {
template <typename TypeParam>
class ContrastiveLossLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
protected:
ContrastiveLossLayerTest()
: blob_bottom_data_i_(new Blob<Dtype>(128, 10, 1, 1)),
blob_bottom_data_j_(new Blob<Dtype>(128, 10, 1, 1)),
blob_bottom_y_(new Blob<Dtype>(128, 1, 1, 1)),
blob_top_loss_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
filler_param.set_mean(0.0);
filler_param.set_std(0.3); // distances~=1.0 to test both sides of margin
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_data_i_);
blob_bottom_vec_.push_back(blob_bottom_data_i_);
filler.Fill(this->blob_bottom_data_j_);
blob_bottom_vec_.push_back(blob_bottom_data_j_);
for (int i = 0; i < blob_bottom_y_->count(); ++i) {
blob_bottom_y_->mutable_cpu_data()[i] = caffe_rng_rand() % 2; // 0 or 1
}
blob_bottom_vec_.push_back(blob_bottom_y_);
blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~ContrastiveLossLayerTest() {
delete blob_bottom_data_i_;
delete blob_bottom_data_j_;
delete blob_bottom_y_;
delete blob_top_loss_;
}
Blob<Dtype>* const blob_bottom_data_i_;
Blob<Dtype>* const blob_bottom_data_j_;
Blob<Dtype>* const blob_bottom_y_;
Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
TYPED_TEST_CASE(ContrastiveLossLayerTest, TestDtypesAndDevices);
TYPED_TEST(ContrastiveLossLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ContrastiveLossLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
// manually compute to compare
const Dtype margin = layer_param.contrastive_loss_param().margin();
const int num = this->blob_bottom_data_i_->num();
const int channels = this->blob_bottom_data_i_->channels();
Dtype loss(0);
for (int i = 0; i < num; ++i) {
Dtype dist_sq(0);
for (int j = 0; j < channels; ++j) {
Dtype diff = this->blob_bottom_data_i_->cpu_data()[i*channels+j] -
this->blob_bottom_data_j_->cpu_data()[i*channels+j];
dist_sq += diff*diff;
}
if (this->blob_bottom_y_->cpu_data()[i]) { // similar pairs
loss += dist_sq;
} else {
loss += std::max(margin-dist_sq, Dtype(0));
}
}
loss /= static_cast<Dtype>(num) * Dtype(2);
EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, 1e-6);
}
TYPED_TEST(ContrastiveLossLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ContrastiveLossLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<Dtype> checker(1e-2, 1e-2, 1701);
// check the gradient for the first two bottom layers
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_), 0);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_), 1);
}
} // namespace caffe