diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 278826bc..d304a595 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode { } }; // struct InitOpAttrs +/*! \brief Attributes used in squeeze operators */ +struct SqueezeAttrs : public tvm::AttrsNode { + Array axes; + + TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { + TVM_ATTR_FIELD(axes) + .describe("The axes to squeeze in the input tensor." + "If `axes = []`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axes does not has dimension 1.") + .set_default(Array({})); + } +}; // struct SqueezeAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 75fbba84..c2036f50 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -42,12 +42,35 @@ def transpose(data, axes=None): Returns ------- result : relay.Expr - The reshaped result. + The transposed result. """ axes = axes or [] return _make.transpose(data, list(axes)) +def squeeze(data, axes=None): + """Squeeze axes in the array. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axes : None or List[int] + Axes to remove. + If axes = [] or = None, remove all axis of dimensions 1. + Otherwise, remove all axis in axes. + If any axis in axes has dimension that does not equal 1, it is an error. + + Returns + ------- + result : relay.Expr + The squeezed result. + """ + axes = axes or [] + return _make.squeeze(data, list(axes)) + + def reshape(data, newshape): """Reshapes the input array. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fb7b09fd..95688347 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims") .set_support_level(1) .add_type_rel("ExpandDims", ExpandDimsRel); -/* relay.concatenate */ - TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); bool ConcatenateRel(const Array& types, @@ -633,5 +631,75 @@ Examples:: .set_support_level(4) .add_type_rel("Where", WhereRel); +Expr MakeSqueeze(Expr data, + Array axes) { + auto attrs = make_node(); + attrs->axes = std::move(axes); + static const Op& op = Op::Get("squeeze"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.squeeze") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSqueeze, args, rv); + }); + +bool SqueezeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* param = attrs.as(); + CHECK(param != nullptr); + std::vector result_shape; + // if axes is empty, squeeze all axes of dimension 1 + if (param->axes.size() == 0) { + for (const auto& e : data->shape) { + const int64_t* axis_ptr = as_const_int(e); + CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; + if (*axis_ptr != 1) { + result_shape.push_back(e); + } + } + } else { + // pair up original shape with a boolean which control whether it will be in the final shape. + std::vector > original_shape; + for (const auto& e : data->shape) { + original_shape.push_back(std::pair(e, true)); + } + for (const auto& e : param->axes) { + const int64_t* axis_ptr = as_const_int(e); + CHECK(axis_ptr != nullptr); + original_shape.at(*axis_ptr).second = false; + } + for (const auto p : original_shape) { + if (p.second) { + result_shape.push_back(p.first); + } else { + const int64_t* axis_ptr = as_const_int(p.first); + CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; + CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + } + } + } + reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype)); + return true; +} + +RELAY_REGISTER_OP("squeeze") +.describe(R"code(Squeeze the input tensor at the dimensions given by axes + +- **data**: The input data to the operator. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Squeeze", SqueezeRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 7d949b21..13ab483f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -6,6 +6,7 @@ from tvm import relay from tvm.relay.ir_pass import infer_type from tvm.relay.ir_builder import IRBuilder, func_type from tvm.relay.env import Environment +from nose.tools import raises def test_zeros_ones(): for op in [relay.zeros, relay.ones]: @@ -67,6 +68,44 @@ def test_transpose_infer_type(): (t, n, 100), "float32") +def test_squeeze_default_axes_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = 1, 4, 1 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.squeeze(x)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (4,), "float32") + + +def test_squeeze_axes_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = 1, 4, 1 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.squeeze(x, axes=(2,))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (1, 4), "float32") + + +@raises(tvm._ffi.base.TVMError) +def test_squeeze_bad_axes_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = 1, 4, 1 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.squeeze(x, axes=(1,))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + + def test_reshape_infer_type(): ib = relay.ir_builder.IRBuilder() n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20 @@ -181,3 +220,5 @@ if __name__ == "__main__": test_take_infer_type() test_full() test_full_like() + test_squeeze_axes_infer_type() + test_squeeze_default_axes_infer_type()