зеркало из https://github.com/microsoft/caffe.git
Merge pull request #1014 from longjon/cleaner-pycaffe
Clean up pycaffe core
This commit is contained in:
Коммит
c80c6f355d
3
Makefile
3
Makefile
|
@ -70,6 +70,7 @@ EMPTY_LINT_REPORT := $(BUILD_DIR)/.$(LINT_EXT)
|
|||
NONEMPTY_LINT_REPORT := $(BUILD_DIR)/$(LINT_EXT)
|
||||
# PY$(PROJECT)_SRC is the python wrapper for $(PROJECT)
|
||||
PY$(PROJECT)_SRC := python/$(PROJECT)/_$(PROJECT).cpp
|
||||
PY$(PROJECT)_HXX_SRC := python/$(PROJECT)/_$(PROJECT).hpp
|
||||
PY$(PROJECT)_SO := python/$(PROJECT)/_$(PROJECT).so
|
||||
# MAT$(PROJECT)_SRC is the matlab wrapper for $(PROJECT)
|
||||
MAT$(PROJECT)_SRC := matlab/$(PROJECT)/mat$(PROJECT).cpp
|
||||
|
@ -345,7 +346,7 @@ py$(PROJECT): py
|
|||
|
||||
py: $(PY$(PROJECT)_SO) $(PROTO_GEN_PY)
|
||||
|
||||
$(PY$(PROJECT)_SO): $(STATIC_NAME) $(PY$(PROJECT)_SRC)
|
||||
$(PY$(PROJECT)_SO): $(STATIC_NAME) $(PY$(PROJECT)_SRC) $(PY$(PROJECT)_HXX_SRC)
|
||||
$(CXX) -shared -o $@ $(PY$(PROJECT)_SRC) \
|
||||
$(STATIC_NAME) $(LINKFLAGS) $(PYTHON_LDFLAGS)
|
||||
@ echo
|
||||
|
|
|
@ -2,17 +2,14 @@
|
|||
// caffe::Caffe functions so that one could easily call it from Python.
|
||||
// Note that for Python, we will simply use float as the data type.
|
||||
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include "boost/python.hpp"
|
||||
#include "boost/python/suite/indexing/vector_indexing_suite.hpp"
|
||||
#include "numpy/arrayobject.h"
|
||||
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
|
||||
|
||||
// these need to be included after boost on OS X
|
||||
#include <string> // NOLINT(build/include_order)
|
||||
#include <vector> // NOLINT(build/include_order)
|
||||
#include <fstream> // NOLINT
|
||||
|
||||
#include "_caffe.hpp"
|
||||
#include "caffe/caffe.hpp"
|
||||
|
||||
// Temporary solution for numpy < 1.7 versions: old macro, no promises.
|
||||
|
@ -22,15 +19,7 @@
|
|||
#define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x))
|
||||
#endif
|
||||
|
||||
|
||||
using namespace caffe; // NOLINT(build/namespaces)
|
||||
using boost::python::dict;
|
||||
using boost::python::extract;
|
||||
using boost::python::len;
|
||||
using boost::python::list;
|
||||
using boost::python::object;
|
||||
using boost::python::handle;
|
||||
using boost::python::vector_indexing_suite;
|
||||
namespace caffe {
|
||||
|
||||
// for convenience, check that input files can be opened, and raise an
|
||||
// exception that boost will send to Python if not (caffe could still crash
|
||||
|
@ -45,322 +34,165 @@ static void CheckFile(const string& filename) {
|
|||
f.close();
|
||||
}
|
||||
|
||||
// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
|
||||
// to Python
|
||||
class CaffeBlob {
|
||||
public:
|
||||
CaffeBlob(const shared_ptr<Blob<float> > &blob, const string& name)
|
||||
: blob_(blob), name_(name) {}
|
||||
bp::object PyBlobWrap::get_data() {
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
|
||||
string name() const { return name_; }
|
||||
int num() const { return blob_->num(); }
|
||||
int channels() const { return blob_->channels(); }
|
||||
int height() const { return blob_->height(); }
|
||||
int width() const { return blob_->width(); }
|
||||
int count() const { return blob_->count(); }
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_data());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
bp::handle<> h(obj);
|
||||
|
||||
// this is here only to satisfy boost's vector_indexing_suite
|
||||
bool operator == (const CaffeBlob &other) {
|
||||
return this->blob_ == other.blob_;
|
||||
return bp::object(h);
|
||||
}
|
||||
|
||||
bp::object PyBlobWrap::get_diff() {
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_diff());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
bp::handle<> h(obj);
|
||||
|
||||
return bp::object(h);
|
||||
}
|
||||
|
||||
PyNet::PyNet(string param_file, string pretrained_param_file) {
|
||||
Init(param_file);
|
||||
CheckFile(pretrained_param_file);
|
||||
net_->CopyTrainedLayersFrom(pretrained_param_file);
|
||||
}
|
||||
|
||||
void PyNet::Init(string param_file) {
|
||||
CheckFile(param_file);
|
||||
net_.reset(new Net<float>(param_file));
|
||||
}
|
||||
|
||||
void PyNet::check_contiguous_array(PyArrayObject* arr, string name,
|
||||
int channels, int height, int width) {
|
||||
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
|
||||
throw std::runtime_error(name + " must be C contiguous");
|
||||
}
|
||||
if (PyArray_NDIM(arr) != 4) {
|
||||
throw std::runtime_error(name + " must be 4-d");
|
||||
}
|
||||
if (PyArray_TYPE(arr) != NPY_FLOAT32) {
|
||||
throw std::runtime_error(name + " must be float32");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[1] != channels) {
|
||||
throw std::runtime_error(name + " has wrong number of channels");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[2] != height) {
|
||||
throw std::runtime_error(name + " has wrong height");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[3] != width) {
|
||||
throw std::runtime_error(name + " has wrong width");
|
||||
}
|
||||
}
|
||||
|
||||
void PyNet::set_input_arrays(bp::object data_obj, bp::object labels_obj) {
|
||||
// check that this network has an input MemoryDataLayer
|
||||
shared_ptr<MemoryDataLayer<float> > md_layer =
|
||||
boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
|
||||
if (!md_layer) {
|
||||
throw std::runtime_error("set_input_arrays may only be called if the"
|
||||
" first layer is a MemoryDataLayer");
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<Blob<float> > blob_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
|
||||
// We need another wrapper (used as boost::python's HeldType) that receives a
|
||||
// self PyObject * which we can use as ndarray.base, so that data/diff memory
|
||||
// is not freed while still being used in Python.
|
||||
class CaffeBlobWrap : public CaffeBlob {
|
||||
public:
|
||||
CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
|
||||
: CaffeBlob(blob), self_(p) {}
|
||||
|
||||
object get_data() {
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_data());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
handle<> h(obj);
|
||||
|
||||
return object(h);
|
||||
// check that we were passed appropriately-sized contiguous memory
|
||||
PyArrayObject* data_arr =
|
||||
reinterpret_cast<PyArrayObject*>(data_obj.ptr());
|
||||
PyArrayObject* labels_arr =
|
||||
reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
|
||||
check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
|
||||
md_layer->datum_height(), md_layer->datum_width());
|
||||
check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
|
||||
if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
|
||||
throw std::runtime_error("data and labels must have the same first"
|
||||
" dimension");
|
||||
}
|
||||
if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
|
||||
throw std::runtime_error("first dimensions of input arrays must be a"
|
||||
" multiple of batch size");
|
||||
}
|
||||
|
||||
object get_diff() {
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
// hold references
|
||||
input_data_ = data_obj;
|
||||
input_labels_ = labels_obj;
|
||||
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_diff());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
handle<> h(obj);
|
||||
md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
|
||||
static_cast<float*>(PyArray_DATA(labels_arr)),
|
||||
PyArray_DIMS(data_arr)[0]);
|
||||
}
|
||||
|
||||
return object(h);
|
||||
}
|
||||
PySGDSolver::PySGDSolver(const string& param_file) {
|
||||
// as in PyNet, (as a convenience, not a guarantee), create a Python
|
||||
// exception if param_file can't be opened
|
||||
CheckFile(param_file);
|
||||
solver_.reset(new SGDSolver<float>(param_file));
|
||||
// we need to explicitly store the net wrapper, rather than constructing
|
||||
// it on the fly, so that it can hold references to Python objects
|
||||
net_.reset(new PyNet(solver_->net()));
|
||||
}
|
||||
|
||||
private:
|
||||
PyObject *self_;
|
||||
};
|
||||
void PySGDSolver::SolveResume(const string& resume_file) {
|
||||
CheckFile(resume_file);
|
||||
return solver_->Solve(resume_file);
|
||||
}
|
||||
|
||||
|
||||
class CaffeLayer {
|
||||
public:
|
||||
CaffeLayer(const shared_ptr<Layer<float> > &layer, const string &name)
|
||||
: layer_(layer), name_(name) {}
|
||||
|
||||
string name() const { return name_; }
|
||||
vector<CaffeBlob> blobs() {
|
||||
vector<CaffeBlob> result;
|
||||
for (int i = 0; i < layer_->blobs().size(); ++i) {
|
||||
result.push_back(CaffeBlob(layer_->blobs()[i], name_));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// this is here only to satisfy boost's vector_indexing_suite
|
||||
bool operator == (const CaffeLayer &other) {
|
||||
return this->layer_ == other.layer_;
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<Layer<float> > layer_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
|
||||
// A simple wrapper over CaffeNet that runs the forward process.
|
||||
struct CaffeNet {
|
||||
// For cases where parameters will be determined later by the Python user,
|
||||
// create a Net with unallocated parameters (which will not be zero-filled
|
||||
// when accessed).
|
||||
explicit CaffeNet(string param_file) {
|
||||
Init(param_file);
|
||||
}
|
||||
|
||||
CaffeNet(string param_file, string pretrained_param_file) {
|
||||
Init(param_file);
|
||||
CheckFile(pretrained_param_file);
|
||||
net_->CopyTrainedLayersFrom(pretrained_param_file);
|
||||
}
|
||||
|
||||
explicit CaffeNet(shared_ptr<Net<float> > net)
|
||||
: net_(net) {}
|
||||
|
||||
void Init(string param_file) {
|
||||
CheckFile(param_file);
|
||||
net_.reset(new Net<float>(param_file));
|
||||
}
|
||||
|
||||
|
||||
virtual ~CaffeNet() {}
|
||||
|
||||
// Generate Python exceptions for badly shaped or discontiguous arrays.
|
||||
inline void check_contiguous_array(PyArrayObject* arr, string name,
|
||||
int channels, int height, int width) {
|
||||
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
|
||||
throw std::runtime_error(name + " must be C contiguous");
|
||||
}
|
||||
if (PyArray_NDIM(arr) != 4) {
|
||||
throw std::runtime_error(name + " must be 4-d");
|
||||
}
|
||||
if (PyArray_TYPE(arr) != NPY_FLOAT32) {
|
||||
throw std::runtime_error(name + " must be float32");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[1] != channels) {
|
||||
throw std::runtime_error(name + " has wrong number of channels");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[2] != height) {
|
||||
throw std::runtime_error(name + " has wrong height");
|
||||
}
|
||||
if (PyArray_DIMS(arr)[3] != width) {
|
||||
throw std::runtime_error(name + " has wrong width");
|
||||
}
|
||||
}
|
||||
|
||||
void Forward(int start, int end) {
|
||||
net_->ForwardFromTo(start, end);
|
||||
}
|
||||
|
||||
void Backward(int start, int end) {
|
||||
net_->BackwardFromTo(start, end);
|
||||
}
|
||||
|
||||
void set_input_arrays(object data_obj, object labels_obj) {
|
||||
// check that this network has an input MemoryDataLayer
|
||||
shared_ptr<MemoryDataLayer<float> > md_layer =
|
||||
boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
|
||||
if (!md_layer) {
|
||||
throw std::runtime_error("set_input_arrays may only be called if the"
|
||||
" first layer is a MemoryDataLayer");
|
||||
}
|
||||
|
||||
// check that we were passed appropriately-sized contiguous memory
|
||||
PyArrayObject* data_arr =
|
||||
reinterpret_cast<PyArrayObject*>(data_obj.ptr());
|
||||
PyArrayObject* labels_arr =
|
||||
reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
|
||||
check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
|
||||
md_layer->datum_height(), md_layer->datum_width());
|
||||
check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
|
||||
if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
|
||||
throw std::runtime_error("data and labels must have the same first"
|
||||
" dimension");
|
||||
}
|
||||
if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
|
||||
throw std::runtime_error("first dimensions of input arrays must be a"
|
||||
" multiple of batch size");
|
||||
}
|
||||
|
||||
// hold references
|
||||
input_data_ = data_obj;
|
||||
input_labels_ = labels_obj;
|
||||
|
||||
md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
|
||||
static_cast<float*>(PyArray_DATA(labels_arr)),
|
||||
PyArray_DIMS(data_arr)[0]);
|
||||
}
|
||||
|
||||
// save the network weights to binary proto for net surgeries.
|
||||
void save(string filename) {
|
||||
NetParameter net_param;
|
||||
net_->ToProto(&net_param, false);
|
||||
WriteProtoToBinaryFile(net_param, filename.c_str());
|
||||
}
|
||||
|
||||
// The caffe::Caffe utility functions.
|
||||
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
|
||||
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
|
||||
void set_phase_train() { Caffe::set_phase(Caffe::TRAIN); }
|
||||
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
|
||||
void set_device(int device_id) { Caffe::SetDevice(device_id); }
|
||||
|
||||
vector<CaffeBlob> blobs() {
|
||||
vector<CaffeBlob> result;
|
||||
for (int i = 0; i < net_->blobs().size(); ++i) {
|
||||
result.push_back(CaffeBlob(net_->blobs()[i], net_->blob_names()[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
vector<CaffeLayer> layers() {
|
||||
vector<CaffeLayer> result;
|
||||
for (int i = 0; i < net_->layers().size(); ++i) {
|
||||
result.push_back(CaffeLayer(net_->layers()[i], net_->layer_names()[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
list inputs() {
|
||||
list input_blob_names;
|
||||
for (int i = 0; i < net_->input_blob_indices().size(); ++i) {
|
||||
input_blob_names.append(
|
||||
net_->blob_names()[net_->input_blob_indices()[i]]);
|
||||
}
|
||||
return input_blob_names;
|
||||
}
|
||||
|
||||
list outputs() {
|
||||
list output_blob_names;
|
||||
for (int i = 0; i < net_->output_blob_indices().size(); ++i) {
|
||||
output_blob_names.append(
|
||||
net_->blob_names()[net_->output_blob_indices()[i]]);
|
||||
}
|
||||
return output_blob_names;
|
||||
}
|
||||
|
||||
// The pointer to the internal caffe::Net instant.
|
||||
shared_ptr<Net<float> > net_;
|
||||
// Input preprocessing configuration attributes.
|
||||
dict mean_;
|
||||
dict input_scale_;
|
||||
dict raw_scale_;
|
||||
dict channel_swap_;
|
||||
// if taking input from an ndarray, we need to hold references
|
||||
object input_data_;
|
||||
object input_labels_;
|
||||
};
|
||||
|
||||
class CaffeSGDSolver {
|
||||
public:
|
||||
explicit CaffeSGDSolver(const string& param_file) {
|
||||
// as in CaffeNet, (as a convenience, not a guarantee), create a Python
|
||||
// exception if param_file can't be opened
|
||||
CheckFile(param_file);
|
||||
solver_.reset(new SGDSolver<float>(param_file));
|
||||
// we need to explicitly store the net wrapper, rather than constructing
|
||||
// it on the fly, so that it can hold references to Python objects
|
||||
net_.reset(new CaffeNet(solver_->net()));
|
||||
}
|
||||
|
||||
shared_ptr<CaffeNet> net() { return net_; }
|
||||
void Solve() { return solver_->Solve(); }
|
||||
void SolveResume(const string& resume_file) {
|
||||
CheckFile(resume_file);
|
||||
return solver_->Solve(resume_file);
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<CaffeNet> net_;
|
||||
shared_ptr<SGDSolver<float> > solver_;
|
||||
};
|
||||
|
||||
|
||||
// The boost_python module definition.
|
||||
BOOST_PYTHON_MODULE(_caffe) {
|
||||
// below, we prepend an underscore to methods that will be replaced
|
||||
// in Python
|
||||
boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
|
||||
"Net", boost::python::init<string, string>())
|
||||
.def(boost::python::init<string>())
|
||||
.def("_forward", &CaffeNet::Forward)
|
||||
.def("_backward", &CaffeNet::Backward)
|
||||
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
|
||||
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
|
||||
.def("set_phase_train", &CaffeNet::set_phase_train)
|
||||
.def("set_phase_test", &CaffeNet::set_phase_test)
|
||||
.def("set_device", &CaffeNet::set_device)
|
||||
.add_property("_blobs", &CaffeNet::blobs)
|
||||
.add_property("layers", &CaffeNet::layers)
|
||||
.add_property("inputs", &CaffeNet::inputs)
|
||||
.add_property("outputs", &CaffeNet::outputs)
|
||||
.add_property("mean", &CaffeNet::mean_)
|
||||
.add_property("input_scale", &CaffeNet::input_scale_)
|
||||
.add_property("raw_scale", &CaffeNet::raw_scale_)
|
||||
.add_property("channel_swap", &CaffeNet::channel_swap_)
|
||||
.def("_set_input_arrays", &CaffeNet::set_input_arrays)
|
||||
.def("save", &CaffeNet::save);
|
||||
bp::class_<PyNet, shared_ptr<PyNet> >(
|
||||
"Net", bp::init<string, string>())
|
||||
.def(bp::init<string>())
|
||||
.def("_forward", &PyNet::Forward)
|
||||
.def("_backward", &PyNet::Backward)
|
||||
.def("set_mode_cpu", &PyNet::set_mode_cpu)
|
||||
.def("set_mode_gpu", &PyNet::set_mode_gpu)
|
||||
.def("set_phase_train", &PyNet::set_phase_train)
|
||||
.def("set_phase_test", &PyNet::set_phase_test)
|
||||
.def("set_device", &PyNet::set_device)
|
||||
.add_property("_blobs", &PyNet::blobs)
|
||||
.add_property("layers", &PyNet::layers)
|
||||
.add_property("inputs", &PyNet::inputs)
|
||||
.add_property("outputs", &PyNet::outputs)
|
||||
.add_property("mean", &PyNet::mean_)
|
||||
.add_property("input_scale", &PyNet::input_scale_)
|
||||
.add_property("raw_scale", &PyNet::raw_scale_)
|
||||
.add_property("channel_swap", &PyNet::channel_swap_)
|
||||
.def("_set_input_arrays", &PyNet::set_input_arrays)
|
||||
.def("save", &PyNet::save);
|
||||
|
||||
boost::python::class_<CaffeBlob, CaffeBlobWrap>(
|
||||
"Blob", boost::python::no_init)
|
||||
.add_property("name", &CaffeBlob::name)
|
||||
.add_property("num", &CaffeBlob::num)
|
||||
.add_property("channels", &CaffeBlob::channels)
|
||||
.add_property("height", &CaffeBlob::height)
|
||||
.add_property("width", &CaffeBlob::width)
|
||||
.add_property("count", &CaffeBlob::count)
|
||||
.add_property("data", &CaffeBlobWrap::get_data)
|
||||
.add_property("diff", &CaffeBlobWrap::get_diff);
|
||||
bp::class_<PyBlob<float>, PyBlobWrap>(
|
||||
"Blob", bp::no_init)
|
||||
.add_property("num", &PyBlob<float>::num)
|
||||
.add_property("channels", &PyBlob<float>::channels)
|
||||
.add_property("height", &PyBlob<float>::height)
|
||||
.add_property("width", &PyBlob<float>::width)
|
||||
.add_property("count", &PyBlob<float>::count)
|
||||
.def("reshape", &PyBlob<float>::Reshape)
|
||||
.add_property("data", &PyBlobWrap::get_data)
|
||||
.add_property("diff", &PyBlobWrap::get_diff);
|
||||
|
||||
boost::python::class_<CaffeLayer>(
|
||||
"Layer", boost::python::no_init)
|
||||
.add_property("name", &CaffeLayer::name)
|
||||
.add_property("blobs", &CaffeLayer::blobs);
|
||||
bp::class_<PyLayer>(
|
||||
"Layer", bp::no_init)
|
||||
.add_property("blobs", &PyLayer::blobs);
|
||||
|
||||
boost::python::class_<CaffeSGDSolver, boost::noncopyable>(
|
||||
"SGDSolver", boost::python::init<string>())
|
||||
.add_property("net", &CaffeSGDSolver::net)
|
||||
.def("solve", &CaffeSGDSolver::Solve)
|
||||
.def("solve", &CaffeSGDSolver::SolveResume);
|
||||
bp::class_<PySGDSolver, boost::noncopyable>(
|
||||
"SGDSolver", bp::init<string>())
|
||||
.add_property("net", &PySGDSolver::net)
|
||||
.def("solve", &PySGDSolver::Solve)
|
||||
.def("solve", &PySGDSolver::SolveResume);
|
||||
|
||||
boost::python::class_<vector<CaffeBlob> >("BlobVec")
|
||||
.def(vector_indexing_suite<vector<CaffeBlob>, true>());
|
||||
bp::class_<vector<PyBlob<float> > >("BlobVec")
|
||||
.def(bp::vector_indexing_suite<vector<PyBlob<float> >, true>());
|
||||
|
||||
boost::python::class_<vector<CaffeLayer> >("LayerVec")
|
||||
.def(vector_indexing_suite<vector<CaffeLayer>, true>());
|
||||
bp::class_<vector<PyLayer> >("LayerVec")
|
||||
.def(bp::vector_indexing_suite<vector<PyLayer>, true>());
|
||||
|
||||
import_array();
|
||||
}
|
||||
|
||||
} // namespace caffe
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
#ifndef PYTHON_CAFFE__CAFFE_HPP_
|
||||
#define PYTHON_CAFFE__CAFFE_HPP_
|
||||
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <boost/python.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <numpy/arrayobject.h>
|
||||
|
||||
// these need to be included after boost on OS X
|
||||
#include <string> // NOLINT(build/include_order)
|
||||
#include <vector> // NOLINT(build/include_order)
|
||||
|
||||
#include "caffe/caffe.hpp"
|
||||
|
||||
namespace bp = boost::python;
|
||||
using boost::shared_ptr;
|
||||
|
||||
namespace caffe {
|
||||
|
||||
// wrap shared_ptr<Blob> in a class that we construct in C++ and pass
|
||||
// to Python
|
||||
template <typename Dtype>
|
||||
class PyBlob {
|
||||
public:
|
||||
explicit PyBlob(const shared_ptr<Blob<Dtype> > &blob)
|
||||
: blob_(blob) {}
|
||||
|
||||
int num() const { return blob_->num(); }
|
||||
int channels() const { return blob_->channels(); }
|
||||
int height() const { return blob_->height(); }
|
||||
int width() const { return blob_->width(); }
|
||||
int count() const { return blob_->count(); }
|
||||
void Reshape(const int n, const int c, const int h, const int w) {
|
||||
return blob_->Reshape(n, c, h, w);
|
||||
}
|
||||
|
||||
// this is here only to satisfy boost's vector_indexing_suite
|
||||
bool operator == (const PyBlob &other) {
|
||||
return this->blob_ == other.blob_;
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<Blob<Dtype> > blob_;
|
||||
};
|
||||
|
||||
// We need another wrapper (used as boost::python's HeldType) that receives a
|
||||
// self PyObject * which we can use as ndarray.base, so that data/diff memory
|
||||
// is not freed while still being used in Python.
|
||||
class PyBlobWrap : public PyBlob<float> {
|
||||
public:
|
||||
PyBlobWrap(PyObject *p, const PyBlob<float> &blob)
|
||||
: PyBlob<float>(blob), self_(p) {}
|
||||
|
||||
bp::object get_data();
|
||||
bp::object get_diff();
|
||||
|
||||
private:
|
||||
PyObject *self_;
|
||||
};
|
||||
|
||||
class PyLayer {
|
||||
public:
|
||||
explicit PyLayer(const shared_ptr<Layer<float> > &layer)
|
||||
: layer_(layer) {}
|
||||
|
||||
vector<PyBlob<float> > blobs() {
|
||||
return vector<PyBlob<float> >(layer_->blobs().begin(),
|
||||
layer_->blobs().end());
|
||||
}
|
||||
|
||||
// this is here only to satisfy boost's vector_indexing_suite
|
||||
bool operator == (const PyLayer &other) {
|
||||
return this->layer_ == other.layer_;
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<Layer<float> > layer_;
|
||||
};
|
||||
|
||||
class PyNet {
|
||||
public:
|
||||
// For cases where parameters will be determined later by the Python user,
|
||||
// create a Net with unallocated parameters (which will not be zero-filled
|
||||
// when accessed).
|
||||
explicit PyNet(string param_file) { Init(param_file); }
|
||||
PyNet(string param_file, string pretrained_param_file);
|
||||
explicit PyNet(shared_ptr<Net<float> > net)
|
||||
: net_(net) {}
|
||||
virtual ~PyNet() {}
|
||||
|
||||
void Init(string param_file);
|
||||
|
||||
|
||||
// Generate Python exceptions for badly shaped or discontiguous arrays.
|
||||
inline void check_contiguous_array(PyArrayObject* arr, string name,
|
||||
int channels, int height, int width);
|
||||
|
||||
void Forward(int start, int end) { net_->ForwardFromTo(start, end); }
|
||||
void Backward(int start, int end) { net_->BackwardFromTo(start, end); }
|
||||
|
||||
void set_input_arrays(bp::object data_obj, bp::object labels_obj);
|
||||
|
||||
// Save the network weights to binary proto for net surgeries.
|
||||
void save(string filename) {
|
||||
NetParameter net_param;
|
||||
net_->ToProto(&net_param, false);
|
||||
WriteProtoToBinaryFile(net_param, filename.c_str());
|
||||
}
|
||||
|
||||
// The caffe::Caffe utility functions.
|
||||
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
|
||||
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
|
||||
void set_phase_train() { Caffe::set_phase(Caffe::TRAIN); }
|
||||
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
|
||||
void set_device(int device_id) { Caffe::SetDevice(device_id); }
|
||||
|
||||
vector<PyBlob<float> > blobs() {
|
||||
return vector<PyBlob<float> >(net_->blobs().begin(), net_->blobs().end());
|
||||
}
|
||||
|
||||
vector<PyLayer> layers() {
|
||||
return vector<PyLayer>(net_->layers().begin(), net_->layers().end());
|
||||
}
|
||||
|
||||
bp::list inputs() {
|
||||
bp::list input_blob_names;
|
||||
for (int i = 0; i < net_->input_blob_indices().size(); ++i) {
|
||||
input_blob_names.append(
|
||||
net_->blob_names()[net_->input_blob_indices()[i]]);
|
||||
}
|
||||
return input_blob_names;
|
||||
}
|
||||
|
||||
bp::list outputs() {
|
||||
bp::list output_blob_names;
|
||||
for (int i = 0; i < net_->output_blob_indices().size(); ++i) {
|
||||
output_blob_names.append(
|
||||
net_->blob_names()[net_->output_blob_indices()[i]]);
|
||||
}
|
||||
return output_blob_names;
|
||||
}
|
||||
|
||||
// Input preprocessing configuration attributes. These are public for
|
||||
// direct access from Python.
|
||||
bp::dict mean_;
|
||||
bp::dict input_scale_;
|
||||
bp::dict raw_scale_;
|
||||
bp::dict channel_swap_;
|
||||
|
||||
protected:
|
||||
// The pointer to the internal caffe::Net instant.
|
||||
shared_ptr<Net<float> > net_;
|
||||
// if taking input from an ndarray, we need to hold references
|
||||
bp::object input_data_;
|
||||
bp::object input_labels_;
|
||||
};
|
||||
|
||||
class PySGDSolver {
|
||||
public:
|
||||
explicit PySGDSolver(const string& param_file);
|
||||
|
||||
shared_ptr<PyNet> net() { return net_; }
|
||||
void Solve() { return solver_->Solve(); }
|
||||
void SolveResume(const string& resume_file);
|
||||
|
||||
protected:
|
||||
shared_ptr<PyNet> net_;
|
||||
shared_ptr<SGDSolver<float> > solver_;
|
||||
};
|
||||
|
||||
// Declare the module init function created by boost::python, so that we can
|
||||
// use this module from C++ when embedding Python.
|
||||
PyMODINIT_FUNC init_caffe(void);
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif
|
Загрузка…
Ссылка в новой задаче