Merge pull request #11 from longjon/master

Python interface to blobs and blob data through boost::python
This commit is contained in:
Evan Shelhamer 2014-01-19 21:17:43 -08:00
Родитель e54bd1be09 ae72b9e870
Коммит f8039bc22b
1 изменённых файлов: 102 добавлений и 0 удалений

Просмотреть файл

@ -6,6 +6,7 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <boost/python.hpp>
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include <numpy/arrayobject.h>
#include "caffe/caffe.hpp"
@ -20,6 +21,78 @@ using boost::python::extract;
using boost::python::len;
using boost::python::list;
using boost::python::object;
using boost::python::handle;
using boost::python::vector_indexing_suite;
// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
// to Python
class CaffeBlob {
public:
CaffeBlob(const shared_ptr<Blob<float> > &blob)
: blob_(blob) {}
CaffeBlob()
{}
int num() const { return blob_->num(); }
int channels() const { return blob_->channels(); }
int height() const { return blob_->height(); }
int width() const { return blob_->width(); }
int count() const { return blob_->count(); }
bool operator == (const CaffeBlob &other)
{
return this->blob_ == other.blob_;
}
protected:
shared_ptr<Blob<float> > blob_;
};
// we need another wrapper (used as boost::python's HeldType) that receives a
// self PyObject * which we can use as ndarray.base, so that data/diff memory
// is not freed while still being used in Python
class CaffeBlobWrap : public CaffeBlob {
public:
CaffeBlobWrap(PyObject *p, shared_ptr<Blob<float> > &blob)
: CaffeBlob(blob), self_(p) {}
CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
: CaffeBlob(blob), self_(p) {}
object get_data()
{
npy_intp dims[] = {num(), channels(), height(), width()};
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
blob_->mutable_cpu_data());
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
Py_INCREF(self_);
handle<> h(obj);
return object(h);
}
object get_diff()
{
npy_intp dims[] = {num(), channels(), height(), width()};
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
blob_->mutable_cpu_diff());
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
Py_INCREF(self_);
handle<> h(obj);
return object(h);
}
private:
PyObject *self_;
};
// A simple wrapper over CaffeNet that runs the forward process.
@ -143,14 +216,24 @@ struct CaffeNet
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
void set_device(int device_id) { Caffe::SetDevice(device_id); }
vector<CaffeBlob> blobs() {
return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
}
vector<CaffeBlob> params() {
return vector<CaffeBlob>(net_->params().begin(), net_->params().end());
}
// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
};
// The boost python module definition.
BOOST_PYTHON_MODULE(pycaffe)
{
boost::python::class_<CaffeNet>(
"CaffeNet", boost::python::init<string, string>())
.def("Forward", &CaffeNet::Forward)
@ -160,5 +243,24 @@ BOOST_PYTHON_MODULE(pycaffe)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
.def("blobs", &CaffeNet::blobs)
.def("params", &CaffeNet::params)
;
boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"CaffeBlob", boost::python::no_init)
.add_property("num", &CaffeBlob::num)
.add_property("channels", &CaffeBlob::channels)
.add_property("height", &CaffeBlob::height)
.add_property("width", &CaffeBlob::width)
.add_property("count", &CaffeBlob::count)
.add_property("data", &CaffeBlobWrap::get_data)
.add_property("diff", &CaffeBlobWrap::get_diff)
;
boost::python::class_<vector<CaffeBlob> >("BlobVec")
.def(vector_indexing_suite<vector<CaffeBlob>, true>());
import_array();
}