[TOPI]Add where operator (#1416)
This commit is contained in:
Родитель
6ea74d4119
Коммит
ee3c1b09b9
|
@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective)
|
|||
# slice_like
|
||||
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
|
||||
reg.register_schedule("slice_like", _fschedule_injective)
|
||||
|
||||
# where
|
||||
reg.register_pattern("where", OpPattern.INJECTIVE)
|
||||
reg.register_schedule("where", _fschedule_injective)
|
||||
|
|
|
@ -1125,8 +1125,8 @@ Examples::
|
|||
DMLC_REGISTER_PARAMETER(SliceLikeParam);
|
||||
|
||||
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_attrs,
|
||||
std::vector<TShape>* out_attrs) {
|
||||
std::vector<TShape>* in_attrs,
|
||||
std::vector<TShape>* out_attrs) {
|
||||
CHECK_EQ(in_attrs->size(), 2U);
|
||||
CHECK_EQ(out_attrs->size(), 1U);
|
||||
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
|
||||
|
@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
|
|||
})
|
||||
.set_support_level(4);
|
||||
|
||||
// where
|
||||
inline bool WhereShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_attrs,
|
||||
std::vector<TShape>* out_attrs) {
|
||||
CHECK_EQ(in_attrs->size(), 3U);
|
||||
CHECK_EQ(out_attrs->size(), 1U);
|
||||
const TShape& cond_shape = in_attrs->at(0);
|
||||
const TShape& x_shape = in_attrs->at(1);
|
||||
const TShape& y_shape = in_attrs->at(2);
|
||||
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: "
|
||||
<< x_shape << " vs " << y_shape;
|
||||
if (cond_shape != x_shape) {
|
||||
CHECK_EQ(cond_shape.ndim(), 1)
|
||||
<< "Shape of condition " << cond_shape
|
||||
<< " must be either equal to x or has dimension of 1.";
|
||||
}
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool WhereInferType(const NodeAttrs &attrs,
|
||||
std::vector<int> *in_attrs,
|
||||
std::vector<int> *out_attrs) {
|
||||
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool WhereCorrectLayout(const NodeAttrs& attrs,
|
||||
std::vector<Layout> *ilayouts,
|
||||
const std::vector<Layout> *last_ilayouts,
|
||||
std::vector<Layout> *olayouts) {
|
||||
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
|
||||
CHECK_EQ(olayouts->size(), 1U);
|
||||
|
||||
for (size_t i = 0; i < ilayouts->size(); ++i) {
|
||||
const Layout& input = last_ilayouts->at(i).defined() ?
|
||||
last_ilayouts->at(i) : ilayouts->at(i);
|
||||
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(where)
|
||||
.describe(R"code(
|
||||
Return the elements, either from x or y, depending on the condition.
|
||||
|
||||
Given three ndarrays, condition, x, and y, return an ndarray with the elements
|
||||
from x or y, depending on the elements from condition are true or false.
|
||||
x and y must have the same shape. If condition has the same shape as x,
|
||||
each element in the output array is from x if the corresponding element
|
||||
in the condition is true, and from y if false.
|
||||
|
||||
If condition does not have the same shape as x, it must be a 1D array whose
|
||||
size is the same as x’s first dimension size. Each row of the output array
|
||||
is from x’s row if the corresponding element from condition is true, and
|
||||
from y’s row if false.
|
||||
|
||||
Note that all non-zero values are interpreted as True in condition.
|
||||
|
||||
Examples::
|
||||
|
||||
x = [[1, 2], [3, 4]]
|
||||
y = [[5, 6], [7, 8]]
|
||||
cond = [[0, 1], [-1, 0]]
|
||||
where(cond, x, y) = [[5, 2], [3, 8]]
|
||||
|
||||
|
||||
cond = [1, 0]
|
||||
where(cond, x, y) = [[1, 2], [7, 8]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.add_argument("condition", "Tensor", "Condition array")
|
||||
.add_argument("x", "Tensor", "First array to be selected")
|
||||
.add_argument("y", "Tensor", "Second array to be selected")
|
||||
.set_num_inputs(3)
|
||||
.set_num_outputs(1)
|
||||
.set_attr<FInferShape>("FInferShape", WhereShape)
|
||||
.set_attr<FInferType>("FInferType", WhereInferType)
|
||||
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{
|
||||
topi::where(inputs[0], inputs[1], inputs[2])
|
||||
};
|
||||
})
|
||||
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
|
||||
return std::vector<std::string>{"condition", "x", "y"};
|
||||
})
|
||||
.set_support_level(4);
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
|
|
|
@ -645,6 +645,36 @@ def test_slice_like():
|
|||
axis = (2, 3)
|
||||
verify_slice_like(np_data, np_shape_like, axis)
|
||||
|
||||
def verify_where(condition, x, y):
|
||||
dtype = "float32"
|
||||
if len(condition.shape) == 1:
|
||||
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
|
||||
else:
|
||||
np_out = np.where(condition, x, y)
|
||||
cond_var = sym.Variable("condition")
|
||||
x_var = sym.Variable("x")
|
||||
y_var = sym.Variable("y")
|
||||
net = sym.where(cond_var, x_var, y_var)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape,
|
||||
"x": x.shape, "y": y.shape})
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.set_input(**{"condition": condition, "x": x, "y": y})
|
||||
m.run()
|
||||
out = m.get_output(0, tvm.nd.empty(x.shape, dtype))
|
||||
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_where():
|
||||
shape = (13, 8, 224, 224, 6)
|
||||
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
|
||||
x = np.random.uniform(size=shape).astype("float32")
|
||||
y = np.random.uniform(size=shape).astype("float32")
|
||||
verify_where(condition, x, y)
|
||||
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
|
||||
x = np.random.uniform(size=shape).astype("float32")
|
||||
y = np.random.uniform(size=shape).astype("float32")
|
||||
verify_where(condition, x, y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reshape()
|
||||
|
@ -665,4 +695,5 @@ if __name__ == "__main__":
|
|||
test_multibox_transform_loc()
|
||||
test_nms()
|
||||
test_slice_like()
|
||||
test_where()
|
||||
print(nnvm.compiler.engine.dump())
|
||||
|
|
|
@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
|
|||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Return the elements, either from x or y, depending on the condition.
|
||||
*
|
||||
* \param condition The condition array.
|
||||
* \param x First array to be selected.
|
||||
* \param y Second array to be selected.
|
||||
* \param name The name of the operation.
|
||||
* \param tag The tag to mark the operation.
|
||||
*
|
||||
* \return A Tensor selected from x or y depending on condition.
|
||||
*/
|
||||
inline Tensor where(const Tensor& condition,
|
||||
const Tensor& x,
|
||||
const Tensor& y,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
CHECK_EQ(x->shape.size(), y->shape.size())
|
||||
<< "x and y must have the same shape.Got different number of dimension: "
|
||||
<< x->shape.size() << " vs " << y->shape.size();
|
||||
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
|
||||
<< x->dtype << " vs " << y->dtype;
|
||||
Array<Expr> oshape = x->shape;
|
||||
Tensor out;
|
||||
|
||||
if (condition->shape.size() != 1) {
|
||||
CHECK_EQ(condition->shape.size(), x->shape.size())
|
||||
<< "condition array must be either have the same shape as x or to be a "
|
||||
"1-D array.Got different number of dimension: "
|
||||
<< condition->shape.size() << " vs " << x->shape.size();
|
||||
out = compute(
|
||||
oshape, [&](const Array<Var>& indices) {
|
||||
return tvm::select(condition(indices) != 0, x(indices), y(indices));
|
||||
}, name, tag);
|
||||
} else {
|
||||
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
|
||||
<< "If condition is 1-D, the first dimension must be the same as x: "
|
||||
<< condition->shape[0] << " vs " << x->shape[0];
|
||||
out = compute(
|
||||
oshape, [&](const Array<Var>& indices) {
|
||||
Array<Expr> condition_idx{indices[0]};
|
||||
return tvm::select(condition(condition_idx) != 0,
|
||||
x(indices), y(indices));
|
||||
}, name, tag);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace topi
|
||||
#endif // TOPI_TRANSFORM_H_
|
||||
|
|
|
@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
|
|||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.where")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = where(args[0], args[1], args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.strided_slice")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = strided_slice(args[0], args[1], args[2], args[3]);
|
||||
|
|
|
@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None):
|
|||
for device in ["llvm", "opencl"]:
|
||||
check_device(device)
|
||||
|
||||
def verify_where(condition, x, y):
|
||||
dtype = "float32"
|
||||
if len(condition.shape) == 1:
|
||||
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
|
||||
else:
|
||||
np_out = np.where(condition, x, y)
|
||||
A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition")
|
||||
B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x")
|
||||
C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y")
|
||||
out_tensor = topi.cpp.where(A, B, C)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(out_tensor)
|
||||
|
||||
foo = tvm.build(s, [A, B, C, out_tensor], device, name="where")
|
||||
tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype)
|
||||
foo(tvm.nd.array(condition, ctx), tvm.nd.array(x, ctx),
|
||||
tvm.nd.array(y, ctx), tvm_out)
|
||||
np.testing.assert_allclose(tvm_out.asnumpy(), np_out)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
check_device(device)
|
||||
|
||||
def verify_concatenate_split(shapes, axis, indices_or_sections):
|
||||
tensor_l_concatenate = []
|
||||
for i, shape in enumerate(shapes):
|
||||
|
@ -324,6 +353,18 @@ def test_take():
|
|||
verify_take((2,2), [[[1,0],[0,1]]], 1)
|
||||
verify_take((4,3,5,6), [[2,1,0,0]], -2)
|
||||
|
||||
def test_where():
|
||||
shape = (10, 3, 7, 13)
|
||||
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
|
||||
x = np.random.uniform(size=shape).astype("float32")
|
||||
y = np.random.uniform(size=shape).astype("float32")
|
||||
verify_where(condition, x, y)
|
||||
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
|
||||
x = np.random.uniform(size=shape).astype("float32")
|
||||
y = np.random.uniform(size=shape).astype("float32")
|
||||
verify_where(condition, x, y)
|
||||
|
||||
|
||||
def test_regression_1():
|
||||
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
|
||||
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
|
||||
|
@ -340,5 +381,6 @@ if __name__ == "__main__":
|
|||
test_squeeze()
|
||||
test_split()
|
||||
test_take()
|
||||
test_where()
|
||||
test_regression_1()
|
||||
test_regression_2()
|
||||
|
|
Загрузка…
Ссылка в новой задаче