Merge pull request #221 from jamt9000/fix-dump-network

Fix dump network
This commit is contained in:
Evan Shelhamer 2014-03-19 09:56:20 -07:00
Родитель ee0b50a083 28d27ee908
Коммит 32ee91c204
1 изменённых файлов: 3 добавлений и 3 удалений

Просмотреть файл

@ -28,7 +28,6 @@
using namespace caffe; // NOLINT(build/namespaces)
int main(int argc, char** argv) {
cudaSetDevice(1);
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TEST);
@ -44,10 +43,10 @@ int main(int argc, char** argv) {
ReadProtoFromBinaryFile(argv[2], &trained_net_param);
vector<Blob<float>* > input_vec;
shared_ptr<Blob<float> > input_blob(new Blob<float>());
if (strcmp(argv[3], "none") != 0) {
BlobProto input_blob_proto;
ReadProtoFromBinaryFile(argv[3], &input_blob_proto);
shared_ptr<Blob<float> > input_blob(new Blob<float>());
input_blob->FromProto(input_blob_proto);
input_vec.push_back(input_blob.get());
}
@ -59,8 +58,9 @@ int main(int argc, char** argv) {
// Run the network without training.
LOG(ERROR) << "Performing Forward";
caffe_net->Forward(input_vec);
if (argc > 4 && strcmp(argv[4], "1")) {
if (argc > 5 && strcmp(argv[5], "1") == 0) {
LOG(ERROR) << "Performing Backward";
Caffe::set_phase(Caffe::TRAIN);
caffe_net->Backward();
// Dump the network
NetParameter output_net_param;