Enhance help, log message & format of the feature extraction example

This commit is contained in:
Kai Li 2014-02-26 03:46:32 +08:00
Родитель dfe63805e9
Коммит 01bb481702
1 изменённых файлов: 70 добавлений и 70 удалений

Просмотреть файл

@ -15,7 +15,6 @@
using namespace caffe; using namespace caffe;
template<typename Dtype> template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv); int feature_extraction_pipeline(int argc, char** argv);
@ -29,11 +28,11 @@ int feature_extraction_pipeline(int argc, char** argv) {
const int num_required_args = 6; const int num_required_args = 6;
if (argc < num_required_args) { if (argc < num_required_args) {
LOG(ERROR)<< LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then" "This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n" " extract features of the input data produced by the net.\n"
"Usage: demo_extract_features pretrained_net_param" "Usage: demo_extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name" " feature_extraction_proto_file extract_feature_blob_name"
" save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]"; " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
return 1; return 1;
} }
int arg_pos = num_required_args; int arg_pos = num_required_args;
@ -63,33 +62,34 @@ int feature_extraction_pipeline(int argc, char** argv) {
&pretrained_net_param); &pretrained_net_param);
// Expected prototxt contains at least one data layer such as // Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the // the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features. // fc7 top blob to extract features.
/* /*
layers { layers {
layer { layer {
name: "data_layer_name" name: "data_layer_name"
type: "data" type: "data"
source: "/path/to/your/images/to/extract/feature/images_leveldb" source: "/path/to/your/images/to/extract/feature/images_leveldb"
meanfile: "/path/to/your/image_mean.binaryproto" meanfile: "/path/to/your/image_mean.binaryproto"
batchsize: 128 batchsize: 128
cropsize: 227 cropsize: 227
mirror: false mirror: false
} }
top: "data_blob_name" top: "data_blob_name"
top: "label_blob_name" top: "label_blob_name"
} }
layers { layers {
layer { layer {
name: "drop7" name: "drop7"
type: "dropout" type: "dropout"
dropout_ratio: 0.5 dropout_ratio: 0.5
} }
bottom: "fc7" bottom: "fc7"
top: "fc7" top: "fc7"
} }
*/ */
NetParameter feature_extraction_net_param;; NetParameter feature_extraction_net_param;
;
string feature_extraction_proto(argv[++arg_pos]); string feature_extraction_proto(argv[++arg_pos]);
ReadProtoFromTextFile(feature_extraction_proto, ReadProtoFromTextFile(feature_extraction_proto,
&feature_extraction_net_param); &feature_extraction_net_param);
@ -98,11 +98,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param); feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
string extract_feature_blob_name(argv[++arg_pos]); string extract_feature_blob_name(argv[++arg_pos]);
if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) { CHECK(feature_extraction_net->HasBlob(extract_feature_blob_name))
LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name << << "Unknown feature blob name " << extract_feature_blob_name
" in the network " << feature_extraction_proto; << " in the network " << feature_extraction_proto;
return 1;
}
string save_feature_leveldb_name(argv[++arg_pos]); string save_feature_leveldb_name(argv[++arg_pos]);
leveldb::DB* db; leveldb::DB* db;
@ -110,9 +108,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
options.error_if_exists = true; options.error_if_exists = true;
options.create_if_missing = true; options.create_if_missing = true;
options.write_buffer_size = 268435456; options.write_buffer_size = 268435456;
LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name; LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
leveldb::Status status = leveldb::DB::Open( leveldb::Status status = leveldb::DB::Open(options,
options, save_feature_leveldb_name.c_str(), &db); save_feature_leveldb_name.c_str(),
&db);
CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name; CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
int num_mini_batches = atoi(argv[++arg_pos]); int num_mini_batches = atoi(argv[++arg_pos]);
@ -124,51 +123,52 @@ int feature_extraction_pipeline(int argc, char** argv) {
const int max_key_str_length = 100; const int max_key_str_length = 100;
char key_str[max_key_str_length]; char key_str[max_key_str_length];
int num_bytes_of_binary_code = sizeof(Dtype); int num_bytes_of_binary_code = sizeof(Dtype);
vector<Blob<float>* > input_vec; vector<Blob<float>*> input_vec;
int image_index = 0; int image_index = 0;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec); feature_extraction_net->Forward(input_vec);
const shared_ptr<Blob<Dtype> > feature_blob = const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
feature_extraction_net->GetBlob(extract_feature_blob_name); ->GetBlob(extract_feature_blob_name);
int num_features = feature_blob->num(); int num_features = feature_blob->num();
int dim_features = feature_blob->count() / num_features; int dim_features = feature_blob->count() / num_features;
for (int n = 0; n < num_features; ++n) { for (int n = 0; n < num_features; ++n) {
datum.set_height(dim_features); datum.set_height(dim_features);
datum.set_width(1); datum.set_width(1);
datum.set_channels(1); datum.set_channels(1);
datum.clear_data(); datum.clear_data();
datum.clear_float_data(); datum.clear_float_data();
string* datum_string = datum.mutable_data(); string* datum_string = datum.mutable_data();
const Dtype* feature_blob_data = feature_blob->cpu_data(); const Dtype* feature_blob_data = feature_blob->cpu_data();
for (int d = 0; d < dim_features; ++d) { for (int d = 0; d < dim_features; ++d) {
const char* data_byte = reinterpret_cast<const char*>(feature_blob_data + d); const char* data_byte = reinterpret_cast<const char*>(feature_blob_data
for(int i = 0; i < num_bytes_of_binary_code; ++i) { + d);
datum_string->push_back(data_byte[i]); for (int i = 0; i < num_bytes_of_binary_code; ++i) {
} datum_string->push_back(data_byte[i]);
} }
string value; }
datum.SerializeToString(&value); string value;
snprintf(key_str, max_key_str_length, "%d", image_index); datum.SerializeToString(&value);
batch->Put(string(key_str), value); snprintf(key_str, max_key_str_length, "%d", image_index);
if (++image_index % 1000 == 0) { batch->Put(string(key_str), value);
db->Write(leveldb::WriteOptions(), batch); if (++image_index % 1000 == 0) {
LOG(ERROR) << "Extracted features of " << image_index << " query images."; db->Write(leveldb::WriteOptions(), batch);
delete batch; LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
batch = new leveldb::WriteBatch(); delete batch;
} batch = new leveldb::WriteBatch();
}
} }
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch // write the last batch
if (image_index % 1000 != 0) { if (image_index % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch); db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Extracted features of " << image_index << " query images."; LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
delete batch; delete batch;
batch = new leveldb::WriteBatch(); batch = new leveldb::WriteBatch();
} }
delete batch; delete batch;
delete db; delete db;
LOG(ERROR)<< "Successfully ended!"; LOG(ERROR)<< "Successfully extracted the features!";
return 0; return 0;
} }