Squeeze bug fix. (#506)
This commit is contained in:
Родитель
69d5fcab74
Коммит
510cd5ec81
|
@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
|
|||
} else {
|
||||
std::unordered_set<dim_t> axis_checker;
|
||||
for (size_t i = 0; i < param.axis.ndim(); ++i) {
|
||||
int real_axis;
|
||||
if (param.axis[i] < 0) {
|
||||
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
|
||||
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
|
||||
axis_checker.insert(real_axis);
|
||||
real_axis = param.axis[i] + static_cast<int>(shp.ndim());
|
||||
} else {
|
||||
real_axis = param.axis[i];
|
||||
}
|
||||
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
|
||||
axis_checker.insert(real_axis);
|
||||
}
|
||||
for (size_t i = 0; i < shp.ndim(); ++i) {
|
||||
if (axis_checker.find(i) == axis_checker.end()) {
|
||||
|
|
|
@ -116,6 +116,24 @@ def test_flatten():
|
|||
sdict = infer_shape(y)
|
||||
assert(sdict["y"][0] == [10, 200])
|
||||
|
||||
def test_squeeze():
|
||||
x = sym.Variable("x", shape=(1, 1, 1, 10))
|
||||
y = sym.squeeze(x, axis=(1,2), name='squeeze')
|
||||
sdict = infer_shape(y)
|
||||
assert(sdict['squeeze'][0] == [1, 10])
|
||||
|
||||
x = sym.Variable("x", shape=(1, 3, 1))
|
||||
y = sym.squeeze(x, name='squeeze')
|
||||
sdict = infer_shape(y)
|
||||
assert(sdict['squeeze'][0] == [3])
|
||||
|
||||
y = sym.squeeze(x, axis=(0), name='squeeze')
|
||||
sdict = infer_shape(y)
|
||||
assert(sdict['squeeze'][0] == [3, 1])
|
||||
|
||||
y = sym.squeeze(x, axis=(0,2), name='squeeze')
|
||||
sdict = infer_shape(y)
|
||||
assert(sdict['squeeze'][0] == [3])
|
||||
|
||||
# Level 2
|
||||
def test_conv2d():
|
||||
|
@ -331,3 +349,4 @@ if __name__ == "__main__":
|
|||
test_reduce()
|
||||
test_transpose()
|
||||
test_prelu()
|
||||
test_squeeze()
|
||||
|
|
Загрузка…
Ссылка в новой задаче