[Relay] [Op] Squeeze (#1858)
This commit is contained in:
Родитель
47b8c36dcf
Коммит
201cfdc59a
|
@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
|
|||
}
|
||||
}; // struct InitOpAttrs
|
||||
|
||||
/*! \brief Attributes used in squeeze operators */
|
||||
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
|
||||
Array<IndexExpr> axes;
|
||||
|
||||
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
|
||||
TVM_ATTR_FIELD(axes)
|
||||
.describe("The axes to squeeze in the input tensor."
|
||||
"If `axes = []`, all axis of dimension 1 get squeezed;"
|
||||
"Else, the dimension in axes get squeezed."
|
||||
"It is an error if an axes does not has dimension 1.")
|
||||
.set_default(Array<IndexExpr>({}));
|
||||
}
|
||||
}; // struct SqueezeAttrs
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
|
||||
|
|
|
@ -42,12 +42,35 @@ def transpose(data, axes=None):
|
|||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The reshaped result.
|
||||
The transposed result.
|
||||
"""
|
||||
axes = axes or []
|
||||
return _make.transpose(data, list(axes))
|
||||
|
||||
|
||||
def squeeze(data, axes=None):
|
||||
"""Squeeze axes in the array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : relay.Expr
|
||||
The input data to the operator.
|
||||
|
||||
axes : None or List[int]
|
||||
Axes to remove.
|
||||
If axes = [] or = None, remove all axis of dimensions 1.
|
||||
Otherwise, remove all axis in axes.
|
||||
If any axis in axes has dimension that does not equal 1, it is an error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The squeezed result.
|
||||
"""
|
||||
axes = axes or []
|
||||
return _make.squeeze(data, list(axes))
|
||||
|
||||
|
||||
def reshape(data, newshape):
|
||||
"""Reshapes the input array.
|
||||
|
||||
|
|
|
@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
|
|||
.set_support_level(1)
|
||||
.add_type_rel("ExpandDims", ExpandDimsRel);
|
||||
|
||||
/* relay.concatenate */
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
|
||||
|
||||
bool ConcatenateRel(const Array<Type>& types,
|
||||
|
@ -633,5 +631,75 @@ Examples::
|
|||
.set_support_level(4)
|
||||
.add_type_rel("Where", WhereRel);
|
||||
|
||||
Expr MakeSqueeze(Expr data,
|
||||
Array<IndexExpr> axes) {
|
||||
auto attrs = make_node<SqueezeAttrs>();
|
||||
attrs->axes = std::move(axes);
|
||||
static const Op& op = Op::Get("squeeze");
|
||||
return CallNode::make(op, {data}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op._make.squeeze")
|
||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||
runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
|
||||
});
|
||||
|
||||
bool SqueezeRel(const Array<Type>& types,
|
||||
int num_inputs,
|
||||
const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
CHECK_EQ(types.size(), 2);
|
||||
const auto* data = types[0].as<TensorTypeNode>();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const auto* param = attrs.as<SqueezeAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
std::vector<IndexExpr> result_shape;
|
||||
// if axes is empty, squeeze all axes of dimension 1
|
||||
if (param->axes.size() == 0) {
|
||||
for (const auto& e : data->shape) {
|
||||
const int64_t* axis_ptr = as_const_int(e);
|
||||
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
|
||||
if (*axis_ptr != 1) {
|
||||
result_shape.push_back(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// pair up original shape with a boolean which control whether it will be in the final shape.
|
||||
std::vector<std::pair<IndexExpr, bool> > original_shape;
|
||||
for (const auto& e : data->shape) {
|
||||
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
|
||||
}
|
||||
for (const auto& e : param->axes) {
|
||||
const int64_t* axis_ptr = as_const_int(e);
|
||||
CHECK(axis_ptr != nullptr);
|
||||
original_shape.at(*axis_ptr).second = false;
|
||||
}
|
||||
for (const auto p : original_shape) {
|
||||
if (p.second) {
|
||||
result_shape.push_back(p.first);
|
||||
} else {
|
||||
const int64_t* axis_ptr = as_const_int(p.first);
|
||||
CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
|
||||
CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
|
||||
}
|
||||
}
|
||||
}
|
||||
reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
|
||||
return true;
|
||||
}
|
||||
|
||||
RELAY_REGISTER_OP("squeeze")
|
||||
.describe(R"code(Squeeze the input tensor at the dimensions given by axes
|
||||
|
||||
- **data**: The input data to the operator.
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_num_inputs(1)
|
||||
.add_argument("data", "Tensor", "The input tensor.")
|
||||
.set_support_level(3)
|
||||
.add_type_rel("Squeeze", SqueezeRel);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -6,6 +6,7 @@ from tvm import relay
|
|||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.ir_builder import IRBuilder, func_type
|
||||
from tvm.relay.env import Environment
|
||||
from nose.tools import raises
|
||||
|
||||
def test_zeros_ones():
|
||||
for op in [relay.zeros, relay.ones]:
|
||||
|
@ -67,6 +68,44 @@ def test_transpose_infer_type():
|
|||
(t, n, 100), "float32")
|
||||
|
||||
|
||||
def test_squeeze_default_axes_infer_type():
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, t, d = 1, 4, 1
|
||||
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.squeeze(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
assert ftype.ret_type == relay.ty.TensorType(
|
||||
(4,), "float32")
|
||||
|
||||
|
||||
def test_squeeze_axes_infer_type():
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, t, d = 1, 4, 1
|
||||
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.squeeze(x, axes=(2,)))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
assert ftype.ret_type == relay.ty.TensorType(
|
||||
(1, 4), "float32")
|
||||
|
||||
|
||||
@raises(tvm._ffi.base.TVMError)
|
||||
def test_squeeze_bad_axes_infer_type():
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, t, d = 1, 4, 1
|
||||
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.squeeze(x, axes=(1,)))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
||||
|
||||
def test_reshape_infer_type():
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
|
||||
|
@ -181,3 +220,5 @@ if __name__ == "__main__":
|
|||
test_take_infer_type()
|
||||
test_full()
|
||||
test_full_like()
|
||||
test_squeeze_axes_infer_type()
|
||||
test_squeeze_default_axes_infer_type()
|
||||
|
|
Загрузка…
Ссылка в новой задаче