зеркало из https://github.com/microsoft/caffe.git
Merge pull request #2944 from philkr/python_layer_param
Give the python layer parameter/weight blobs.
This commit is contained in:
Коммит
f572eefc8a
|
@ -190,6 +190,21 @@ bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
|
|||
return bp::object();
|
||||
}
|
||||
|
||||
bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
|
||||
if (bp::len(kwargs) > 0) {
|
||||
throw std::runtime_error("BlobVec.add_blob takes no kwargs");
|
||||
}
|
||||
typedef vector<shared_ptr<Blob<Dtype> > > BlobVec;
|
||||
BlobVec* self = bp::extract<BlobVec*>(args[0]);
|
||||
vector<int> shape(bp::len(args) - 1);
|
||||
for (int i = 1; i < bp::len(args); ++i) {
|
||||
shape[i - 1] = bp::extract<int>(args[i]);
|
||||
}
|
||||
self->push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
|
||||
// We need to explicitly return None to use bp::raw_function.
|
||||
return bp::object();
|
||||
}
|
||||
|
||||
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
|
||||
|
||||
BOOST_PYTHON_MODULE(_caffe) {
|
||||
|
@ -288,7 +303,8 @@ BOOST_PYTHON_MODULE(_caffe) {
|
|||
|
||||
// vector wrappers for all the vector types we use
|
||||
bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
|
||||
.def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
|
||||
.def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>())
|
||||
.def("add_blob", bp::raw_function(&BlobVec_add_blob));
|
||||
bp::class_<vector<Blob<Dtype>*> >("RawBlobVec")
|
||||
.def(bp::vector_indexing_suite<vector<Blob<Dtype>*>, true>());
|
||||
bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
|
||||
|
|
|
@ -28,6 +28,21 @@ class ExceptionLayer(caffe.Layer):
|
|||
def setup(self, bottom, top):
|
||||
raise RuntimeError
|
||||
|
||||
class ParameterLayer(caffe.Layer):
|
||||
"""A layer that just multiplies by ten"""
|
||||
|
||||
def setup(self, bottom, top):
|
||||
self.blobs.add_blob(1)
|
||||
self.blobs[0].data[0] = 0
|
||||
|
||||
def reshape(self, bottom, top):
|
||||
top[0].reshape(*bottom[0].data.shape)
|
||||
|
||||
def forward(self, bottom, top):
|
||||
pass
|
||||
|
||||
def backward(self, top, propagate_down, bottom):
|
||||
self.blobs[0].diff[0] = 1
|
||||
|
||||
def python_net_file():
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
||||
|
@ -52,6 +67,16 @@ def exception_net_file():
|
|||
return f.name
|
||||
|
||||
|
||||
def parameter_net_file():
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
||||
f.write("""name: 'pythonnet' force_backward: true
|
||||
input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
|
||||
layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
|
||||
python_param { module: 'test_python_layer' layer: 'ParameterLayer' } }
|
||||
""")
|
||||
return f.name
|
||||
|
||||
|
||||
class TestPythonLayer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
net_file = python_net_file()
|
||||
|
@ -84,3 +109,32 @@ class TestPythonLayer(unittest.TestCase):
|
|||
net_file = exception_net_file()
|
||||
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
|
||||
os.remove(net_file)
|
||||
|
||||
def test_parameter(self):
|
||||
net_file = parameter_net_file()
|
||||
net = caffe.Net(net_file, caffe.TRAIN)
|
||||
# Test forward and backward
|
||||
net.forward()
|
||||
net.backward()
|
||||
layer = net.layers[list(net._layer_names).index('layer')]
|
||||
self.assertEqual(layer.blobs[0].data[0], 0)
|
||||
self.assertEqual(layer.blobs[0].diff[0], 1)
|
||||
layer.blobs[0].data[0] += layer.blobs[0].diff[0]
|
||||
self.assertEqual(layer.blobs[0].data[0], 1)
|
||||
|
||||
# Test saving and loading
|
||||
h, caffemodel_file = tempfile.mkstemp()
|
||||
net.save(caffemodel_file)
|
||||
layer.blobs[0].data[0] = -1
|
||||
self.assertEqual(layer.blobs[0].data[0], -1)
|
||||
net.copy_from(caffemodel_file)
|
||||
self.assertEqual(layer.blobs[0].data[0], 1)
|
||||
os.remove(caffemodel_file)
|
||||
|
||||
# Test weight sharing
|
||||
net2 = caffe.Net(net_file, caffe.TRAIN)
|
||||
net2.share_with(net)
|
||||
layer = net.layers[list(net2._layer_names).index('layer')]
|
||||
self.assertEqual(layer.blobs[0].data[0], 1)
|
||||
|
||||
os.remove(net_file)
|
||||
|
|
Загрузка…
Ссылка в новой задаче