diff --git a/nnvm/python/nnvm/top/reduction.py b/nnvm/python/nnvm/top/reduction.py index 61973a75..fd8e2f8d 100644 --- a/nnvm/python/nnvm/top/reduction.py +++ b/nnvm/python/nnvm/top/reduction.py @@ -41,3 +41,11 @@ reg.register_schedule("min", _fschedule_reduce) # collapse sum reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE) reg.register_schedule("collapse_sum", _fschedule_reduce) + +# argmax +reg.register_pattern("argmax", OpPattern.COMM_REDUCE) +reg.register_schedule("argmax", _fschedule_reduce) + +# argmin +reg.register_pattern("argmin", OpPattern.COMM_REDUCE) +reg.register_schedule("argmin", _fschedule_reduce) diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index cb6848e1..3f948720 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -262,5 +262,62 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum) return Array{ topi::collapse_sum(inputs[0], inputs[1]->shape) }; }); +template +inline bool InferFixedType(const NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + // Static type inference for argmax operation. Argmax return indices which + // should have Int32 type as shapes do. + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, static_cast(Type)); + return true; +} + +NNVM_REGISTER_BASE_REDUCE_OP(argmax) +.describe(R"code(Creates an operation that finds the indices of the maximum +values over a given axis. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "The input") +.set_attr("FInferShape", ReduceShape) +.set_attr("FInferType", InferFixedType) +.set_attr("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) +.set_num_inputs(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const ReduceParam& param = nnvm::get(attrs.parsed); + TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), + param.axis, param.exclude); + auto axis = ShapeToArray(r_axes); + return Array{ + topi::argmax(inputs[0], axis, param.keepdims) }; +}); + +NNVM_REGISTER_BASE_REDUCE_OP(argmin) +.describe(R"code(Creates an operation that finds the indices of the minimum +values over a given axis. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "The input") +.set_attr("FInferShape", ReduceShape) +.set_attr("FInferType", InferFixedType) +.set_attr("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) +.set_num_inputs(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const ReduceParam& param = nnvm::get(attrs.parsed); + TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), + param.axis, param.exclude); + auto axis = ShapeToArray(r_axes); + return Array{ + topi::argmin(inputs[0], axis, param.keepdims) }; +}); + + } // namespace top } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 236ac8e8..2d0c8aeb 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -71,21 +71,27 @@ def verify_transpose(dshape, axes): out = m.get_output(0, tvm.nd.empty(out_np.shape)) np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) - -def verify_reduce(dshape, fnp, fsym, **kwargs): +def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float32', **kwargs): + """ Verify reduce operations by comparign its result with `result` """ x = sym.Variable("x") - y = fsym(x + 1, **kwargs) - dtype = "float32" + y = fsym(x + 0, **kwargs) for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = graph_runtime.create(graph, lib, ctx) # set input - data = np.random.uniform(size=dshape).astype(dtype) - out_np = fnp(data + 1, **kwargs) m.run(x=data) - out = m.get_output(0, tvm.nd.empty(out_np.shape)) - np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + # oshape set to None means do not test the shape-correctness + oshape = result.shape if oshape is None else oshape + out = m.get_output(0, tvm.nd.empty(oshape, dtype=otype)) + np.testing.assert_equal(out.asnumpy().shape, result.shape) + np.testing.assert_allclose(out.asnumpy(), result, atol=1e-5, rtol=1e-5) +def verify_reduce(dshape, fnp, fsym, oshape=None, otype='float32', **kwargs): + """ Verify reduce operations by generating data at random and calling numpy + version as reference """ + data = np.random.uniform(size=dshape).astype(otype) + result = fnp(data + 0, **kwargs) + verify_reduce_explicit(dshape, data, result, fsym, oshape=oshape, otype=otype, **kwargs) def verify_collapse(dshape, target_shape, fnp): x = sym.Variable("x", shape=dshape) @@ -109,11 +115,43 @@ def test_transpose(): def test_reduce(): + + def _with_keepdims(func): + """ Wrapper around numpy's argmax/argmin with `keepdims` argument supported """ + def wrapper(data, axis=None, keepdims=False): + if not keepdims: + return func(data, axis=axis) + else: + if axis is not None: + out_shape = list(data.shape) + out_shape[axis] = 1 + else: + out_shape = [1 for _ in range(len(data.shape))] + return func(data, axis=axis).reshape(out_shape) + return wrapper + verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True) verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True) verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) verify_reduce((4, 4, 3), np.sum, sym.sum) + data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32) + verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True) + verify_reduce_explicit([2,2,2], data, np.array([[0,0],[0,1]]), sym.argmin, otype='int32', axis=[0,2], exclude=True) + shape = [4, 4, 3] + for axis in [None, 0, 1, 2]: + for keepdims in [True,False]: + kwargs = { 'keepdims':keepdims } + if axis is None: + # FIXME: NNVM doesn't support setting `axis=None` explicitly. + kwargs.update({'oshape': [1,1,1] if keepdims else [] }) + else: + kwargs.update({'axis': axis}) + kwargs.update({'oshape': shape[:axis]+[1]+shape[axis+1:] if keepdims else shape[:axis]+shape[axis+1:]}) + + verify_reduce(shape, _with_keepdims(np.argmax), sym.argmax, otype='int32', **kwargs) + verify_reduce(shape, _with_keepdims(np.argmin), sym.argmin, otype='int32', **kwargs) + def test_collapse(): verify_collapse((2, 3, 4), (1,), lambda x: x.sum())