Fix inconsistent python/cpp API behavior for if_then_else, power (#3829)

* fix inconsistent python/cpp APIs for if_then_else

* fix error message

* fix power consistency

* fix

* fix bug

* add test
This commit is contained in:
Xingjian Shi 2019-08-26 11:31:10 -07:00 коммит произвёл Yao Wang
Родитель 92b6ca7127
Коммит 283b0c3f7b
5 изменённых файлов: 102 добавлений и 17 удалений

Просмотреть файл

@ -419,7 +419,7 @@ def power(x, y):
z : Expr
The result.
"""
return call_pure_intrin(x.dtype, "pow", x, y)
return _make._OpPow(convert(x), convert(y))
def popcount(x):
@ -482,12 +482,7 @@ def if_then_else(cond, t, f):
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
t = convert(t)
f = convert(f)
cond = convert(cond)
if cond.dtype != "bool":
raise TypeError("The condition's data type has to be bool")
return call_pure_intrin(t.dtype, "tvm_if_then_else", cond, t, f)
return _make._OpIfThenElse(convert(cond), convert(t), convert(f))
# Intrinsic rule related code

Просмотреть файл

@ -196,6 +196,7 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
@ -211,6 +212,10 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_API("make._OpIfThenElse")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) {
return if_then_else(cond, true_value, false_value);
});
} // namespace ir
} // namespace tvm

Просмотреть файл

@ -238,7 +238,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm;
using ir::UIntImm;
CHECK(cond.type() == Bool(1))
<< "if_then_else only accept a single condition";
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) {
if (op->value != 0) {

Просмотреть файл

@ -27,7 +27,7 @@ def test_array_save_load_json():
a = tvm.convert([1,2,3])
json_str = tvm.save_json(a)
a_loaded = tvm.load_json(json_str)
assert(a[1].value == 2)
assert(a_loaded[1].value == 2)
def test_map():

Просмотреть файл

@ -16,6 +16,15 @@
# under the License.
import tvm
def check_throws(f):
try:
f()
except tvm.TVMError:
pass
else:
raise AssertionError("Should have raised an exception but didn't.")
def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x, "int32") for x in args])
@ -47,14 +56,6 @@ def test_const_fold2():
assert isinstance((1 / x), tvm.expr.Div)
def test_const_fold3():
def check_throws(f):
try:
f()
except tvm.TVMError:
pass
else:
raise AssertionError("Should have raised an exception but didn't.")
# Test that using ints with logic operations is forbidden
x = tvm.var("x")
for val in [0, 1]:
@ -100,8 +101,92 @@ def test_const_fold4():
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
def test_binary_dtype_match():
def verify_general_dtype_support(f, is_conditional=False):
rules = [[('bool', 'int32'), 'int32'],
[('int32', 'float32'), 'float32'],
[('int32', 'int64'), 'int64'],
[('uint32', 'int32'), 'int32']]
for (lhs_dtype, rhs_dtype), out_dtype in rules:
lhs = tvm.var('lhs', dtype=lhs_dtype)
rhs = tvm.var('rhs', dtype=rhs_dtype)
out = f(lhs, rhs)
if not is_conditional:
assert out.dtype == out_dtype
else:
assert out.dtype == 'bool'
if hasattr(out, 'a'):
assert out.a.dtype == out_dtype
assert out.b.dtype == out_dtype
elif hasattr(out, 'args'):
# CallOp
assert out.args[0].dtype == out_dtype
assert out.args[1].dtype == out_dtype
else:
raise ValueError('Unknown binary op format!')
def verify_callop_float_only(f):
for lhs_dtype in ['int32', 'float32', 'float64']:
for rhs_dtype in ['int32', 'float32', 'float64']:
lhs = tvm.var('lhs', dtype=lhs_dtype)
rhs = tvm.var('rhs', dtype=rhs_dtype)
if 'float' not in lhs_dtype and 'float' not in rhs_dtype:
check_throws(lambda: f(lhs, rhs))
elif 'float' in lhs_dtype and 'float' in rhs_dtype and lhs_dtype != rhs_dtype:
check_throws(lambda: f(lhs, rhs))
elif 'float' in lhs_dtype:
out = f(lhs, rhs)
assert out.dtype == lhs_dtype
assert out.args[0].dtype == lhs_dtype
assert out.args[1].dtype == lhs_dtype
else:
out = f(lhs, rhs)
assert out.dtype == rhs_dtype
assert out.args[0].dtype == rhs_dtype
assert out.args[1].dtype == rhs_dtype
verify_general_dtype_support(lambda a, b: a + b)
verify_general_dtype_support(lambda a, b: a * b)
verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True)
verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
verify_callop_float_only(lambda a, b: tvm.power(a, b))
def test_if_then_else():
cases = [[(tvm.var('cond', dtype='bool'), 'bool', 'int32'), 'int32'],
[(True, 'int32', 'float32'), 'float32'],
[(False, 'int32', 'int64'), 'int64'],
[(tvm.var('cond', dtype='bool'), 'uint32', 'int32'), 'int32'],
[(tvm.var('cond', dtype='int32'), 'uint32', 'int32'), 'int32']]
for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
lhs = tvm.var('lhs', dtype=lhs_dtype)
rhs = tvm.var('rhs', dtype=rhs_dtype)
if cond is True or cond is False:
out = tvm.if_then_else(cond, lhs, rhs)
out2 = tvm.if_then_else(not cond, rhs, lhs)
out3 = tvm.if_then_else(not cond, lhs, rhs)
assert tvm.ir_pass.Equal(out, out2) == 1
if cond:
assert tvm.ir_pass.Equal(out, lhs.astype(out_dtype)) == 1
assert tvm.ir_pass.Equal(out3, rhs.astype(out_dtype)) == 1
else:
assert tvm.ir_pass.Equal(out, rhs.astype(out_dtype)) == 1
assert tvm.ir_pass.Equal(out3, lhs.astype(out_dtype)) == 1
elif cond.dtype == 'bool':
out = tvm.if_then_else(cond, lhs, rhs)
assert out.dtype == out_dtype
assert out.args[1].dtype == out_dtype
assert out.args[2].dtype == out_dtype
elif cond.dtype != 'bool':
check_throws(lambda: tvm.if_then_else(cond, lhs, rhs))
else:
raise ValueError('Unknown combinations')
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
test_binary_dtype_match()
test_if_then_else()