[Topi] Fast mode in take op (#3325)
This commit is contained in:
Родитель
d4ca627a5a
Коммит
2c41fd2f03
|
@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
|
|||
TVM_ATTR_FIELD(mode).set_default("clip")
|
||||
.describe("Specify how out-of-bound indices will behave."
|
||||
"clip - clip to the range (default)"
|
||||
"wrap - wrap around the indices");
|
||||
"wrap - wrap around the indices"
|
||||
"fast - no clip or wrap around (user must make sure indices are in-bound)");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
|
|||
the flattened input array is used.
|
||||
|
||||
mode : str, optional
|
||||
Specifies how out-of-bound indices will behave [clip, wrap].
|
||||
Specifies how out-of-bound indices will behave [clip, wrap, fast].
|
||||
clip: clip to the range (default).
|
||||
wrap: wrap around the indices.
|
||||
fast: no clip or wrap around (user must make sure indices are in-bound).
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -269,7 +269,8 @@ def test_take():
|
|||
|
||||
func = relay.Function([x, indices], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
|
||||
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
|
||||
np_mode = "raise" if mode == "fast" else mode
|
||||
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
for kind in ["graph", "debug"]:
|
||||
|
@ -291,6 +292,9 @@ def test_take():
|
|||
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=1)
|
||||
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
|
||||
verify_take((3,3,3), [[11,25]], mode="fast")
|
||||
verify_take((3,4), [0, 2], axis=0, mode="fast")
|
||||
verify_take((3,4), [0, 2], axis=1, mode="fast")
|
||||
|
||||
|
||||
def test_split_infer_type():
|
||||
|
|
|
@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
|
|||
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
|
||||
return a(UnravelIndex(idx, a_shape));
|
||||
}, name, tag);
|
||||
} else if (mode == "fast") {
|
||||
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
|
||||
"Make sure input indices are in bound";
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
return a(UnravelIndex(indices(out_index), a_shape));
|
||||
}, name, tag);
|
||||
} else { // mode == "wrap"
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
|
@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
|
|||
}
|
||||
return a(real_indices);
|
||||
}, name, tag);
|
||||
} else if (mode == "fast") {
|
||||
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
|
||||
"Make sure input indices are in bound";
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
Array<Expr> indices_position;
|
||||
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
|
||||
indices_position.push_back(out_index[j]);
|
||||
}
|
||||
Array<Expr> real_indices;
|
||||
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
real_indices.push_back(indices(indices_position));
|
||||
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
|
||||
real_indices.push_back(out_index[j]);
|
||||
}
|
||||
return a(real_indices);
|
||||
}, name, tag);
|
||||
} else { // mode == "wrap"
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& out_index) {
|
||||
|
|
|
@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
|
|||
Specifies how out-of-bound indices will behave.
|
||||
clip - clip to the range (default)
|
||||
wrap - wrap around the indices
|
||||
fast - no clip or wrap around (user must make sure indices are in-bound)
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
|
|||
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
|
||||
|
||||
if axis is None:
|
||||
out_npys = np.take(data_npy, indices_src, mode=mode)
|
||||
np_mode = "raise" if mode == "fast" else mode
|
||||
out_npys = np.take(data_npy, indices_src, mode=np_mode)
|
||||
else:
|
||||
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
|
||||
np_mode = "raise" if mode == "fast" else mode
|
||||
out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
|
||||
data_nd = tvm.nd.array(data_npy, ctx)
|
||||
indices_nd = tvm.nd.array(indices_src, ctx)
|
||||
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
|
||||
|
@ -521,6 +523,9 @@ def test_take():
|
|||
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
|
||||
verify_take((3,4), [-1, 2], axis=1)
|
||||
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
|
||||
verify_take((3,3,3), [[11,25]], mode="fast")
|
||||
verify_take((3,4), [0, 2], axis=0, mode="fast")
|
||||
verify_take((3,4), [0, 2], axis=1, mode="fast")
|
||||
|
||||
def test_gather_nd():
|
||||
for indices_dtype in ['int32', 'float32']:
|
||||
|
|
Загрузка…
Ссылка в новой задаче