зеркало из https://github.com/microsoft/caffe.git
some updates
This commit is contained in:
Родитель
298e7a4129
Коммит
dceef03a17
|
@ -1,36 +0,0 @@
|
|||
#ifndef CAFFEINE_BASE_H_
|
||||
#define CAFFEINE_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "caffeine/blob.hpp"
|
||||
#include "caffeine/proto/layer_param.pb.h"
|
||||
|
||||
using std::vector;
|
||||
|
||||
namespace caffeine {
|
||||
|
||||
template <typename Dtype>
|
||||
class Layer {
|
||||
public:
|
||||
explicit Layer(const LayerParameter& param)
|
||||
: initialized_(false), layer_param_(param) {};
|
||||
~Layer();
|
||||
virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) = 0;
|
||||
virtual void Forward(vector<const Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) = 0;
|
||||
virtual void Predict(vector<const Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) = 0;
|
||||
virtual void Backward(vector<const Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top, bool propagate_down) = 0;
|
||||
protected:
|
||||
bool initialized_;
|
||||
// The protobuf that stores the layer parameters
|
||||
LayerParameter layer_param_;
|
||||
// The vector that stores the parameters as a set of blobs.
|
||||
vector<Blob<Dtype> > blobs;
|
||||
}; // class Layer
|
||||
|
||||
} // namespace caffeine
|
||||
|
||||
#endif // CAFFEINE_BASE_H_
|
|
@ -1,8 +1,6 @@
|
|||
#ifndef CAFFEINE_COMMON_HPP_
|
||||
#define CAFFEINE_COMMON_HPP_
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <cublas_v2.h>
|
||||
#include <glog/logging.h>
|
||||
|
@ -21,31 +19,23 @@ using boost::shared_ptr;
|
|||
// A singleton class to hold common caffeine stuff, such as the handler that
|
||||
// caffeine is going to use for cublas.
|
||||
class Caffeine {
|
||||
public:
|
||||
~Caffeine();
|
||||
static Caffeine& Get();
|
||||
enum Brew { CPU, GPU };
|
||||
|
||||
// The getters for the variables.
|
||||
static cublasHandle_t cublas_handle();
|
||||
static Brew mode();
|
||||
// The setters for the variables
|
||||
static Brew set_mode(Brew mode);
|
||||
private:
|
||||
Caffeine() {
|
||||
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
|
||||
};
|
||||
Caffeine();
|
||||
static shared_ptr<Caffeine> singleton_;
|
||||
cublasHandle_t cublas_handle_;
|
||||
public:
|
||||
~Caffeine() {
|
||||
if (!cublas_handle_) {
|
||||
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
|
||||
}
|
||||
}
|
||||
static Caffeine& Get() {
|
||||
if (!singleton_) {
|
||||
singleton_.reset(new Caffeine());
|
||||
}
|
||||
return *singleton_;
|
||||
}
|
||||
|
||||
static cublasHandle_t cublas_handle() {
|
||||
return Get().cublas_handle_;
|
||||
}
|
||||
Brew mode_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif // CAFFEINE_COMMON_HPP_
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#include "caffeine/base.h"
|
||||
#include "caffeine/base.hpp"
|
||||
|
||||
namespace caffeine {
|
||||
|
||||
|
||||
} // namespace caffeine
|
||||
} // namespace caffeine
|
||||
|
|
Загрузка…
Ссылка в новой задаче