[Relay/TOPI][OP] Add clip and wrap mode support in take (#2858)
* Update take * Add special case for canonical simplify and fix test cases * Use lower case for wrap and clip * remove unnecssary lower * Fix mxnet converter for take * fix
This commit is contained in:
Родитель
7cc9240ae8
Коммит
3746d9026a
|
@ -1 +1 @@
|
|||
Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90
|
||||
Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
|
|
@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
|
|||
|
||||
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
|
||||
Integer axis;
|
||||
std::string mode;
|
||||
|
||||
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
|
||||
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
|
||||
.describe("The axis over which to select values.");
|
||||
TVM_ATTR_FIELD(mode).set_default("clip")
|
||||
.describe("Specify how out-of-bound indices will behave."
|
||||
"clip - clip to the range (default)"
|
||||
"wrap - wrap around the indices");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
|
|||
return _op.tile(inputs[0], **new_attrs)
|
||||
|
||||
|
||||
def _mx_take(inputs, attrs):
|
||||
assert len(inputs) == 2
|
||||
mode = attrs.get_str("mode", "clip")
|
||||
if mode == "raise":
|
||||
raise RuntimeError("take doesn't support raise mode")
|
||||
axis = attrs.get_int("axis", 0)
|
||||
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
|
||||
|
||||
|
||||
def _mx_reverse(inputs, attrs):
|
||||
assert len(inputs) == 1
|
||||
new_attrs = {}
|
||||
|
@ -749,6 +758,7 @@ _convert_map = {
|
|||
"_full" : _mx_full,
|
||||
"repeat" : _mx_repeat,
|
||||
"tile" : _mx_tile,
|
||||
"take" : _mx_take,
|
||||
"reverse" : _mx_reverse,
|
||||
"squeeze" : _mx_squeeze,
|
||||
"broadcast_axis": _mx_broadcast_axis,
|
||||
|
|
|
@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
|
|||
return _make.reshape_like(data, shape_like)
|
||||
|
||||
|
||||
def take(data, indices, axis=None):
|
||||
def take(data, indices, axis=None, mode="clip"):
|
||||
"""Take elements from an array along an axis.
|
||||
|
||||
Parameters
|
||||
|
@ -201,12 +201,17 @@ def take(data, indices, axis=None):
|
|||
The axis over which to select values. By default,
|
||||
the flattened input array is used.
|
||||
|
||||
mode : str, optional
|
||||
Specifies how out-of-bound indices will behave.
|
||||
clip - clip to the range (default)
|
||||
wrap - wrap around the indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.take(data, indices, axis)
|
||||
return _make.take(data, indices, axis, mode)
|
||||
|
||||
|
||||
def full(fill_value, shape=(), dtype=""):
|
||||
|
|
|
@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
|
|||
const auto* param = attrs.as<TakeAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
if (!param->axis.defined()) {
|
||||
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
|
||||
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
|
||||
} else {
|
||||
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
|
||||
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
|
||||
}
|
||||
}
|
||||
|
||||
Expr MakeTake(Expr data,
|
||||
Expr indices,
|
||||
Integer axis) {
|
||||
Integer axis,
|
||||
std::string mode) {
|
||||
auto attrs = make_node<TakeAttrs>();
|
||||
attrs->axis = std::move(axis);
|
||||
attrs->mode = std::move(mode);
|
||||
static const Op& op = Op::Get("take");
|
||||
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op._make.take")
|
||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
|
||||
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
|
||||
});
|
||||
|
||||
RELAY_REGISTER_OP("take")
|
||||
|
|
|
@ -464,7 +464,6 @@ def test_forward_embedding():
|
|||
verify((2, 2), (4, 5))
|
||||
verify((2, 3, 4), (4, 5))
|
||||
|
||||
|
||||
def test_forward_smooth_l1():
|
||||
data = mx.sym.var('data')
|
||||
mx_sym = mx.sym.smooth_l1(data)
|
||||
|
@ -472,6 +471,26 @@ def test_forward_smooth_l1():
|
|||
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
|
||||
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
|
||||
|
||||
def test_forward_take():
|
||||
def verify(shape, indices_src, axis, mode="clip"):
|
||||
x_np = np.random.uniform(size=shape).astype("float32")
|
||||
indices_np = np.array(indices_src, dtype="float32")
|
||||
ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
|
||||
mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
|
||||
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
|
||||
for target, ctx in ctx_list():
|
||||
for kind in ["graph", "debug"]:
|
||||
intrp = relay.create_executor(kind, ctx=ctx, target=target)
|
||||
op_res = intrp.evaluate(new_sym)(x_np, indices_np)
|
||||
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
|
||||
verify((2,2), [[[1,0],[0,1]]], 0)
|
||||
verify((2,2), [[[1,0],[0,1]]], 1)
|
||||
verify((4,3,5,6), [[2,1,0,0]], -2)
|
||||
verify((3,4), [-1, 5], 0)
|
||||
verify((3,4), [-1, 5], 0, mode="wrap")
|
||||
verify((3,4), [-1, 5], 1)
|
||||
verify((3,4), [-1, 5], 1, mode="wrap")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_forward_mlp()
|
||||
test_forward_vgg()
|
||||
|
@ -507,3 +526,4 @@ if __name__ == '__main__':
|
|||
test_forward_full()
|
||||
test_forward_embedding()
|
||||
test_forward_smooth_l1()
|
||||
test_forward_take()
|
||||
|
|
|
@ -243,17 +243,17 @@ def test_take_infer_type():
|
|||
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
|
||||
|
||||
def test_take():
|
||||
def verify_take(src_shape, indices_src, axis=None):
|
||||
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
|
||||
src_dtype = "float32"
|
||||
indices_dtype = "int32"
|
||||
indices_src = np.array(indices_src, dtype=indices_dtype)
|
||||
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
|
||||
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
|
||||
z = relay.take(x, indices, axis=axis)
|
||||
z = relay.take(x, indices, axis=axis, mode=mode)
|
||||
|
||||
func = relay.Function([x, indices], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
|
||||
ref_res = np.take(x_data, indices=indices_src, axis=axis)
|
||||
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
for kind in ["graph", "debug"]:
|
||||
|
@ -269,6 +269,12 @@ def test_take():
|
|||
verify_take((2,2), [[[1,0],[0,1]]], 0)
|
||||
verify_take((2,2), [[[1,0],[0,1]]], 1)
|
||||
verify_take((4,3,5,6), [[2,1,0,0]], -2)
|
||||
verify_take((3,4), [-5, 20])
|
||||
verify_take((3,4), [-5, 20], mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=0)
|
||||
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=1)
|
||||
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
|
||||
|
||||
|
||||
def test_split_infer_type():
|
||||
|
|
|
@ -39,6 +39,11 @@ def test_simplify_mod():
|
|||
stmt = tvm.ir_pass.CanonicalSimplify(body)
|
||||
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
|
||||
assert diff.value == 0
|
||||
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
|
||||
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
|
||||
assert index != j
|
||||
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
|
||||
assert index == j
|
||||
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
|
||||
index = tvm.ir_pass.CanonicalSimplify(
|
||||
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
|
||||
|
|
|
@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
|
|||
*/
|
||||
inline Tensor take(const Tensor& a,
|
||||
const Tensor& indices,
|
||||
std::string mode = "clip",
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Array<Expr> a_shape = a->shape;
|
||||
Array<Expr> out_shape;
|
||||
for (size_t j = 0; j < indices->shape.size(); ++j) {
|
||||
out_shape.push_back(indices->shape[j]);
|
||||
Array<Expr> out_shape = indices->shape;
|
||||
Expr a_size = 1;
|
||||
for (size_t i = 0; i < a_shape.size(); ++i) {
|
||||
a_size = a_size * a_shape[i];
|
||||
}
|
||||
|
||||
return compute(
|
||||
if (mode == "clip") {
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
Array<Expr> indices_position;
|
||||
for (size_t j = 0; j < indices->shape.size(); ++j) {
|
||||
indices_position.push_back(out_index[j]);
|
||||
}
|
||||
return a(UnravelIndex(indices(indices_position), a_shape));
|
||||
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
|
||||
return a(UnravelIndex(idx, a_shape));
|
||||
}, name, tag);
|
||||
} else { // mode == "wrap"
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
auto idx = (indices(out_index) % a_size + a_size) % a_size;
|
||||
return a(UnravelIndex(idx, a_shape));
|
||||
}, name, tag);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
|
@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
|
|||
inline Tensor take(const Tensor& a,
|
||||
const Tensor& indices,
|
||||
int axis,
|
||||
std::string mode = "clip",
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
if (axis < 0) {
|
||||
axis += static_cast<int>(a->shape.size());
|
||||
}
|
||||
CHECK_GE(axis, 0) << "axis out of bounds";
|
||||
CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
|
||||
auto axis_dim = a->shape[axis];
|
||||
|
||||
int indices_len = static_cast<int>(indices->shape.size());
|
||||
Array<Expr> out_shape;
|
||||
|
@ -655,7 +665,8 @@ inline Tensor take(const Tensor& a,
|
|||
out_shape.push_back(a->shape[i]);
|
||||
}
|
||||
}
|
||||
return compute(
|
||||
if (mode == "clip") {
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
Array<Expr> indices_position;
|
||||
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
|
||||
|
@ -665,12 +676,33 @@ inline Tensor take(const Tensor& a,
|
|||
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
real_indices.push_back(indices(indices_position));
|
||||
auto idx = tvm::min(tvm::max(0, indices(indices_position)),
|
||||
axis_dim - 1);
|
||||
real_indices.push_back(idx);
|
||||
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
return a(real_indices);
|
||||
}, name, tag);
|
||||
} else { // mode == "wrap"
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
Array<Expr> indices_position;
|
||||
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
|
||||
indices_position.push_back(out_index[j]);
|
||||
}
|
||||
Array<Expr> real_indices;
|
||||
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
|
||||
real_indices.push_back(idx);
|
||||
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
return a(real_indices);
|
||||
}, name, tag);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
|
|
|
@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
|
|||
return cpp.split(ary, indices_or_sections, axis)
|
||||
|
||||
|
||||
def take(a, indices, axis=None):
|
||||
def take(a, indices, axis=None, mode="clip"):
|
||||
"""Take elements from an array along an axis.
|
||||
|
||||
Parameters
|
||||
|
@ -243,13 +243,18 @@ def take(a, indices, axis=None):
|
|||
The axis over which to select values. By default,
|
||||
the flattened input array is used.
|
||||
|
||||
mode : str, optional
|
||||
Specifies how out-of-bound indices will behave.
|
||||
clip - clip to the range (default)
|
||||
wrap - wrap around the indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : tvm.Tensor
|
||||
"""
|
||||
if axis is None:
|
||||
return cpp.take(a, indices)
|
||||
return cpp.take(a, indices, int(axis))
|
||||
return cpp.take(a, indices, mode)
|
||||
return cpp.take(a, indices, int(axis), mode)
|
||||
|
||||
|
||||
def gather_nd(a, indices):
|
||||
|
|
|
@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")
|
|||
|
||||
TVM_REGISTER_GLOBAL("topi.take")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
if (args.size() == 2) {
|
||||
*rv = take(args[0], args[1]);
|
||||
if (args.size() == 3) {
|
||||
std::string mode = args[2];
|
||||
*rv = take(args[0], args[1], mode);
|
||||
} else {
|
||||
int axis = args[2];
|
||||
*rv = take(args[0], args[1], axis);
|
||||
std::string mode = args[3];
|
||||
*rv = take(args[0], args[1], axis, mode);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
@ -232,16 +232,16 @@ def verify_flip(in_shape, axis):
|
|||
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
|
||||
check_device(device)
|
||||
|
||||
def verify_take(src_shape, indices_src, axis=None):
|
||||
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
|
||||
src_dtype = "float32"
|
||||
indices_dtype = "int32"
|
||||
indices_src = np.array(indices_src, dtype=indices_dtype)
|
||||
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
|
||||
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
|
||||
if axis is None:
|
||||
out_tensor = topi.take(a=A, indices=indices)
|
||||
out_tensor = topi.take(a=A, indices=indices, mode=mode)
|
||||
else:
|
||||
out_tensor = topi.take(a=A, indices=indices, axis=axis)
|
||||
out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
|
@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None):
|
|||
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
|
||||
|
||||
if axis is None:
|
||||
out_npys = np.take(data_npy, indices_src)
|
||||
out_npys = np.take(data_npy, indices_src, mode=mode)
|
||||
else:
|
||||
out_npys = np.take(data_npy, indices_src, axis=axis)
|
||||
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
|
||||
data_nd = tvm.nd.array(data_npy, ctx)
|
||||
indices_nd = tvm.nd.array(indices_src, ctx)
|
||||
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
|
||||
|
@ -498,6 +498,12 @@ def test_take():
|
|||
verify_take((2,2), [[[1,0],[0,1]]], 0)
|
||||
verify_take((2,2), [[[1,0],[0,1]]], 1)
|
||||
verify_take((4,3,5,6), [[2,1,0,0]], -2)
|
||||
verify_take((3,4), [-5, 20])
|
||||
verify_take((3,4), [-5, 20], mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=0)
|
||||
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=1)
|
||||
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
|
||||
|
||||
def test_gather_nd():
|
||||
for indices_dtype in ['int32', 'float32']:
|
||||
|
|
Загрузка…
Ссылка в новой задаче