This commit is contained in:
Yao Wang 2018-07-12 20:04:41 -07:00 коммит произвёл Tianqi Chen
Родитель 6ea74d4119
Коммит ee3c1b09b9
6 изменённых файлов: 225 добавлений и 2 удалений

Просмотреть файл

@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective)
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)
# where
reg.register_pattern("where", OpPattern.INJECTIVE)
reg.register_schedule("where", _fschedule_injective)

Просмотреть файл

@ -1125,8 +1125,8 @@ Examples::
DMLC_REGISTER_PARAMETER(SliceLikeParam);
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
})
.set_support_level(4);
// where
inline bool WhereShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& cond_shape = in_attrs->at(0);
const TShape& x_shape = in_attrs->at(1);
const TShape& y_shape = in_attrs->at(2);
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: "
<< x_shape << " vs " << y_shape;
if (cond_shape != x_shape) {
CHECK_EQ(cond_shape.ndim(), 1)
<< "Shape of condition " << cond_shape
<< " must be either equal to x or has dimension of 1.";
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape);
return true;
}
inline bool WhereInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1));
return true;
}
inline bool WhereCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(where)
.describe(R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as xs first dimension size. Each row of the output array
is from xs row if the corresponding element from condition is true, and
from ys row if false.
Note that all non-zero values are interpreted as True in condition.
Examples::
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
)code" NNVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", WhereShape)
.set_attr<FInferType>("FInferType", WhereInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{
topi::where(inputs[0], inputs[1], inputs[2])
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"condition", "x", "y"};
})
.set_support_level(4);
} // namespace top
} // namespace nnvm

Просмотреть файл

@ -645,6 +645,36 @@ def test_slice_like():
axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis)
def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
cond_var = sym.Variable("condition")
x_var = sym.Variable("x")
y_var = sym.Variable("y")
net = sym.where(cond_var, x_var, y_var)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape,
"x": x.shape, "y": y.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"condition": condition, "x": x, "y": y})
m.run()
out = m.get_output(0, tvm.nd.empty(x.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
def test_where():
shape = (13, 8, 224, 224, 6)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
if __name__ == "__main__":
test_reshape()
@ -665,4 +695,5 @@ if __name__ == "__main__":
test_multibox_transform_loc()
test_nms()
test_slice_like()
test_where()
print(nnvm.compiler.engine.dump())

Просмотреть файл

@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
}, name, tag);
}
/*!
* \brief Return the elements, either from x or y, depending on the condition.
*
* \param condition The condition array.
* \param x First array to be selected.
* \param y Second array to be selected.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor selected from x or y depending on condition.
*/
inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
<< x->dtype << " vs " << y->dtype;
Array<Expr> oshape = x->shape;
Tensor out;
if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
<< "If condition is 1-D, the first dimension must be the same as x: "
<< condition->shape[0] << " vs " << x->shape[0];
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
return out;
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_

Просмотреть файл

@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});
TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);

Просмотреть файл

@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None):
for device in ["llvm", "opencl"]:
check_device(device)
def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition")
B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x")
C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y")
out_tensor = topi.cpp.where(A, B, C)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
foo = tvm.build(s, [A, B, C, out_tensor], device, name="where")
tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype)
foo(tvm.nd.array(condition, ctx), tvm.nd.array(x, ctx),
tvm.nd.array(y, ctx), tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_concatenate_split(shapes, axis, indices_or_sections):
tensor_l_concatenate = []
for i, shape in enumerate(shapes):
@ -324,6 +353,18 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_where():
shape = (10, 3, 7, 13)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
def test_regression_1():
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
@ -340,5 +381,6 @@ if __name__ == "__main__":
test_squeeze()
test_split()
test_take()
test_where()
test_regression_1()
test_regression_2()