зеркало из https://github.com/microsoft/caffe.git
Enhance help, log message & format of the feature extraction example
This commit is contained in:
Родитель
dfe63805e9
Коммит
01bb481702
|
@ -15,7 +15,6 @@
|
|||
|
||||
using namespace caffe;
|
||||
|
||||
|
||||
template<typename Dtype>
|
||||
int feature_extraction_pipeline(int argc, char** argv);
|
||||
|
||||
|
@ -29,11 +28,11 @@ int feature_extraction_pipeline(int argc, char** argv) {
|
|||
const int num_required_args = 6;
|
||||
if (argc < num_required_args) {
|
||||
LOG(ERROR)<<
|
||||
"This program takes in a trained network and an input data layer, and then"
|
||||
" extract features of the input data produced by the net.\n"
|
||||
"Usage: demo_extract_features pretrained_net_param"
|
||||
" feature_extraction_proto_file extract_feature_blob_name"
|
||||
" save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
|
||||
"This program takes in a trained network and an input data layer, and then"
|
||||
" extract features of the input data produced by the net.\n"
|
||||
"Usage: demo_extract_features pretrained_net_param"
|
||||
" feature_extraction_proto_file extract_feature_blob_name"
|
||||
" save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
|
||||
return 1;
|
||||
}
|
||||
int arg_pos = num_required_args;
|
||||
|
@ -63,33 +62,34 @@ int feature_extraction_pipeline(int argc, char** argv) {
|
|||
&pretrained_net_param);
|
||||
|
||||
// Expected prototxt contains at least one data layer such as
|
||||
// the layer data_layer_name and one feature blob such as the
|
||||
// fc7 top blob to extract features.
|
||||
/*
|
||||
layers {
|
||||
layer {
|
||||
name: "data_layer_name"
|
||||
type: "data"
|
||||
source: "/path/to/your/images/to/extract/feature/images_leveldb"
|
||||
meanfile: "/path/to/your/image_mean.binaryproto"
|
||||
batchsize: 128
|
||||
cropsize: 227
|
||||
mirror: false
|
||||
}
|
||||
top: "data_blob_name"
|
||||
top: "label_blob_name"
|
||||
}
|
||||
layers {
|
||||
layer {
|
||||
name: "drop7"
|
||||
type: "dropout"
|
||||
dropout_ratio: 0.5
|
||||
}
|
||||
bottom: "fc7"
|
||||
top: "fc7"
|
||||
}
|
||||
*/
|
||||
NetParameter feature_extraction_net_param;;
|
||||
// the layer data_layer_name and one feature blob such as the
|
||||
// fc7 top blob to extract features.
|
||||
/*
|
||||
layers {
|
||||
layer {
|
||||
name: "data_layer_name"
|
||||
type: "data"
|
||||
source: "/path/to/your/images/to/extract/feature/images_leveldb"
|
||||
meanfile: "/path/to/your/image_mean.binaryproto"
|
||||
batchsize: 128
|
||||
cropsize: 227
|
||||
mirror: false
|
||||
}
|
||||
top: "data_blob_name"
|
||||
top: "label_blob_name"
|
||||
}
|
||||
layers {
|
||||
layer {
|
||||
name: "drop7"
|
||||
type: "dropout"
|
||||
dropout_ratio: 0.5
|
||||
}
|
||||
bottom: "fc7"
|
||||
top: "fc7"
|
||||
}
|
||||
*/
|
||||
NetParameter feature_extraction_net_param;
|
||||
;
|
||||
string feature_extraction_proto(argv[++arg_pos]);
|
||||
ReadProtoFromTextFile(feature_extraction_proto,
|
||||
&feature_extraction_net_param);
|
||||
|
@ -98,11 +98,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
|
|||
feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
|
||||
|
||||
string extract_feature_blob_name(argv[++arg_pos]);
|
||||
if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
|
||||
LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
|
||||
" in the network " << feature_extraction_proto;
|
||||
return 1;
|
||||
}
|
||||
CHECK(feature_extraction_net->HasBlob(extract_feature_blob_name))
|
||||
<< "Unknown feature blob name " << extract_feature_blob_name
|
||||
<< " in the network " << feature_extraction_proto;
|
||||
|
||||
string save_feature_leveldb_name(argv[++arg_pos]);
|
||||
leveldb::DB* db;
|
||||
|
@ -110,9 +108,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
|
|||
options.error_if_exists = true;
|
||||
options.create_if_missing = true;
|
||||
options.write_buffer_size = 268435456;
|
||||
LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name;
|
||||
leveldb::Status status = leveldb::DB::Open(
|
||||
options, save_feature_leveldb_name.c_str(), &db);
|
||||
LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
|
||||
leveldb::Status status = leveldb::DB::Open(options,
|
||||
save_feature_leveldb_name.c_str(),
|
||||
&db);
|
||||
CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
|
||||
|
||||
int num_mini_batches = atoi(argv[++arg_pos]);
|
||||
|
@ -124,51 +123,52 @@ int feature_extraction_pipeline(int argc, char** argv) {
|
|||
const int max_key_str_length = 100;
|
||||
char key_str[max_key_str_length];
|
||||
int num_bytes_of_binary_code = sizeof(Dtype);
|
||||
vector<Blob<float>* > input_vec;
|
||||
vector<Blob<float>*> input_vec;
|
||||
int image_index = 0;
|
||||
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
|
||||
feature_extraction_net->Forward(input_vec);
|
||||
const shared_ptr<Blob<Dtype> > feature_blob =
|
||||
feature_extraction_net->GetBlob(extract_feature_blob_name);
|
||||
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
|
||||
->GetBlob(extract_feature_blob_name);
|
||||
int num_features = feature_blob->num();
|
||||
int dim_features = feature_blob->count() / num_features;
|
||||
for (int n = 0; n < num_features; ++n) {
|
||||
datum.set_height(dim_features);
|
||||
datum.set_width(1);
|
||||
datum.set_channels(1);
|
||||
datum.clear_data();
|
||||
datum.clear_float_data();
|
||||
string* datum_string = datum.mutable_data();
|
||||
const Dtype* feature_blob_data = feature_blob->cpu_data();
|
||||
for (int d = 0; d < dim_features; ++d) {
|
||||
const char* data_byte = reinterpret_cast<const char*>(feature_blob_data + d);
|
||||
for(int i = 0; i < num_bytes_of_binary_code; ++i) {
|
||||
datum_string->push_back(data_byte[i]);
|
||||
}
|
||||
}
|
||||
string value;
|
||||
datum.SerializeToString(&value);
|
||||
snprintf(key_str, max_key_str_length, "%d", image_index);
|
||||
batch->Put(string(key_str), value);
|
||||
if (++image_index % 1000 == 0) {
|
||||
db->Write(leveldb::WriteOptions(), batch);
|
||||
LOG(ERROR) << "Extracted features of " << image_index << " query images.";
|
||||
delete batch;
|
||||
batch = new leveldb::WriteBatch();
|
||||
}
|
||||
datum.set_height(dim_features);
|
||||
datum.set_width(1);
|
||||
datum.set_channels(1);
|
||||
datum.clear_data();
|
||||
datum.clear_float_data();
|
||||
string* datum_string = datum.mutable_data();
|
||||
const Dtype* feature_blob_data = feature_blob->cpu_data();
|
||||
for (int d = 0; d < dim_features; ++d) {
|
||||
const char* data_byte = reinterpret_cast<const char*>(feature_blob_data
|
||||
+ d);
|
||||
for (int i = 0; i < num_bytes_of_binary_code; ++i) {
|
||||
datum_string->push_back(data_byte[i]);
|
||||
}
|
||||
}
|
||||
string value;
|
||||
datum.SerializeToString(&value);
|
||||
snprintf(key_str, max_key_str_length, "%d", image_index);
|
||||
batch->Put(string(key_str), value);
|
||||
if (++image_index % 1000 == 0) {
|
||||
db->Write(leveldb::WriteOptions(), batch);
|
||||
LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
|
||||
delete batch;
|
||||
batch = new leveldb::WriteBatch();
|
||||
}
|
||||
}
|
||||
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
|
||||
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
|
||||
// write the last batch
|
||||
if (image_index % 1000 != 0) {
|
||||
db->Write(leveldb::WriteOptions(), batch);
|
||||
LOG(ERROR) << "Extracted features of " << image_index << " query images.";
|
||||
LOG(ERROR)<< "Extracted features of " << image_index << " query images.";
|
||||
delete batch;
|
||||
batch = new leveldb::WriteBatch();
|
||||
}
|
||||
|
||||
delete batch;
|
||||
delete db;
|
||||
LOG(ERROR)<< "Successfully ended!";
|
||||
LOG(ERROR)<< "Successfully extracted the features!";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче