зеркало из https://github.com/microsoft/caffe.git
Merge pull request #11 from longjon/master
Python interface to blobs and blob data through boost::python
This commit is contained in:
Коммит
f8039bc22b
|
@ -6,6 +6,7 @@
|
|||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <boost/python.hpp>
|
||||
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
|
||||
#include <numpy/arrayobject.h>
|
||||
#include "caffe/caffe.hpp"
|
||||
|
||||
|
@ -20,6 +21,78 @@ using boost::python::extract;
|
|||
using boost::python::len;
|
||||
using boost::python::list;
|
||||
using boost::python::object;
|
||||
using boost::python::handle;
|
||||
using boost::python::vector_indexing_suite;
|
||||
|
||||
|
||||
// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
|
||||
// to Python
|
||||
class CaffeBlob {
|
||||
public:
|
||||
|
||||
CaffeBlob(const shared_ptr<Blob<float> > &blob)
|
||||
: blob_(blob) {}
|
||||
|
||||
CaffeBlob()
|
||||
{}
|
||||
|
||||
int num() const { return blob_->num(); }
|
||||
int channels() const { return blob_->channels(); }
|
||||
int height() const { return blob_->height(); }
|
||||
int width() const { return blob_->width(); }
|
||||
int count() const { return blob_->count(); }
|
||||
|
||||
bool operator == (const CaffeBlob &other)
|
||||
{
|
||||
return this->blob_ == other.blob_;
|
||||
}
|
||||
|
||||
protected:
|
||||
shared_ptr<Blob<float> > blob_;
|
||||
};
|
||||
|
||||
|
||||
// we need another wrapper (used as boost::python's HeldType) that receives a
|
||||
// self PyObject * which we can use as ndarray.base, so that data/diff memory
|
||||
// is not freed while still being used in Python
|
||||
class CaffeBlobWrap : public CaffeBlob {
|
||||
public:
|
||||
CaffeBlobWrap(PyObject *p, shared_ptr<Blob<float> > &blob)
|
||||
: CaffeBlob(blob), self_(p) {}
|
||||
|
||||
CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
|
||||
: CaffeBlob(blob), self_(p) {}
|
||||
|
||||
object get_data()
|
||||
{
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_data());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
handle<> h(obj);
|
||||
|
||||
return object(h);
|
||||
}
|
||||
|
||||
object get_diff()
|
||||
{
|
||||
npy_intp dims[] = {num(), channels(), height(), width()};
|
||||
|
||||
PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
|
||||
blob_->mutable_cpu_diff());
|
||||
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
|
||||
Py_INCREF(self_);
|
||||
handle<> h(obj);
|
||||
|
||||
return object(h);
|
||||
}
|
||||
|
||||
private:
|
||||
PyObject *self_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
// A simple wrapper over CaffeNet that runs the forward process.
|
||||
|
@ -143,14 +216,24 @@ struct CaffeNet
|
|||
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
|
||||
void set_device(int device_id) { Caffe::SetDevice(device_id); }
|
||||
|
||||
vector<CaffeBlob> blobs() {
|
||||
return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
|
||||
}
|
||||
|
||||
vector<CaffeBlob> params() {
|
||||
return vector<CaffeBlob>(net_->params().begin(), net_->params().end());
|
||||
}
|
||||
|
||||
// The pointer to the internal caffe::Net instant.
|
||||
shared_ptr<Net<float> > net_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
// The boost python module definition.
|
||||
BOOST_PYTHON_MODULE(pycaffe)
|
||||
{
|
||||
|
||||
boost::python::class_<CaffeNet>(
|
||||
"CaffeNet", boost::python::init<string, string>())
|
||||
.def("Forward", &CaffeNet::Forward)
|
||||
|
@ -160,5 +243,24 @@ BOOST_PYTHON_MODULE(pycaffe)
|
|||
.def("set_phase_train", &CaffeNet::set_phase_train)
|
||||
.def("set_phase_test", &CaffeNet::set_phase_test)
|
||||
.def("set_device", &CaffeNet::set_device)
|
||||
.def("blobs", &CaffeNet::blobs)
|
||||
.def("params", &CaffeNet::params)
|
||||
;
|
||||
|
||||
boost::python::class_<CaffeBlob, CaffeBlobWrap>(
|
||||
"CaffeBlob", boost::python::no_init)
|
||||
.add_property("num", &CaffeBlob::num)
|
||||
.add_property("channels", &CaffeBlob::channels)
|
||||
.add_property("height", &CaffeBlob::height)
|
||||
.add_property("width", &CaffeBlob::width)
|
||||
.add_property("count", &CaffeBlob::count)
|
||||
.add_property("data", &CaffeBlobWrap::get_data)
|
||||
.add_property("diff", &CaffeBlobWrap::get_diff)
|
||||
;
|
||||
|
||||
boost::python::class_<vector<CaffeBlob> >("BlobVec")
|
||||
.def(vector_indexing_suite<vector<CaffeBlob>, true>());
|
||||
|
||||
import_array();
|
||||
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче