Fix bugs in the image retrieval example

This commit is contained in:
Kai Li 2014-02-26 05:34:23 +08:00
Родитель cfb2f915b9
Коммит 23eecde6b7
1 изменённых файлов: 36 добавлений и 84 удалений

Просмотреть файл

@ -19,7 +19,7 @@ template<typename Dtype>
void similarity_search(
const vector<shared_ptr<Blob<Dtype> > >& sample_binary_feature_blobs,
const shared_ptr<Blob<Dtype> > query_binary_feature,
const int top_k_results, shared_ptr<Blob<Dtype> > retrieval_results);
const int top_k_results, vector<vector<Dtype> >* retrieval_results);
template<typename Dtype>
int image_retrieval_pipeline(int argc, char** argv);
@ -35,7 +35,7 @@ int image_retrieval_pipeline(int argc, char** argv) {
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in binarized features of query images and sample images"
" extracted by Caffe to retrieve similar images."
" extracted by Caffe to retrieve similar images.\n"
"Usage: demo_retrieve_images sample_binary_features_binaryproto_file"
" query_binary_features_binaryproto_file save_retrieval_result_filename"
" [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]";
@ -67,10 +67,9 @@ int image_retrieval_pipeline(int argc, char** argv) {
}
Caffe::set_phase(Caffe::TEST);
NetParameter pretrained_net_param;
arg_pos = 0; // the name of the executable
LOG(ERROR)<< "Loading sample binary features";
string sample_binary_features_binaryproto_file(argv[++arg_pos]);
BlobProtoVector sample_binary_features;
ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file,
@ -87,92 +86,47 @@ int image_retrieval_pipeline(int argc, char** argv) {
top_k_results = num_samples;
}
LOG(ERROR)<< "Loading query binary features";
string query_images_feature_blob_binaryproto(argv[++arg_pos]);
BlobProtoVector query_images_features;
ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto,
&query_images_features);
vector<shared_ptr<Blob<Dtype> > > query_binary_feature_blobs;
for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
for (int i = 0; i < query_images_features.blobs_size(); ++i) {
shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
blob->FromProto(query_images_features.blobs(i));
query_binary_feature_blobs.push_back(blob);
}
string save_retrieval_result_filename(argv[++arg_pos]);
LOG(ERROR)<< "Opening result file " << save_retrieval_result_filename;
std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(),
std::ofstream::out);
LOG(ERROR)<< "Retrieving images";
shared_ptr<Blob<Dtype> > retrieval_results;
vector<vector<Dtype> > retrieval_results;
int query_image_index = 0;
int num_bytes_of_binary_code = sizeof(Dtype);
int num_query_batches = query_binary_feature_blobs.size();
for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
LOG(ERROR)<< "Batch " << batch_index << " image retrieval";
similarity_search<Dtype>(sample_binary_feature_blobs,
query_binary_feature_blobs[batch_index],
top_k_results, retrieval_results);
LOG(ERROR) << "Batch " << batch_index << " save image retrieval results";
int num_results = retrieval_results->num();
const Dtype* retrieval_results_data = retrieval_results->cpu_data();
query_binary_feature_blobs[batch_index],
top_k_results, &retrieval_results);
int num_results = retrieval_results.size();
for (int i = 0; i < num_results; ++i) {
retrieval_result_ofs << ++query_image_index;
retrieval_results_data += retrieval_results->offset(i);
for (int j = 0; j < top_k_results; ++j) {
retrieval_result_ofs << " " << retrieval_results_data[j];
retrieval_result_ofs << query_image_index++;
for (int j = 0; j < retrieval_results[i].size(); ++j) {
retrieval_result_ofs << " " << retrieval_results[i][j];
}
retrieval_result_ofs << "\n";
}
} // for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
retrieval_result_ofs.close();
LOG(ERROR)<< "Successfully ended!";
LOG(ERROR)<< "Successfully retrieved similar images for " << query_image_index << " queries!";
return 0;
}
template<typename Dtype>
void binarize(const int n, const Dtype* real_valued_feature,
Dtype* binary_codes) {
// TODO: more advanced binarization algorithm such as bilinear projection
// Yunchao Gong, Sanjiv Kumar, Henry A. Rowley, and Svetlana Lazebnik.
// Learning Binary Codes for High-Dimensional Data Using Bilinear Projections.
// In IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), 2013.
// http://www.unc.edu/~yunchao/bpbc.htm
int size_of_code = sizeof(Dtype) * 8;
CHECK_EQ(n % size_of_code, 0);
int num_binary_codes = n / size_of_code;
uint64_t code;
int offset;
for (int i = 0; i < num_binary_codes; ++i) {
code = 0;
offset = i * size_of_code;
for (int j = 0; j < size_of_code; ++j) {
code |= sign(real_valued_feature[offset + j]);
code << 1;
}
binary_codes[i] = static_cast<Dtype>(code);
}
}
template<typename Dtype>
void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
shared_ptr<Blob<Dtype> > binary_codes) {
int num = real_valued_features->num();
int dim = real_valued_features->count() / num;
int size_of_code = sizeof(Dtype) * 8;
CHECK_EQ(dim % size_of_code, 0);
binary_codes->Reshape(num, dim / size_of_code, 1, 1);
const Dtype* real_valued_features_data = real_valued_features->cpu_data();
Dtype* binary_codes_data = binary_codes->mutable_cpu_data();
for (int n = 0; n < num; ++n) {
binarize<Dtype>(dim,
real_valued_features_data + real_valued_features->offset(n),
binary_codes_data + binary_codes->offset(n));
}
}
class MinHeapComparison {
public:
bool operator()(const std::pair<int, int>& lhs,
@ -185,39 +139,37 @@ template<typename Dtype>
void similarity_search(
const vector<shared_ptr<Blob<Dtype> > >& sample_images_feature_blobs,
const shared_ptr<Blob<Dtype> > query_image_feature, const int top_k_results,
shared_ptr<Blob<Dtype> > retrieval_results) {
vector<vector<Dtype> >* retrieval_results) {
int num_queries = query_image_feature->num();
int dim = query_image_feature->count() / num_queries;
int hamming_dist;
retrieval_results->Reshape(num_queries, top_k_results, 1, 1);
Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
retrieval_results->resize(num_queries);
std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
MinHeapComparison> results;
for (int i = 0; i < num_queries; ++i) {
std::priority_queue<std::pair<int, int>,
std::vector<std::pair<int, int> >, MinHeapComparison> results;
for (int num_sample_blob;
num_sample_blob < sample_images_feature_blobs.size();
++num_sample_blob) {
shared_ptr<Blob<Dtype> > sample_images_feature =
sample_images_feature_blobs[num_sample_blob];
int num_samples = sample_images_feature->num();
for (int j = 0; j < num_samples; ++j) {
while (!results.empty()) {
results.pop();
}
for (int j = 0; j < sample_images_feature_blobs.size(); ++j) {
int num_samples = sample_images_feature_blobs[j]->num();
for (int k = 0; k < num_samples; ++k) {
hamming_dist = caffe_hamming_distance(
dim,
query_image_feature->cpu_data() + query_image_feature->offset(i),
sample_images_feature->cpu_data()
+ sample_images_feature->offset(j));
sample_images_feature_blobs[j]->cpu_data()
+ sample_images_feature_blobs[j]->offset(k));
if (results.size() < top_k_results) {
results.push(std::make_pair(-hamming_dist, j));
results.push(std::make_pair(-hamming_dist, k));
} else if (-hamming_dist > results.top().first) { // smaller hamming dist
results.pop();
results.push(std::make_pair(-hamming_dist, j));
results.push(std::make_pair(-hamming_dist, k));
}
} // for (int j = 0; j < num_samples; ++j) {
retrieval_results_data += retrieval_results->offset(i);
for (int k = 0; k < results.size(); ++k) {
retrieval_results_data[k] = results.top().second;
results.pop();
}
} // for(...; sample_images_feature_blobs.size(); ...)
} // for (int i = 0; i < num_queries; ++i) {
} // for (int k = 0; k < num_samples; ++k) {
} // for (int j = 0; j < sample_images_feature_blobs.size(); ++j)
retrieval_results->at(i).resize(results.size());
for (int k = results.size() - 1; k >= 0; --k) {
retrieval_results->at(i)[k] = results.top().second;
results.pop();
}
} // for (int i = 0; i < num_queries; ++i) {
}