This commit is contained in:
Yangqing Jia 2013-09-13 17:33:52 -07:00
Родитель 298e7a4129
Коммит dceef03a17
3 изменённых файлов: 14 добавлений и 60 удалений

Просмотреть файл

@ -1,36 +0,0 @@
#ifndef CAFFEINE_BASE_H_
#define CAFFEINE_BASE_H_
#include <vector>
#include "caffeine/blob.hpp"
#include "caffeine/proto/layer_param.pb.h"
using std::vector;
namespace caffeine {
template <typename Dtype>
class Layer {
public:
explicit Layer(const LayerParameter& param)
: initialized_(false), layer_param_(param) {};
~Layer();
virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
virtual void Forward(vector<const Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
virtual void Predict(vector<const Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
virtual void Backward(vector<const Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top, bool propagate_down) = 0;
protected:
bool initialized_;
// The protobuf that stores the layer parameters
LayerParameter layer_param_;
// The vector that stores the parameters as a set of blobs.
vector<Blob<Dtype> > blobs;
}; // class Layer
} // namespace caffeine
#endif // CAFFEINE_BASE_H_

Просмотреть файл

@ -1,8 +1,6 @@
#ifndef CAFFEINE_COMMON_HPP_
#define CAFFEINE_COMMON_HPP_
#include <iostream>
#include <boost/shared_ptr.hpp>
#include <cublas_v2.h>
#include <glog/logging.h>
@ -21,31 +19,23 @@ using boost::shared_ptr;
// A singleton class to hold common caffeine stuff, such as the handler that
// caffeine is going to use for cublas.
class Caffeine {
public:
~Caffeine();
static Caffeine& Get();
enum Brew { CPU, GPU };
// The getters for the variables.
static cublasHandle_t cublas_handle();
static Brew mode();
// The setters for the variables
static Brew set_mode(Brew mode);
private:
Caffeine() {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
};
Caffeine();
static shared_ptr<Caffeine> singleton_;
cublasHandle_t cublas_handle_;
public:
~Caffeine() {
if (!cublas_handle_) {
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
}
}
static Caffeine& Get() {
if (!singleton_) {
singleton_.reset(new Caffeine());
}
return *singleton_;
}
static cublasHandle_t cublas_handle() {
return Get().cublas_handle_;
}
Brew mode_;
};
}
#endif // CAFFEINE_COMMON_HPP_

Просмотреть файл

@ -1,6 +1,6 @@
#include "caffeine/base.h"
#include "caffeine/base.hpp"
namespace caffeine {
} // namespace caffeine
} // namespace caffeine