removed opencv dependency for easier distribution

This commit is contained in:
Yangqing Jia 2013-10-10 16:54:39 -07:00
Родитель 5badef47fb
Коммит 6ecaa904e5
11 изменённых файлов: 33 добавлений и 362 удалений

Просмотреть файл

@ -43,7 +43,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
INCLUDE_DIRS := . /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS := . /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand \
leveldb snappy opencv_core opencv_highgui pthread tcmalloc
leveldb snappy pthread tcmalloc
WARNINGS := -Wall
COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
@ -80,6 +80,10 @@ $(TEST_BINS): %.testbin : %.o $(GTEST_OBJ) $(STATIC_NAME)
$(PROGRAM_BINS): %.bin : %.o $(STATIC_NAME)
$(CXX) $< $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS)
$(OBJS): $(PROTO_GEN_CC)
$(PROGRAM_OBJS): $(PROTO_GEN_CC)
$(CU_OBJS): %.cuo: %.cu
$(NVCC) -c $< -o $@

Просмотреть файл

@ -43,10 +43,6 @@ private:\
namespace caffe {
// Two classes whose purpose are solely for instantiating blob template
// functions.
class GPUBrewer {};
class CPUBrewer {};
// We will use the boost shared_ptr instead of the new C++11 one mainly
// because cuda does not work (at least now) well with C++11 features.

Просмотреть файл

@ -47,10 +47,12 @@ void* DataLayerPrefetch(void* layer_pointer) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
top_data[((itemid * channels + c) * cropsize + h) * cropsize + cropsize - 1 - w] =
static_cast<Dtype>((uint8_t)data[
(c * height + h + h_offset) * width + w + w_offset]
) * scale - subtraction;
top_data[((itemid * channels + c) * cropsize + h) * cropsize
+ cropsize - 1 - w] =
static_cast<Dtype>(
(uint8_t)data[(c * height + h + h_offset) * width
+ w + w_offset])
* scale - subtraction;
}
}
}
@ -59,10 +61,11 @@ void* DataLayerPrefetch(void* layer_pointer) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
top_data[((itemid * channels + c) * cropsize + h) * cropsize + w] =
static_cast<Dtype>((uint8_t)data[
(c * height + h + h_offset) * width + w + w_offset]
) * scale - subtraction;
top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
= static_cast<Dtype>(
(uint8_t)data[(c * height + h + h_offset) * width
+ w + w_offset])
* scale - subtraction;
}
}
}
@ -144,10 +147,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, cropsize);
CHECK_GT(datum_width_, cropsize);
// Now, start the prefetch thread.
//LOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
<< "Pthread execution failed.";
//LOG(INFO) << "Prefetch initialized.";
// LOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
// LOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>
@ -161,8 +164,8 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
sizeof(Dtype) * prefetch_label_->count());
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
<< "Pthread execution failed.";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
template <typename Dtype>
@ -171,13 +174,15 @@ void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(),
sizeof(Dtype) * prefetch_data_->count(), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(),
sizeof(Dtype) * prefetch_label_->count(), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
cudaMemcpyHostToDevice));
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
<< "Pthread execution failed.";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
// The backward operations are dummy - they do not carry any computation.

Просмотреть файл

@ -49,7 +49,7 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
GetFiller<Dtype>(this->layer_param_.bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
} // parameter initialization
// Setting up the bias multiplier
if (biasterm_) {
bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));

Просмотреть файл

@ -1,54 +0,0 @@
// Copyright 2013 Yangqing Jia
#include <gtest/gtest.h>
#include <cstring>
#include "caffe/common.hpp"
#include "caffe/blob.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
template <typename Dtype>
class NetProtoTest : public ::testing::Test {};
typedef ::testing::Types<float> Dtypes;
TYPED_TEST_CASE(NetProtoTest, Dtypes);
TYPED_TEST(NetProtoTest, TestLoadFromText) {
NetParameter net_param;
ReadProtoFromTextFile("data/simple_conv.prototxt", &net_param);
Blob<TypeParam> lena_image;
ReadImageToBlob<TypeParam>(string("data/lena_256.jpg"), &lena_image);
vector<Blob<TypeParam>*> bottom_vec;
bottom_vec.push_back(&lena_image);
for (int i = 0; i < lena_image.count(); ++i) {
EXPECT_GE(lena_image.cpu_data()[i], 0);
EXPECT_LE(lena_image.cpu_data()[i], 1);
}
Caffe::set_mode(Caffe::CPU);
// Initialize the network, and then does smoothing
Net<TypeParam> caffe_net(net_param, bottom_vec);
LOG(ERROR) << "Start Forward.";
const vector<Blob<TypeParam>*>& output = caffe_net.Forward(bottom_vec);
LOG(ERROR) << "Forward Done.";
EXPECT_EQ(output[0]->num(), 1);
EXPECT_EQ(output[0]->channels(), 1);
EXPECT_EQ(output[0]->height(), 252);
EXPECT_EQ(output[0]->width(), 252);
for (int i = 0; i < output[0]->count(); ++i) {
EXPECT_GE(output[0]->cpu_data()[i], 0);
EXPECT_LE(output[0]->cpu_data()[i], 1);
}
WriteBlobToImage<TypeParam>(string("lena_smoothed.png"), *output[0]);
}
} // namespace caffe

Просмотреть файл

Просмотреть файл

@ -1,104 +0,0 @@
// Copyright Yangqing Jia 2013
//
// This is a working version of the math functions that would hopefully replace
// the cpu and gpu separate version, that would eventually replace the old
// math_functions wrapper.
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
namespace caffe {
namespace blobmath {
// Decaf gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
template <class Brewer, typename Dtype>
void gemm(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
Dtype* C);
template <typename Dtype>
void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
Dtype* y);
template <typename Dtype>
void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
Dtype* y);
template <typename Dtype>
void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);
template <typename Dtype>
void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);
template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);
template <typename Dtype>
void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);
template <typename Dtype>
void caffe_copy(const int N, const Dtype *X, Dtype *Y);
template <typename Dtype>
void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
template <typename Dtype>
void caffe_scal(const int N, const Dtype alpha, Dtype *X);
template <typename Dtype>
void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X);
template <typename Dtype>
void caffe_sqr(const int N, const Dtype* a, Dtype* y);
template <typename Dtype>
void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);
template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);
template <typename Dtype>
Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
template <typename Dtype>
void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);
} // namespace blobmath
} // namespace caffe

Просмотреть файл

@ -5,8 +5,6 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/coded_stream.h>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <algorithm>
#include <string>
@ -17,8 +15,6 @@
#include "caffe/util/io.hpp"
#include "caffe/proto/caffe.pb.h"
using cv::Mat;
using cv::Vec3b;
using std::fstream;
using std::ios;
using std::max;
@ -32,77 +28,6 @@ using google::protobuf::io::CodedOutputStream;
namespace caffe {
void ReadImageToProto(const string& filename, BlobProto* proto) {
Mat cv_img;
cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
CHECK(cv_img.data) << "Could not open or find the image.";
DCHECK_EQ(cv_img.channels(), 3);
proto->set_num(1);
proto->set_channels(3);
proto->set_height(cv_img.rows);
proto->set_width(cv_img.cols);
proto->clear_data();
proto->clear_diff();
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
proto->add_data(static_cast<float>(cv_img.at<Vec3b>(h, w)[c]) / 255.);
}
}
}
}
void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
Mat cv_img;
cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
CHECK(cv_img.data) << "Could not open or find the image.";
DCHECK_EQ(cv_img.channels(), 3);
datum->set_channels(3);
datum->set_height(cv_img.rows);
datum->set_width(cv_img.cols);
datum->set_label(label);
datum->clear_data();
datum->clear_float_data();
string* datum_string = datum->mutable_data();
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
datum_string->push_back(static_cast<char>(cv_img.at<Vec3b>(h, w)[c]));
}
}
}
}
void WriteProtoToImage(const string& filename, const BlobProto& proto) {
CHECK_EQ(proto.num(), 1);
CHECK(proto.channels() == 3 || proto.channels() == 1);
CHECK_GT(proto.height(), 0);
CHECK_GT(proto.width(), 0);
Mat cv_img(proto.height(), proto.width(), CV_8UC3);
if (proto.channels() == 1) {
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
cv_img.at<Vec3b>(h, w)[c] =
uint8_t(proto.data(h * cv_img.cols + w) * 255.);
}
}
}
} else {
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
cv_img.at<Vec3b>(h, w)[c] =
uint8_t(proto.data((c * cv_img.rows + h) * cv_img.cols + w)
* 255.);
}
}
}
}
CHECK(cv::imwrite(filename, cv_img));
}
void ReadProtoFromTextFile(const char* filename,
::google::protobuf::Message* proto) {
int fd = open(filename, O_RDONLY);

Просмотреть файл

@ -15,26 +15,6 @@ using ::google::protobuf::Message;
namespace caffe {
void ReadImageToProto(const string& filename, BlobProto* proto);
template <typename Dtype>
inline void ReadImageToBlob(const string& filename, Blob<Dtype>* blob) {
BlobProto proto;
ReadImageToProto(filename, &proto);
blob->FromProto(proto);
}
void WriteProtoToImage(const string& filename, const BlobProto& proto);
template <typename Dtype>
inline void WriteBlobToImage(const string& filename, const Blob<Dtype>& blob) {
BlobProto proto;
blob.ToProto(&proto);
WriteProtoToImage(filename, proto);
}
void ReadImageToDatum(const string& filename, const int label, Datum* datum);
void ReadProtoFromTextFile(const char* filename,
Message* proto);
inline void ReadProtoFromTextFile(const string& filename,

Просмотреть файл

@ -1,81 +0,0 @@
// Copyright 2013 Yangqing Jia
// This program converts a set of images to a leveldb by storing them as Datum
// proto buffers.
// Usage:
// convert_dataset ROOTFOLDER LISTFILE DB_NAME
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 0
// ....
// You are responsible for shuffling the files yourself.
#include <glog/logging.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <string>
#include <iostream>
#include <fstream>
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
using namespace caffe;
using std::string;
using std::stringstream;
// A utility function to generate random strings
void GenerateRandomPrefix(const int n, string* key) {
const char* kCHARS = "abcdefghijklmnopqrstuvwxyz";
key->clear();
for (int i = 0; i < n; ++i) {
key->push_back(kCHARS[rand() % 26]);
}
key->push_back('_');
}
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
std::ifstream infile(argv[2]);
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
LOG(INFO) << "Opening leveldb " << argv[3];
leveldb::Status status = leveldb::DB::Open(
options, argv[3], &db);
CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
string root_folder(argv[1]);
string filename;
int label;
Datum datum;
int count = 0;
char key_cstr[100];
leveldb::WriteBatch* batch = new leveldb::WriteBatch();
while (infile >> filename >> label) {
ReadImageToDatum(root_folder + filename, label, &datum);
// sequential
sprintf(key_cstr, "%08d_%s", count, filename.c_str());
string key(key_cstr);
// random
// string key;
// GenerateRandomPrefix(8, &key);
// key += filename;
string value;
// get the value
datum.SerializeToString(&value);
batch->Put(key, value);
if (++count % 1000 == 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Processed " << count << " files.";
delete batch;
batch = new leveldb::WriteBatch();
}
}
delete db;
return 0;
}

Просмотреть файл

@ -17,7 +17,7 @@
using namespace caffe;
int main(int argc, char** argv) {
cudaSetDevice(0);
cudaSetDevice(1);
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TRAIN);
@ -35,7 +35,7 @@ int main(int argc, char** argv) {
SolverParameter solver_param;
solver_param.set_base_lr(0.01);
solver_param.set_display(1);
solver_param.set_display(100);
solver_param.set_max_iter(6000);
solver_param.set_lr_policy("inv");
solver_param.set_gamma(0.0001);