Merge pull request #2462 from longjon/correct-python-exceptions

Handle Python layer exceptions correctly
This commit is contained in:
Evan Shelhamer 2015-08-06 00:27:59 -07:00
Родитель d4aa5fe9b3 977023f171
Коммит ac6d4b67c2
3 изменённых файлов: 41 добавлений и 26 удалений

Просмотреть файл

@ -18,22 +18,11 @@ class PythonLayer : public Layer<Dtype> {
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("setup")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("setup")(bottom, top);
}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("reshape")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("reshape")(bottom, top);
}
virtual inline const char* type() const { return "Python"; }
@ -41,21 +30,11 @@ class PythonLayer : public Layer<Dtype> {
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("forward")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("forward")(bottom, top);
}
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
try {
self_.attr("backward")(top, propagate_down, bottom);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("backward")(top, propagate_down, bottom);
}
private:

Просмотреть файл

@ -22,6 +22,13 @@ class SimpleLayer(caffe.Layer):
bottom[0].diff[...] = 10 * top[0].diff
class ExceptionLayer(caffe.Layer):
"""A layer for checking exceptions from Python"""
def setup(self, bottom, top):
raise RuntimeError
def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
@ -35,6 +42,16 @@ def python_net_file():
return f.name
def exception_net_file():
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
""")
return f.name
class TestPythonLayer(unittest.TestCase):
def setUp(self):
net_file = python_net_file()
@ -62,3 +79,8 @@ class TestPythonLayer(unittest.TestCase):
for blob in six.itervalues(self.net.blobs):
for d in blob.data.shape:
self.assertEqual(s, d)
def test_exception(self):
net_file = exception_net_file()
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
os.remove(net_file)

Просмотреть файл

@ -8,6 +8,11 @@
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
#ifdef WITH_PYTHON_LAYER
#include "boost/python.hpp"
namespace bp = boost::python;
#endif
using caffe::Blob;
using caffe::Caffe;
using caffe::Net;
@ -304,7 +309,16 @@ int main(int argc, char** argv) {
// Run tool or show usage.
caffe::GlobalInit(&argc, &argv);
if (argc == 2) {
return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
try {
#endif
return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
} catch (bp::error_already_set) {
PyErr_Print();
return 1;
}
#endif
} else {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
}