This commit is contained in:
Siva 2018-05-23 09:32:36 +05:30 коммит произвёл Tianqi Chen
Родитель 69d5fcab74
Коммит 510cd5ec81
2 изменённых файлов: 25 добавлений и 3 удалений

Просмотреть файл

@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
} else { } else {
std::unordered_set<dim_t> axis_checker; std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) { for (size_t i = 0; i < param.axis.ndim(); ++i) {
int real_axis;
if (param.axis[i] < 0) { if (param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim()); real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0); } else {
axis_checker.insert(real_axis); real_axis = param.axis[i];
} }
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
} }
for (size_t i = 0; i < shp.ndim(); ++i) { for (size_t i = 0; i < shp.ndim(); ++i) {
if (axis_checker.find(i) == axis_checker.end()) { if (axis_checker.find(i) == axis_checker.end()) {

Просмотреть файл

@ -116,6 +116,24 @@ def test_flatten():
sdict = infer_shape(y) sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 200]) assert(sdict["y"][0] == [10, 200])
def test_squeeze():
x = sym.Variable("x", shape=(1, 1, 1, 10))
y = sym.squeeze(x, axis=(1,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [1, 10])
x = sym.Variable("x", shape=(1, 3, 1))
y = sym.squeeze(x, name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])
y = sym.squeeze(x, axis=(0), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3, 1])
y = sym.squeeze(x, axis=(0,2), name='squeeze')
sdict = infer_shape(y)
assert(sdict['squeeze'][0] == [3])
# Level 2 # Level 2
def test_conv2d(): def test_conv2d():
@ -331,3 +349,4 @@ if __name__ == "__main__":
test_reduce() test_reduce()
test_transpose() test_transpose()
test_prelu() test_prelu()
test_squeeze()