This commit is contained in:
雾雨魔理沙 2018-10-15 09:46:21 -07:00 коммит произвёл Tianqi Chen
Родитель 47b8c36dcf
Коммит 201cfdc59a
4 изменённых файлов: 149 добавлений и 3 удалений

Просмотреть файл

@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
}
}; // struct InitOpAttrs
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
}
}; // struct SqueezeAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_

Просмотреть файл

@ -42,12 +42,35 @@ def transpose(data, axes=None):
Returns
-------
result : relay.Expr
The reshaped result.
The transposed result.
"""
axes = axes or []
return _make.transpose(data, list(axes))
def squeeze(data, axes=None):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
def reshape(data, newshape):
"""Reshapes the input array.

Просмотреть файл

@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel);
/* relay.concatenate */
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
bool ConcatenateRel(const Array<Type>& types,
@ -633,5 +631,75 @@ Examples::
.set_support_level(4)
.add_type_rel("Where", WhereRel);
Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.squeeze")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
});
bool SqueezeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) {
result_shape.push_back(e);
}
}
} else {
// pair up original shape with a boolean which control whether it will be in the final shape.
std::vector<std::pair<IndexExpr, bool> > original_shape;
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
result_shape.push_back(p.first);
} else {
const int64_t* axis_ptr = as_const_int(p.first);
CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
}
}
}
reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
return true;
}
RELAY_REGISTER_OP("squeeze")
.describe(R"code(Squeeze the input tensor at the dimensions given by axes
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);
} // namespace relay
} // namespace tvm

Просмотреть файл

@ -6,6 +6,7 @@ from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.env import Environment
from nose.tools import raises
def test_zeros_ones():
for op in [relay.zeros, relay.ones]:
@ -67,6 +68,44 @@ def test_transpose_infer_type():
(t, n, 100), "float32")
def test_squeeze_default_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(4,), "float32")
def test_squeeze_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x, axes=(2,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, 4), "float32")
@raises(tvm._ffi.base.TVMError)
def test_squeeze_bad_axes_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = 1, 4, 1
x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.squeeze(x, axes=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
def test_reshape_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
@ -181,3 +220,5 @@ if __name__ == "__main__":
test_take_infer_type()
test_full()
test_full_like()
test_squeeze_axes_infer_type()
test_squeeze_default_axes_infer_type()