зеркало из https://github.com/microsoft/caffe.git
Merge pull request #2462 from longjon/correct-python-exceptions
Handle Python layer exceptions correctly
This commit is contained in:
Коммит
ac6d4b67c2
|
@ -18,22 +18,11 @@ class PythonLayer : public Layer<Dtype> {
|
|||
|
||||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
try {
|
||||
self_.attr("setup")(bottom, top);
|
||||
} catch (bp::error_already_set) {
|
||||
PyErr_Print();
|
||||
throw;
|
||||
}
|
||||
self_.attr("setup")(bottom, top);
|
||||
}
|
||||
|
||||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
try {
|
||||
self_.attr("reshape")(bottom, top);
|
||||
} catch (bp::error_already_set) {
|
||||
PyErr_Print();
|
||||
throw;
|
||||
}
|
||||
self_.attr("reshape")(bottom, top);
|
||||
}
|
||||
|
||||
virtual inline const char* type() const { return "Python"; }
|
||||
|
@ -41,21 +30,11 @@ class PythonLayer : public Layer<Dtype> {
|
|||
protected:
|
||||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
const vector<Blob<Dtype>*>& top) {
|
||||
try {
|
||||
self_.attr("forward")(bottom, top);
|
||||
} catch (bp::error_already_set) {
|
||||
PyErr_Print();
|
||||
throw;
|
||||
}
|
||||
self_.attr("forward")(bottom, top);
|
||||
}
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
|
||||
try {
|
||||
self_.attr("backward")(top, propagate_down, bottom);
|
||||
} catch (bp::error_already_set) {
|
||||
PyErr_Print();
|
||||
throw;
|
||||
}
|
||||
self_.attr("backward")(top, propagate_down, bottom);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -22,6 +22,13 @@ class SimpleLayer(caffe.Layer):
|
|||
bottom[0].diff[...] = 10 * top[0].diff
|
||||
|
||||
|
||||
class ExceptionLayer(caffe.Layer):
|
||||
"""A layer for checking exceptions from Python"""
|
||||
|
||||
def setup(self, bottom, top):
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def python_net_file():
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
||||
f.write("""name: 'pythonnet' force_backward: true
|
||||
|
@ -35,6 +42,16 @@ def python_net_file():
|
|||
return f.name
|
||||
|
||||
|
||||
def exception_net_file():
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write("""name: 'pythonnet' force_backward: true
|
||||
input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
|
||||
layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
|
||||
python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
|
||||
""")
|
||||
return f.name
|
||||
|
||||
|
||||
class TestPythonLayer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
net_file = python_net_file()
|
||||
|
@ -62,3 +79,8 @@ class TestPythonLayer(unittest.TestCase):
|
|||
for blob in six.itervalues(self.net.blobs):
|
||||
for d in blob.data.shape:
|
||||
self.assertEqual(s, d)
|
||||
|
||||
def test_exception(self):
|
||||
net_file = exception_net_file()
|
||||
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
|
||||
os.remove(net_file)
|
||||
|
|
|
@ -8,6 +8,11 @@
|
|||
#include "boost/algorithm/string.hpp"
|
||||
#include "caffe/caffe.hpp"
|
||||
|
||||
#ifdef WITH_PYTHON_LAYER
|
||||
#include "boost/python.hpp"
|
||||
namespace bp = boost::python;
|
||||
#endif
|
||||
|
||||
using caffe::Blob;
|
||||
using caffe::Caffe;
|
||||
using caffe::Net;
|
||||
|
@ -304,7 +309,16 @@ int main(int argc, char** argv) {
|
|||
// Run tool or show usage.
|
||||
caffe::GlobalInit(&argc, &argv);
|
||||
if (argc == 2) {
|
||||
return GetBrewFunction(caffe::string(argv[1]))();
|
||||
#ifdef WITH_PYTHON_LAYER
|
||||
try {
|
||||
#endif
|
||||
return GetBrewFunction(caffe::string(argv[1]))();
|
||||
#ifdef WITH_PYTHON_LAYER
|
||||
} catch (bp::error_already_set) {
|
||||
PyErr_Print();
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче