diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 86351c40..55ba1778 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90 +Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748 diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index af493823..9f172054 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode { struct TakeAttrs : public tvm::AttrsNode { Integer axis; + std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_ATTR_FIELD(axis).set_default(NullValue()) .describe("The axis over which to select values."); + TVM_ATTR_FIELD(mode).set_default("clip") + .describe("Specify how out-of-bound indices will behave." + "clip - clip to the range (default)" + "wrap - wrap around the indices"); } }; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 47ad5d7a..8e36801f 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs): return _op.tile(inputs[0], **new_attrs) +def _mx_take(inputs, attrs): + assert len(inputs) == 2 + mode = attrs.get_str("mode", "clip") + if mode == "raise": + raise RuntimeError("take doesn't support raise mode") + axis = attrs.get_int("axis", 0) + return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode) + + def _mx_reverse(inputs, attrs): assert len(inputs) == 1 new_attrs = {} @@ -749,6 +758,7 @@ _convert_map = { "_full" : _mx_full, "repeat" : _mx_repeat, "tile" : _mx_tile, + "take" : _mx_take, "reverse" : _mx_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 37aace5a..73573043 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -186,7 +186,7 @@ def reshape_like(data, shape_like): return _make.reshape_like(data, shape_like) -def take(data, indices, axis=None): +def take(data, indices, axis=None, mode="clip"): """Take elements from an array along an axis. Parameters @@ -201,12 +201,17 @@ def take(data, indices, axis=None): The axis over which to select values. By default, the flattened input array is used. + mode : str, optional + Specifies how out-of-bound indices will behave. + clip - clip to the range (default) + wrap - wrap around the indices + Returns ------- ret : relay.Expr The computed result. """ - return _make.take(data, indices, axis) + return _make.take(data, indices, axis, mode) def full(fill_value, shape=(), dtype=""): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a0ea8f2e..08b06a2a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -753,24 +753,26 @@ Array TakeCompute(const Attrs& attrs, const auto* param = attrs.as(); CHECK(param != nullptr); if (!param->axis.defined()) { - return Array{ topi::take(inputs[0], inputs[1]) }; + return Array{ topi::take(inputs[0], inputs[1], param->mode) }; } else { - return Array{ topi::take(inputs[0], inputs[1], param->axis) }; + return Array{ topi::take(inputs[0], inputs[1], param->axis, param->mode) }; } } Expr MakeTake(Expr data, Expr indices, - Integer axis) { + Integer axis, + std::string mode) { auto attrs = make_node(); attrs->axis = std::move(axis); + attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); return CallNode::make(op, {data, indices}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.take") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeTake, args, rv); + runtime::detail::unpack_call(MakeTake, args, rv); }); RELAY_REGISTER_OP("take") diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index faccfbfd..9d0d5940 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -464,7 +464,6 @@ def test_forward_embedding(): verify((2, 2), (4, 5)) verify((2, 3, 4), (4, 5)) - def test_forward_smooth_l1(): data = mx.sym.var('data') mx_sym = mx.sym.smooth_l1(data) @@ -472,6 +471,26 @@ def test_forward_smooth_l1(): mx_sym = mx.sym.smooth_l1(data, scalar=1.0) verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) +def test_forward_take(): + def verify(shape, indices_src, axis, mode="clip"): + x_np = np.random.uniform(size=shape).astype("float32") + indices_np = np.array(indices_src, dtype="float32") + ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode) + mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np, indices_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2,2), [[[1,0],[0,1]]], 0) + verify((2,2), [[[1,0],[0,1]]], 1) + verify((4,3,5,6), [[2,1,0,0]], -2) + verify((3,4), [-1, 5], 0) + verify((3,4), [-1, 5], 0, mode="wrap") + verify((3,4), [-1, 5], 1) + verify((3,4), [-1, 5], 1, mode="wrap") + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -507,3 +526,4 @@ if __name__ == '__main__': test_forward_full() test_forward_embedding() test_forward_smooth_l1() + test_forward_take() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 10ace54e..0cfbcc2c 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -243,17 +243,17 @@ def test_take_infer_type(): verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) def test_take(): - def verify_take(src_shape, indices_src, axis=None): + def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) x = relay.var("x", relay.TensorType(src_shape, src_dtype)) indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype)) - z = relay.take(x, indices, axis=axis) + z = relay.take(x, indices, axis=axis, mode=mode) func = relay.Function([x, indices], z) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) - ref_res = np.take(x_data, indices=indices_src, axis=axis) + ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: @@ -269,6 +269,12 @@ def test_take(): verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) + verify_take((3,4), [-5, 20]) + verify_take((3,4), [-5, 20], mode="wrap") + verify_take((3,4), [-1, 2], axis=0) + verify_take((3,4), [-1, 2], axis=0, mode="wrap") + verify_take((3,4), [-1, 2], axis=1) + verify_take((3,4), [-1, 2], axis=1, mode="wrap") def test_split_infer_type(): diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index a327650f..6ee3bc6b 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -39,6 +39,11 @@ def test_simplify_mod(): stmt = tvm.ir_pass.CanonicalSimplify(body) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) assert diff.value == 0 + # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 + index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16) + assert index != j + index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)}) + assert index == j # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify( (j + n * 32) % 16, {j: tvm.Range(0, 6)}) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 464bd6fa..bbe1a316 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -604,22 +604,29 @@ inline Array split_sections(const Tensor& x, */ inline Tensor take(const Tensor& a, const Tensor& indices, + std::string mode = "clip", std::string name = "tensor", std::string tag = kInjective) { Array a_shape = a->shape; - Array out_shape; - for (size_t j = 0; j < indices->shape.size(); ++j) { - out_shape.push_back(indices->shape[j]); + Array out_shape = indices->shape; + Expr a_size = 1; + for (size_t i = 0; i < a_shape.size(); ++i) { + a_size = a_size * a_shape[i]; } - return compute( + if (mode == "clip") { + return compute( out_shape, [&](const Array& out_index) { - Array indices_position; - for (size_t j = 0; j < indices->shape.size(); ++j) { - indices_position.push_back(out_index[j]); - } - return a(UnravelIndex(indices(indices_position), a_shape)); + auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); + return a(UnravelIndex(idx, a_shape)); }, name, tag); + } else { // mode == "wrap" + return compute( + out_shape, [&](const Array& out_index) { + auto idx = (indices(out_index) % a_size + a_size) % a_size; + return a(UnravelIndex(idx, a_shape)); + }, name, tag); + } } /*! @@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a, inline Tensor take(const Tensor& a, const Tensor& indices, int axis, + std::string mode = "clip", std::string name = "tensor", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } + CHECK_GE(axis, 0) << "axis out of bounds"; CHECK_LT(axis, a->shape.size()) << "axis out of bounds"; + auto axis_dim = a->shape[axis]; int indices_len = static_cast(indices->shape.size()); Array out_shape; @@ -655,7 +665,8 @@ inline Tensor take(const Tensor& a, out_shape.push_back(a->shape[i]); } } - return compute( + if (mode == "clip") { + return compute( out_shape, [&](const Array& out_index) { Array indices_position; for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { @@ -665,12 +676,33 @@ inline Tensor take(const Tensor& a, for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - real_indices.push_back(indices(indices_position)); + auto idx = tvm::min(tvm::max(0, indices(indices_position)), + axis_dim - 1); + real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } return a(real_indices); }, name, tag); + } else { // mode == "wrap" + return compute( + out_shape, [&](const Array& out_index) { + Array indices_position; + for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim; + real_indices.push_back(idx); + for (size_t j = axis + indices_len; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, name, tag); + } } /*! diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 2c109cd9..e674b9e1 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0): return cpp.split(ary, indices_or_sections, axis) -def take(a, indices, axis=None): +def take(a, indices, axis=None, mode="clip"): """Take elements from an array along an axis. Parameters @@ -243,13 +243,18 @@ def take(a, indices, axis=None): The axis over which to select values. By default, the flattened input array is used. + mode : str, optional + Specifies how out-of-bound indices will behave. + clip - clip to the range (default) + wrap - wrap around the indices + Returns ------- ret : tvm.Tensor """ if axis is None: - return cpp.take(a, indices) - return cpp.take(a, indices, int(axis)) + return cpp.take(a, indices, mode) + return cpp.take(a, indices, int(axis), mode) def gather_nd(a, indices): diff --git a/topi/src/topi.cc b/topi/src/topi.cc index aed2eab9..1df73d8f 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform") TVM_REGISTER_GLOBAL("topi.take") .set_body([](TVMArgs args, TVMRetValue *rv) { - if (args.size() == 2) { - *rv = take(args[0], args[1]); + if (args.size() == 3) { + std::string mode = args[2]; + *rv = take(args[0], args[1], mode); } else { int axis = args[2]; - *rv = take(args[0], args[1], axis); + std::string mode = args[3]; + *rv = take(args[0], args[1], axis, mode); } }); diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 59c10904..b56df9f2 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -232,16 +232,16 @@ def verify_flip(in_shape, axis): for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) -def verify_take(src_shape, indices_src, axis=None): +def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A") indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") if axis is None: - out_tensor = topi.take(a=A, indices=indices) + out_tensor = topi.take(a=A, indices=indices, mode=mode) else: - out_tensor = topi.take(a=A, indices=indices, axis=axis) + out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode) def check_device(device): ctx = tvm.context(device, 0) @@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None): data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: - out_npys = np.take(data_npy, indices_src) + out_npys = np.take(data_npy, indices_src, mode=mode) else: - out_npys = np.take(data_npy, indices_src, axis=axis) + out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -498,6 +498,12 @@ def test_take(): verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) + verify_take((3,4), [-5, 20]) + verify_take((3,4), [-5, 20], mode="wrap") + verify_take((3,4), [-1, 2], axis=0) + verify_take((3,4), [-1, 2], axis=0, mode="wrap") + verify_take((3,4), [-1, 2], axis=1) + verify_take((3,4), [-1, 2], axis=1, mode="wrap") def test_gather_nd(): for indices_dtype in ['int32', 'float32']: