[NNVM] Add argmax and argmin operations from topi (#1462)
This commit is contained in:
Родитель
0fddc35214
Коммит
cf9db7ea66
|
@ -41,3 +41,11 @@ reg.register_schedule("min", _fschedule_reduce)
|
||||||
# collapse sum
|
# collapse sum
|
||||||
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
|
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
|
||||||
reg.register_schedule("collapse_sum", _fschedule_reduce)
|
reg.register_schedule("collapse_sum", _fschedule_reduce)
|
||||||
|
|
||||||
|
# argmax
|
||||||
|
reg.register_pattern("argmax", OpPattern.COMM_REDUCE)
|
||||||
|
reg.register_schedule("argmax", _fschedule_reduce)
|
||||||
|
|
||||||
|
# argmin
|
||||||
|
reg.register_pattern("argmin", OpPattern.COMM_REDUCE)
|
||||||
|
reg.register_schedule("argmin", _fschedule_reduce)
|
||||||
|
|
|
@ -262,5 +262,62 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
|
||||||
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
|
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
|
||||||
});
|
});
|
||||||
|
|
||||||
|
template<int Type>
|
||||||
|
inline bool InferFixedType(const NodeAttrs& attrs,
|
||||||
|
std::vector<int>* in_attrs,
|
||||||
|
std::vector<int>* out_attrs) {
|
||||||
|
// Static type inference for argmax operation. Argmax return indices which
|
||||||
|
// should have Int32 type as shapes do.
|
||||||
|
CHECK_EQ(in_attrs->size(), 1U);
|
||||||
|
CHECK_EQ(out_attrs->size(), 1U);
|
||||||
|
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, static_cast<int>(Type));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
NNVM_REGISTER_BASE_REDUCE_OP(argmax)
|
||||||
|
.describe(R"code(Creates an operation that finds the indices of the maximum
|
||||||
|
values over a given axis.
|
||||||
|
|
||||||
|
)code" NNVM_ADD_FILELINE)
|
||||||
|
.add_argument("data", "Tensor", "The input")
|
||||||
|
.set_attr<FInferShape>("FInferShape", ReduceShape)
|
||||||
|
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>)
|
||||||
|
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
|
||||||
|
.set_num_inputs(1)
|
||||||
|
.set_attr<FTVMCompute>(
|
||||||
|
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||||
|
const Array<Tensor>& inputs,
|
||||||
|
const Array<Tensor>& out_info) {
|
||||||
|
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
|
||||||
|
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
|
||||||
|
param.axis, param.exclude);
|
||||||
|
auto axis = ShapeToArray(r_axes);
|
||||||
|
return Array<Tensor>{
|
||||||
|
topi::argmax(inputs[0], axis, param.keepdims) };
|
||||||
|
});
|
||||||
|
|
||||||
|
NNVM_REGISTER_BASE_REDUCE_OP(argmin)
|
||||||
|
.describe(R"code(Creates an operation that finds the indices of the minimum
|
||||||
|
values over a given axis.
|
||||||
|
|
||||||
|
)code" NNVM_ADD_FILELINE)
|
||||||
|
.add_argument("data", "Tensor", "The input")
|
||||||
|
.set_attr<FInferShape>("FInferShape", ReduceShape)
|
||||||
|
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>)
|
||||||
|
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
|
||||||
|
.set_num_inputs(1)
|
||||||
|
.set_attr<FTVMCompute>(
|
||||||
|
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||||
|
const Array<Tensor>& inputs,
|
||||||
|
const Array<Tensor>& out_info) {
|
||||||
|
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
|
||||||
|
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
|
||||||
|
param.axis, param.exclude);
|
||||||
|
auto axis = ShapeToArray(r_axes);
|
||||||
|
return Array<Tensor>{
|
||||||
|
topi::argmin(inputs[0], axis, param.keepdims) };
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
} // namespace top
|
} // namespace top
|
||||||
} // namespace nnvm
|
} // namespace nnvm
|
||||||
|
|
|
@ -71,21 +71,27 @@ def verify_transpose(dshape, axes):
|
||||||
out = m.get_output(0, tvm.nd.empty(out_np.shape))
|
out = m.get_output(0, tvm.nd.empty(out_np.shape))
|
||||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float32', **kwargs):
|
||||||
def verify_reduce(dshape, fnp, fsym, **kwargs):
|
""" Verify reduce operations by comparign its result with `result` """
|
||||||
x = sym.Variable("x")
|
x = sym.Variable("x")
|
||||||
y = fsym(x + 1, **kwargs)
|
y = fsym(x + 0, **kwargs)
|
||||||
dtype = "float32"
|
|
||||||
for target, ctx in ctx_list():
|
for target, ctx in ctx_list():
|
||||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
m = graph_runtime.create(graph, lib, ctx)
|
m = graph_runtime.create(graph, lib, ctx)
|
||||||
# set input
|
# set input
|
||||||
data = np.random.uniform(size=dshape).astype(dtype)
|
|
||||||
out_np = fnp(data + 1, **kwargs)
|
|
||||||
m.run(x=data)
|
m.run(x=data)
|
||||||
out = m.get_output(0, tvm.nd.empty(out_np.shape))
|
# oshape set to None means do not test the shape-correctness
|
||||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
oshape = result.shape if oshape is None else oshape
|
||||||
|
out = m.get_output(0, tvm.nd.empty(oshape, dtype=otype))
|
||||||
|
np.testing.assert_equal(out.asnumpy().shape, result.shape)
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), result, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
def verify_reduce(dshape, fnp, fsym, oshape=None, otype='float32', **kwargs):
|
||||||
|
""" Verify reduce operations by generating data at random and calling numpy
|
||||||
|
version as reference """
|
||||||
|
data = np.random.uniform(size=dshape).astype(otype)
|
||||||
|
result = fnp(data + 0, **kwargs)
|
||||||
|
verify_reduce_explicit(dshape, data, result, fsym, oshape=oshape, otype=otype, **kwargs)
|
||||||
|
|
||||||
def verify_collapse(dshape, target_shape, fnp):
|
def verify_collapse(dshape, target_shape, fnp):
|
||||||
x = sym.Variable("x", shape=dshape)
|
x = sym.Variable("x", shape=dshape)
|
||||||
|
@ -109,11 +115,43 @@ def test_transpose():
|
||||||
|
|
||||||
|
|
||||||
def test_reduce():
|
def test_reduce():
|
||||||
|
|
||||||
|
def _with_keepdims(func):
|
||||||
|
""" Wrapper around numpy's argmax/argmin with `keepdims` argument supported """
|
||||||
|
def wrapper(data, axis=None, keepdims=False):
|
||||||
|
if not keepdims:
|
||||||
|
return func(data, axis=axis)
|
||||||
|
else:
|
||||||
|
if axis is not None:
|
||||||
|
out_shape = list(data.shape)
|
||||||
|
out_shape[axis] = 1
|
||||||
|
else:
|
||||||
|
out_shape = [1 for _ in range(len(data.shape))]
|
||||||
|
return func(data, axis=axis).reshape(out_shape)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
|
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
|
||||||
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
|
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
|
||||||
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
|
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
|
||||||
verify_reduce((4, 4, 3), np.sum, sym.sum)
|
verify_reduce((4, 4, 3), np.sum, sym.sum)
|
||||||
|
|
||||||
|
data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32)
|
||||||
|
verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True)
|
||||||
|
verify_reduce_explicit([2,2,2], data, np.array([[0,0],[0,1]]), sym.argmin, otype='int32', axis=[0,2], exclude=True)
|
||||||
|
shape = [4, 4, 3]
|
||||||
|
for axis in [None, 0, 1, 2]:
|
||||||
|
for keepdims in [True,False]:
|
||||||
|
kwargs = { 'keepdims':keepdims }
|
||||||
|
if axis is None:
|
||||||
|
# FIXME: NNVM doesn't support setting `axis=None` explicitly.
|
||||||
|
kwargs.update({'oshape': [1,1,1] if keepdims else [] })
|
||||||
|
else:
|
||||||
|
kwargs.update({'axis': axis})
|
||||||
|
kwargs.update({'oshape': shape[:axis]+[1]+shape[axis+1:] if keepdims else shape[:axis]+shape[axis+1:]})
|
||||||
|
|
||||||
|
verify_reduce(shape, _with_keepdims(np.argmax), sym.argmax, otype='int32', **kwargs)
|
||||||
|
verify_reduce(shape, _with_keepdims(np.argmin), sym.argmin, otype='int32', **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def test_collapse():
|
def test_collapse():
|
||||||
verify_collapse((2, 3, 4), (1,), lambda x: x.sum())
|
verify_collapse((2, 3, 4), (1,), lambda x: x.sum())
|
||||||
|
|
Загрузка…
Ссылка в новой задаче