From 510cd5ec81077a7f2808c77584400bd001e04015 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 23 May 2018 09:32:36 +0530 Subject: [PATCH] Squeeze bug fix. (#506) --- nnvm/src/top/tensor/transform.cc | 9 ++++++--- .../tests/python/unittest/test_infer_shape.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index f4e5a7be..48f8428d 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, } else { std::unordered_set axis_checker; for (size_t i = 0; i < param.axis.ndim(); ++i) { + int real_axis; if (param.axis[i] < 0) { - int real_axis = param.axis[i] + static_cast(shp.ndim()); - CHECK(real_axis < static_cast(shp.ndim()) && real_axis >= 0); - axis_checker.insert(real_axis); + real_axis = param.axis[i] + static_cast(shp.ndim()); + } else { + real_axis = param.axis[i]; } + CHECK(real_axis < static_cast(shp.ndim()) && real_axis >= 0); + axis_checker.insert(real_axis); } for (size_t i = 0; i < shp.ndim(); ++i) { if (axis_checker.find(i) == axis_checker.end()) { diff --git a/nnvm/tests/python/unittest/test_infer_shape.py b/nnvm/tests/python/unittest/test_infer_shape.py index 8011e96f..9fbc93c0 100644 --- a/nnvm/tests/python/unittest/test_infer_shape.py +++ b/nnvm/tests/python/unittest/test_infer_shape.py @@ -116,6 +116,24 @@ def test_flatten(): sdict = infer_shape(y) assert(sdict["y"][0] == [10, 200]) +def test_squeeze(): + x = sym.Variable("x", shape=(1, 1, 1, 10)) + y = sym.squeeze(x, axis=(1,2), name='squeeze') + sdict = infer_shape(y) + assert(sdict['squeeze'][0] == [1, 10]) + + x = sym.Variable("x", shape=(1, 3, 1)) + y = sym.squeeze(x, name='squeeze') + sdict = infer_shape(y) + assert(sdict['squeeze'][0] == [3]) + + y = sym.squeeze(x, axis=(0), name='squeeze') + sdict = infer_shape(y) + assert(sdict['squeeze'][0] == [3, 1]) + + y = sym.squeeze(x, axis=(0,2), name='squeeze') + sdict = infer_shape(y) + assert(sdict['squeeze'][0] == [3]) # Level 2 def test_conv2d(): @@ -331,3 +349,4 @@ if __name__ == "__main__": test_reduce() test_transpose() test_prelu() + test_squeeze()