зеркало из https://github.com/microsoft/caffe.git
some updates
This commit is contained in:
Родитель
298e7a4129
Коммит
dceef03a17
|
@ -1,36 +0,0 @@
|
||||||
#ifndef CAFFEINE_BASE_H_
|
|
||||||
#define CAFFEINE_BASE_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include "caffeine/blob.hpp"
|
|
||||||
#include "caffeine/proto/layer_param.pb.h"
|
|
||||||
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
namespace caffeine {
|
|
||||||
|
|
||||||
template <typename Dtype>
|
|
||||||
class Layer {
|
|
||||||
public:
|
|
||||||
explicit Layer(const LayerParameter& param)
|
|
||||||
: initialized_(false), layer_param_(param) {};
|
|
||||||
~Layer();
|
|
||||||
virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
|
|
||||||
vector<Blob<Dtype>*>* top) = 0;
|
|
||||||
virtual void Forward(vector<const Blob<Dtype>*>& bottom,
|
|
||||||
vector<Blob<Dtype>*>* top) = 0;
|
|
||||||
virtual void Predict(vector<const Blob<Dtype>*>& bottom,
|
|
||||||
vector<Blob<Dtype>*>* top) = 0;
|
|
||||||
virtual void Backward(vector<const Blob<Dtype>*>& bottom,
|
|
||||||
vector<Blob<Dtype>*>* top, bool propagate_down) = 0;
|
|
||||||
protected:
|
|
||||||
bool initialized_;
|
|
||||||
// The protobuf that stores the layer parameters
|
|
||||||
LayerParameter layer_param_;
|
|
||||||
// The vector that stores the parameters as a set of blobs.
|
|
||||||
vector<Blob<Dtype> > blobs;
|
|
||||||
}; // class Layer
|
|
||||||
|
|
||||||
} // namespace caffeine
|
|
||||||
|
|
||||||
#endif // CAFFEINE_BASE_H_
|
|
|
@ -1,8 +1,6 @@
|
||||||
#ifndef CAFFEINE_COMMON_HPP_
|
#ifndef CAFFEINE_COMMON_HPP_
|
||||||
#define CAFFEINE_COMMON_HPP_
|
#define CAFFEINE_COMMON_HPP_
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <glog/logging.h>
|
#include <glog/logging.h>
|
||||||
|
@ -21,31 +19,23 @@ using boost::shared_ptr;
|
||||||
// A singleton class to hold common caffeine stuff, such as the handler that
|
// A singleton class to hold common caffeine stuff, such as the handler that
|
||||||
// caffeine is going to use for cublas.
|
// caffeine is going to use for cublas.
|
||||||
class Caffeine {
|
class Caffeine {
|
||||||
|
public:
|
||||||
|
~Caffeine();
|
||||||
|
static Caffeine& Get();
|
||||||
|
enum Brew { CPU, GPU };
|
||||||
|
|
||||||
|
// The getters for the variables.
|
||||||
|
static cublasHandle_t cublas_handle();
|
||||||
|
static Brew mode();
|
||||||
|
// The setters for the variables
|
||||||
|
static Brew set_mode(Brew mode);
|
||||||
private:
|
private:
|
||||||
Caffeine() {
|
Caffeine();
|
||||||
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
|
|
||||||
};
|
|
||||||
static shared_ptr<Caffeine> singleton_;
|
static shared_ptr<Caffeine> singleton_;
|
||||||
cublasHandle_t cublas_handle_;
|
cublasHandle_t cublas_handle_;
|
||||||
public:
|
Brew mode_;
|
||||||
~Caffeine() {
|
|
||||||
if (!cublas_handle_) {
|
|
||||||
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
static Caffeine& Get() {
|
|
||||||
if (!singleton_) {
|
|
||||||
singleton_.reset(new Caffeine());
|
|
||||||
}
|
|
||||||
return *singleton_;
|
|
||||||
}
|
|
||||||
|
|
||||||
static cublasHandle_t cublas_handle() {
|
|
||||||
return Get().cublas_handle_;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // CAFFEINE_COMMON_HPP_
|
#endif // CAFFEINE_COMMON_HPP_
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#include "caffeine/base.h"
|
#include "caffeine/base.hpp"
|
||||||
|
|
||||||
namespace caffeine {
|
namespace caffeine {
|
||||||
|
|
||||||
|
|
||||||
} // namespace caffeine
|
} // namespace caffeine
|
||||||
|
|
Загрузка…
Ссылка в новой задаче