[Relay] Add logical operators (#2743)
This commit is contained in:
Родитель
695647db94
Коммит
2239508be4
|
@ -366,7 +366,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and)
|
|||
.describe(R"code(Elementwise compute the logical AND
|
||||
|
||||
)code")
|
||||
.set_support_level(1)
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
|
@ -378,7 +378,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or)
|
|||
.describe(R"code(Elementwise compute the logical OR
|
||||
|
||||
)code")
|
||||
.set_support_level(1)
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
|
@ -413,7 +413,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not)
|
|||
.describe(R"code(Elementwise compute the logical NOT
|
||||
|
||||
)code" NNVM_ADD_FILELINE)
|
||||
.set_support_level(3)
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
|
|
|
@ -849,6 +849,11 @@ def _softmax():
|
|||
transforms={'axis': ('axis', 1)})([inputs[0]], attr)
|
||||
return _impl
|
||||
|
||||
def _logical(name):
|
||||
def _impl(inputs, attr, params):
|
||||
return AttrCvt(op_name=name)(inputs, attr)
|
||||
return _impl
|
||||
|
||||
# compatible operators that do NOT require any conversion.
|
||||
_identity_list = []
|
||||
|
||||
|
@ -909,6 +914,9 @@ _convert_map = {
|
|||
'Transpose' : _transpose(),
|
||||
'Tanh' : AttrCvt('tanh'),
|
||||
'Mean' : _mean(),
|
||||
'LogicalAnd' : _logical('logical_and'),
|
||||
'LogicalOr' : _logical('logical_or'),
|
||||
'LogicalNot' : _logical('logical_not'),
|
||||
'Less' : _broadcast('less'),
|
||||
'Greater' : _broadcast('greater'),
|
||||
'LessEqual' : _broadcast('less_equal'),
|
||||
|
|
|
@ -18,6 +18,7 @@ register_schedule("trunc", schedule_broadcast)
|
|||
register_schedule("round", schedule_broadcast)
|
||||
register_schedule("abs", schedule_broadcast)
|
||||
register_schedule("tanh", schedule_broadcast)
|
||||
register_schedule("logical_not", schedule_broadcast)
|
||||
register_schedule("negative", schedule_broadcast)
|
||||
register_schedule("copy", schedule_broadcast)
|
||||
|
||||
|
@ -27,6 +28,8 @@ register_schedule("multiply", schedule_broadcast)
|
|||
register_schedule("divide", schedule_broadcast)
|
||||
register_schedule("power", schedule_injective)
|
||||
register_schedule("mod", schedule_broadcast)
|
||||
register_schedule("logical_and", schedule_broadcast)
|
||||
register_schedule("logical_or", schedule_broadcast)
|
||||
register_schedule("equal", schedule_broadcast)
|
||||
register_schedule("not_equal", schedule_broadcast)
|
||||
register_schedule("less", schedule_broadcast)
|
||||
|
|
|
@ -191,6 +191,22 @@ def negative(data):
|
|||
return _make.negative(data)
|
||||
|
||||
|
||||
def logical_not(data):
|
||||
"""Compute element-wise logical not of data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : relay.Expr
|
||||
The input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.logical_not(data)
|
||||
|
||||
|
||||
def add(lhs, rhs):
|
||||
"""Addition with numpy-style broadcasting.
|
||||
|
||||
|
@ -307,6 +323,42 @@ def mod(lhs, rhs):
|
|||
return _make.mod(lhs, rhs)
|
||||
|
||||
|
||||
def logical_and(lhs, rhs):
|
||||
"""logical AND with numpy-style broadcasting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lhs : relay.Expr
|
||||
The left hand side input data
|
||||
rhs : relay.Expr
|
||||
The right hand side input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.logical_and(lhs, rhs)
|
||||
|
||||
|
||||
def logical_or(lhs, rhs):
|
||||
"""logical OR with numpy-style broadcasting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lhs : relay.Expr
|
||||
The left hand side input data
|
||||
rhs : relay.Expr
|
||||
The right hand side input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.logical_or(lhs, rhs)
|
||||
|
||||
|
||||
def equal(lhs, rhs):
|
||||
"""Broadcasted elementwise test for (lhs == rhs).
|
||||
|
||||
|
|
|
@ -82,6 +82,18 @@ RELAY_REGISTER_BINARY_OP("mod")
|
|||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
|
||||
|
||||
|
||||
RELAY_REGISTER_BINARY_OP("logical_and")
|
||||
.describe("Elementwise logical AND with broadcasting")
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
|
||||
|
||||
|
||||
RELAY_REGISTER_BINARY_OP("logical_or")
|
||||
.describe("Elementwise logical OR with broadcasting")
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
|
||||
|
||||
|
||||
RELAY_REGISTER_CMP_OP("equal")
|
||||
.describe("Elementwise equal compare with broadcasting")
|
||||
.set_support_level(4)
|
||||
|
|
|
@ -178,5 +178,16 @@ RELAY_REGISTER_UNARY_OP("negative")
|
|||
.set_support_level(3)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
|
||||
|
||||
|
||||
RELAY_REGISTER_UNARY_OP("logical_not")
|
||||
.describe(R"code(Returns the logical inverse of input array, computed element-wise.
|
||||
|
||||
.. math::
|
||||
~(x)
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_support_level(4)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -682,6 +682,49 @@ def test_forward_pad():
|
|||
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
|
||||
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)
|
||||
|
||||
#######################################################################
|
||||
# Logical operators
|
||||
# --------------------
|
||||
def test_logical_and():
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
|
||||
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
|
||||
out = tf.logical_and(in1, in2, name='out')
|
||||
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
|
||||
|
||||
def test_logical_or():
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
|
||||
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
|
||||
out = tf.logical_or(in1, in2, name='out')
|
||||
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
|
||||
|
||||
def test_logical_xor():
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
|
||||
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
|
||||
out = tf.logical_xor(in1, in2, name='out')
|
||||
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
|
||||
|
||||
def test_logical_not():
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
|
||||
out = tf.logical_not(in1, name='out')
|
||||
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
|
||||
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
|
||||
|
||||
def test_forward_logical():
|
||||
test_logical_and()
|
||||
test_logical_or()
|
||||
test_logical_xor()
|
||||
test_logical_not()
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Inception V3
|
||||
|
@ -1109,5 +1152,4 @@ if __name__ == '__main__':
|
|||
|
||||
# Relational ops
|
||||
test_forward_rel_ops()
|
||||
|
||||
|
||||
test_forward_logical()
|
||||
|
|
Загрузка…
Ссылка в новой задаче