diff --git a/include/caffe/python_layer.hpp b/include/caffe/python_layer.hpp index 19cf18c9..9c30250c 100644 --- a/include/caffe/python_layer.hpp +++ b/include/caffe/python_layer.hpp @@ -18,22 +18,11 @@ class PythonLayer : public Layer { virtual void LayerSetUp(const vector*>& bottom, const vector*>& top) { - try { - self_.attr("setup")(bottom, top); - } catch (bp::error_already_set) { - PyErr_Print(); - throw; - } + self_.attr("setup")(bottom, top); } - virtual void Reshape(const vector*>& bottom, const vector*>& top) { - try { - self_.attr("reshape")(bottom, top); - } catch (bp::error_already_set) { - PyErr_Print(); - throw; - } + self_.attr("reshape")(bottom, top); } virtual inline const char* type() const { return "Python"; } @@ -41,21 +30,11 @@ class PythonLayer : public Layer { protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top) { - try { - self_.attr("forward")(bottom, top); - } catch (bp::error_already_set) { - PyErr_Print(); - throw; - } + self_.attr("forward")(bottom, top); } virtual void Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - try { - self_.attr("backward")(top, propagate_down, bottom); - } catch (bp::error_already_set) { - PyErr_Print(); - throw; - } + self_.attr("backward")(top, propagate_down, bottom); } private: diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py index f41e283f..ff070a35 100644 --- a/python/caffe/test/test_python_layer.py +++ b/python/caffe/test/test_python_layer.py @@ -22,6 +22,13 @@ class SimpleLayer(caffe.Layer): bottom[0].diff[...] = 10 * top[0].diff +class ExceptionLayer(caffe.Layer): + """A layer for checking exceptions from Python""" + + def setup(self, bottom, top): + raise RuntimeError + + def python_net_file(): with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: f.write("""name: 'pythonnet' force_backward: true @@ -35,6 +42,16 @@ def python_net_file(): return f.name +def exception_net_file(): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write("""name: 'pythonnet' force_backward: true + input: 'data' input_shape { dim: 10 dim: 9 dim: 8 } + layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top' + python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } } + """) + return f.name + + class TestPythonLayer(unittest.TestCase): def setUp(self): net_file = python_net_file() @@ -62,3 +79,8 @@ class TestPythonLayer(unittest.TestCase): for blob in six.itervalues(self.net.blobs): for d in blob.data.shape: self.assertEqual(s, d) + + def test_exception(self): + net_file = exception_net_file() + self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST) + os.remove(net_file) diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 0b7523fc..9de3abdc 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -8,6 +8,11 @@ #include "boost/algorithm/string.hpp" #include "caffe/caffe.hpp" +#ifdef WITH_PYTHON_LAYER +#include "boost/python.hpp" +namespace bp = boost::python; +#endif + using caffe::Blob; using caffe::Caffe; using caffe::Net; @@ -304,7 +309,16 @@ int main(int argc, char** argv) { // Run tool or show usage. caffe::GlobalInit(&argc, &argv); if (argc == 2) { - return GetBrewFunction(caffe::string(argv[1]))(); +#ifdef WITH_PYTHON_LAYER + try { +#endif + return GetBrewFunction(caffe::string(argv[1]))(); +#ifdef WITH_PYTHON_LAYER + } catch (bp::error_already_set) { + PyErr_Print(); + return 1; + } +#endif } else { gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); }