Squeeze bug fix. (#506)
This commit is contained in:
Родитель
69d5fcab74
Коммит
510cd5ec81
|
@ -638,11 +638,14 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
|
||||||
} else {
|
} else {
|
||||||
std::unordered_set<dim_t> axis_checker;
|
std::unordered_set<dim_t> axis_checker;
|
||||||
for (size_t i = 0; i < param.axis.ndim(); ++i) {
|
for (size_t i = 0; i < param.axis.ndim(); ++i) {
|
||||||
|
int real_axis;
|
||||||
if (param.axis[i] < 0) {
|
if (param.axis[i] < 0) {
|
||||||
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
|
real_axis = param.axis[i] + static_cast<int>(shp.ndim());
|
||||||
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
|
} else {
|
||||||
axis_checker.insert(real_axis);
|
real_axis = param.axis[i];
|
||||||
}
|
}
|
||||||
|
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
|
||||||
|
axis_checker.insert(real_axis);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < shp.ndim(); ++i) {
|
for (size_t i = 0; i < shp.ndim(); ++i) {
|
||||||
if (axis_checker.find(i) == axis_checker.end()) {
|
if (axis_checker.find(i) == axis_checker.end()) {
|
||||||
|
|
|
@ -116,6 +116,24 @@ def test_flatten():
|
||||||
sdict = infer_shape(y)
|
sdict = infer_shape(y)
|
||||||
assert(sdict["y"][0] == [10, 200])
|
assert(sdict["y"][0] == [10, 200])
|
||||||
|
|
||||||
|
def test_squeeze():
|
||||||
|
x = sym.Variable("x", shape=(1, 1, 1, 10))
|
||||||
|
y = sym.squeeze(x, axis=(1,2), name='squeeze')
|
||||||
|
sdict = infer_shape(y)
|
||||||
|
assert(sdict['squeeze'][0] == [1, 10])
|
||||||
|
|
||||||
|
x = sym.Variable("x", shape=(1, 3, 1))
|
||||||
|
y = sym.squeeze(x, name='squeeze')
|
||||||
|
sdict = infer_shape(y)
|
||||||
|
assert(sdict['squeeze'][0] == [3])
|
||||||
|
|
||||||
|
y = sym.squeeze(x, axis=(0), name='squeeze')
|
||||||
|
sdict = infer_shape(y)
|
||||||
|
assert(sdict['squeeze'][0] == [3, 1])
|
||||||
|
|
||||||
|
y = sym.squeeze(x, axis=(0,2), name='squeeze')
|
||||||
|
sdict = infer_shape(y)
|
||||||
|
assert(sdict['squeeze'][0] == [3])
|
||||||
|
|
||||||
# Level 2
|
# Level 2
|
||||||
def test_conv2d():
|
def test_conv2d():
|
||||||
|
@ -331,3 +349,4 @@ if __name__ == "__main__":
|
||||||
test_reduce()
|
test_reduce()
|
||||||
test_transpose()
|
test_transpose()
|
||||||
test_prelu()
|
test_prelu()
|
||||||
|
test_squeeze()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче