[NNVM][TOP] broadcast versions corresponding to topi: mod, max, min, pow, left_shift, right_shift greater, less, equal, not_equal, greater_equal and less_equal. (#1383)
This commit is contained in:
Родитель
0d673a9d6f
Коммит
3a0b757c10
|
@ -544,7 +544,7 @@ def _get_convert_map(opset):
|
|||
'Exp': Renamer('exp'),
|
||||
'Log': Renamer('log'),
|
||||
'Tanh': Renamer('tanh'),
|
||||
# 'Pow'
|
||||
'Pow': Renamer('broadcast_pow'),
|
||||
'PRelu': Prelu.get_converter(opset),
|
||||
'Sigmoid': Renamer('sigmoid'),
|
||||
# 'HardSigmoid'
|
||||
|
|
|
@ -168,6 +168,54 @@ reg.register_schedule("broadcast_mul", _fschedule_broadcast)
|
|||
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_div", _fschedule_broadcast)
|
||||
|
||||
# broadcast mod
|
||||
reg.register_pattern("broadcast_mod", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_mod", _fschedule_broadcast)
|
||||
|
||||
# broadcast max
|
||||
reg.register_pattern("broadcast_max", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_max", _fschedule_broadcast)
|
||||
|
||||
# broadcast min
|
||||
reg.register_pattern("broadcast_min", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_min", _fschedule_broadcast)
|
||||
|
||||
# broadcast pow
|
||||
reg.register_pattern("broadcast_pow", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_pow", _fschedule_broadcast)
|
||||
|
||||
# broadcast left_shift
|
||||
reg.register_pattern("broadcast_left_shift", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_left_shift", _fschedule_broadcast)
|
||||
|
||||
# broadcast right_shift
|
||||
reg.register_pattern("broadcast_right_shift", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_right_shift", _fschedule_broadcast)
|
||||
|
||||
# broadcast greater
|
||||
reg.register_pattern("broadcast_greater", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_greater", _fschedule_broadcast)
|
||||
|
||||
# broadcast less
|
||||
reg.register_pattern("broadcast_less", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_less", _fschedule_broadcast)
|
||||
|
||||
# broadcast equal
|
||||
reg.register_pattern("broadcast_equal", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_equal", _fschedule_broadcast)
|
||||
|
||||
# broadcast not_equal
|
||||
reg.register_pattern("broadcast_not_equal", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_not_equal", _fschedule_broadcast)
|
||||
|
||||
# broadcast greater_equal
|
||||
reg.register_pattern("broadcast_greater_equal", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_greater_equal", _fschedule_broadcast)
|
||||
|
||||
# broadcast less_equal
|
||||
reg.register_pattern("broadcast_less_equal", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_less_equal", _fschedule_broadcast)
|
||||
|
||||
# broadcast_to
|
||||
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_to", _fschedule_broadcast)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "../op_common.h"
|
||||
#include "../elemwise_op_common.h"
|
||||
#include "topi/broadcast.h"
|
||||
#include "topi/elemwise.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace top {
|
||||
|
@ -346,5 +347,251 @@ Example::
|
|||
return std::vector<NodeEntry>{ dlhs, drhs };
|
||||
});
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mod, mod)
|
||||
.add_alias("__mod_symbol__")
|
||||
.describe(R"code(Returns element-wise mod of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 3.]]
|
||||
|
||||
broadcast_mod(x, y) = [[ 1., 0., 1.],
|
||||
[ 1., 2., 0.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_max, maximum)
|
||||
.add_alias("__max_symbol__")
|
||||
.describe(R"code(Returns element-wise max of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 3.]]
|
||||
|
||||
broadcast_max(x, y) = [[ 2., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_min, minimum)
|
||||
.add_alias("__min_symbol__")
|
||||
.describe(R"code(Returns element-wise minimum of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 3.]]
|
||||
|
||||
broadcast_min(x, y) = [[ 1., 2., 2.],
|
||||
[ 3., 3., 3.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_pow, power)
|
||||
.add_alias("__pow_symbol__")
|
||||
.describe(R"code(Returns element-wise x^y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 1.],
|
||||
[ 2.]]
|
||||
|
||||
broadcast_pow(x, y) = [[ 1., 2., 3. ],
|
||||
[ 16., 25., 36.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_left_shift, left_shift)
|
||||
.add_alias("__left_shift_symbol__")
|
||||
.describe(R"code(Returns element-wise x << y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 1.]]
|
||||
|
||||
broadcast_left_shift(x, y) = [[ 4., 8., 12.],
|
||||
[ 8., 10., 12.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_right_shift, right_shift)
|
||||
.add_alias("__right_shift_symbol__")
|
||||
.describe(R"code(Returns element-wise x >> y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 4., 8., 12.],
|
||||
[ 8., 10., 12.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 1.]]
|
||||
|
||||
broadcast_right_shift(x, y) = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater, greater)
|
||||
.add_alias("__greater_symbol__")
|
||||
.describe(R"code(Returns element-wise x > y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 3.]]
|
||||
|
||||
broadcast_greater(x, y) = [[ 0., 0., 1.],
|
||||
[ 1., 1., 1.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::greater(inputs[0], inputs[1]), out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less, less)
|
||||
.add_alias("__less_symbol__")
|
||||
.describe(R"code(Returns element-wise x < y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 3.]]
|
||||
|
||||
broadcast_less(x, y) = [[ 1., 0., 0.],
|
||||
[ 0., 0., 0.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::less(inputs[0], inputs[1]), out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_equal, equal)
|
||||
.add_alias("__equal_symbol__")
|
||||
.describe(R"code(Returns element-wise x == y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 5.]]
|
||||
|
||||
broadcast_equal(x, y) = [[ 0., 1., 0.],
|
||||
[ 0., 1., 0.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::equal(inputs[0], inputs[1]), out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_not_equal, not_equal)
|
||||
.add_alias("__not_equal_symbol__")
|
||||
.describe(R"code(Returns element-wise x != y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 4.]]
|
||||
|
||||
broadcast_not_equal(x, y) = [[ 1., 0., 1.],
|
||||
[ 0., 1., 1.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::not_equal(inputs[0],
|
||||
inputs[1]),
|
||||
out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater_equal, greater_equal)
|
||||
.add_alias("__greater_equal_symbol__")
|
||||
.describe(R"code(Returns element-wise x >= y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 2.],
|
||||
[ 6.]]
|
||||
|
||||
broadcast_greater_equal(x, y) = [[ 0., 1., 1.],
|
||||
[ 0., 0., 1.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::greater_equal(inputs[0],
|
||||
inputs[1]),
|
||||
out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less_equal, less_equal)
|
||||
.add_alias("__less_equal_symbol__")
|
||||
.describe(R"code(Returns element-wise x <= y of the input arrays with broadcasting.
|
||||
|
||||
Example::
|
||||
|
||||
x = [[ 1., 2., 3.],
|
||||
[ 4., 5., 6.]]
|
||||
|
||||
y = [[ 1.],
|
||||
[ 5.]]
|
||||
|
||||
broadcast_less_equal(x, y) = [[ 1., 0., 0.],
|
||||
[ 1., 1., 0.]]
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
return Array<Tensor>{ topi::cast(topi::less_equal(inputs[0],
|
||||
inputs[1]),
|
||||
out_info[0]->dtype) };
|
||||
}, 11);
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
|
|
|
@ -9,17 +9,23 @@ from nnvm.testing.config import ctx_list
|
|||
|
||||
|
||||
def helper(symbol, inputs, dtype,
|
||||
np_forward, np_backward=None, need_input=True, need_head_grads=True):
|
||||
np_forward, np_backward=None,
|
||||
need_input=True, need_head_grads=True, in_range={}):
|
||||
ishapes = {}
|
||||
input_syms = []
|
||||
np_inputs = {}
|
||||
for (name, shape, s) in inputs:
|
||||
ishapes.update({name: shape})
|
||||
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
|
||||
if name in in_range:
|
||||
np_inputs.update({name: np.random.uniform(size=shape,
|
||||
low=in_range[name][0],
|
||||
high=in_range[name][1]).astype(dtype)})
|
||||
else:
|
||||
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
|
||||
input_syms.append(s)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
|
||||
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, dtype=dtype)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.run(**np_inputs)
|
||||
y_np = np_forward(**np_inputs)
|
||||
|
@ -228,6 +234,49 @@ def test_broadcast():
|
|||
return da, db
|
||||
helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
|
||||
|
||||
y = sym.broadcast_mod(a, b)
|
||||
helper(y, inputs, 'int32',
|
||||
lambda a, b: np.mod(a, b),
|
||||
in_range={'a': (0.001, 100), 'b': (1, 100)})
|
||||
|
||||
y = sym.broadcast_max(a, b)
|
||||
helper(y, inputs, dtype, lambda a, b: np.maximum(a, b))
|
||||
|
||||
y = sym.broadcast_min(a, b)
|
||||
helper(y, inputs, dtype, lambda a, b: np.minimum(a, b))
|
||||
|
||||
y = sym.broadcast_pow(a, b)
|
||||
helper(y, inputs, dtype,
|
||||
lambda a, b: np.power(a, b),
|
||||
in_range={'a': (0.001, 100), 'b': (0.001, 2)})
|
||||
|
||||
y = sym.broadcast_left_shift(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: a << b)
|
||||
|
||||
y = sym.broadcast_right_shift(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: a >> b)
|
||||
|
||||
y = sym.broadcast_greater(a, b)
|
||||
helper(y, inputs, dtype, lambda a, b: np.greater(a, b))
|
||||
|
||||
y = sym.broadcast_less(a, b)
|
||||
helper(y, inputs, dtype, lambda a, b: np.less(a, b))
|
||||
|
||||
y = sym.broadcast_equal(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: np.equal(a, b),
|
||||
in_range={'a': (-2, 2), 'b': (-2, 2)})
|
||||
|
||||
y = sym.broadcast_not_equal(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: np.not_equal(a, b),
|
||||
in_range={'a': (-2, 2), 'b': (-2, 2)})
|
||||
|
||||
y = sym.broadcast_greater_equal(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: np.greater_equal(a, b),
|
||||
in_range={'a': (-3, 3), 'b': (-3, 3)})
|
||||
|
||||
y = sym.broadcast_less_equal(a, b)
|
||||
helper(y, inputs, 'int32', lambda a, b: np.less_equal(a, b),
|
||||
in_range={'a': (-3, 3), 'b': (-3, 3)})
|
||||
|
||||
def test_greater():
|
||||
l = sym.Variable("l")
|
||||
|
|
|
@ -108,6 +108,50 @@ def test_reshape_like():
|
|||
|
||||
np.testing.assert_allclose(ref_shape, tvm_out.shape)
|
||||
|
||||
def _test_power_iteration(x_shape, y_shape):
|
||||
if isinstance(y_shape, int):
|
||||
y_shape = [y_shape]
|
||||
|
||||
x = np.random.uniform(size=x_shape).astype(np.float32)
|
||||
y = np.random.uniform(size=y_shape).astype(np.float32)
|
||||
|
||||
np_res = np.power(x, y).astype(np.float32)
|
||||
|
||||
res = helper.make_node("Pow", ['x', 'y'], ['out'])
|
||||
|
||||
graph = helper.make_graph([res],
|
||||
'power_test',
|
||||
inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
|
||||
helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
|
||||
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))])
|
||||
|
||||
model = helper.make_model(graph, producer_name='power_test')
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
new_sym, params = nnvm.frontend.from_onnx(model)
|
||||
|
||||
input_name = model.graph.input[0].name
|
||||
input_name1 = model.graph.input[1].name
|
||||
shape_dict = {input_name: x.shape, input_name1: y.shape}
|
||||
dtype_dict = {input_name: x.dtype, input_name1: y.dtype}
|
||||
|
||||
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
# set inputs
|
||||
m.set_input(input_name, tvm.nd.array(x))
|
||||
m.set_input(input_name1, tvm.nd.array(y))
|
||||
m.set_input(**params)
|
||||
m.run()
|
||||
# get outputs
|
||||
tvm_out = m.get_output(0, tvm.nd.empty(np_res.shape, np_res.dtype))
|
||||
|
||||
np.testing.assert_allclose(np_res, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_power():
|
||||
_test_power_iteration((1, 3), (1))
|
||||
_test_power_iteration((2, 3), (2, 3))
|
||||
_test_power_iteration((2, 3), (1, 3))
|
||||
|
||||
def test_squeeze():
|
||||
in_shape = (1, 3, 1, 3, 1, 1)
|
||||
out_shape = (3, 3)
|
||||
|
@ -247,6 +291,7 @@ if __name__ == '__main__':
|
|||
verify_resnet18()
|
||||
test_reshape()
|
||||
test_reshape_like()
|
||||
test_power()
|
||||
test_squeeze()
|
||||
test_unsqueeze()
|
||||
test_slice()
|
||||
|
|
Загрузка…
Ссылка в новой задаче