Consistent result of DetectLinearEquation() when an empy vars is passed (#2860)
This commit is contained in:
Родитель
c162e7d662
Коммит
b590c4f225
|
@ -127,25 +127,21 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
|
|||
Expr base = e;
|
||||
Array<Expr> coeff;
|
||||
|
||||
if (0 == vars.size()) {
|
||||
coeff.push_back(make_const(Int(32), 1));
|
||||
} else {
|
||||
for (Var v : vars) {
|
||||
LinearEqEntry ret;
|
||||
if (!LinearEqDetector(v).Detect(base, &ret)) {
|
||||
return Array<Expr>();
|
||||
}
|
||||
coeff.push_back(ret.coeff);
|
||||
base = std::move(ret.base);
|
||||
for (Var v : vars) {
|
||||
LinearEqEntry ret;
|
||||
if (!LinearEqDetector(v).Detect(base, &ret)) {
|
||||
return Array<Expr>();
|
||||
}
|
||||
coeff.push_back(ret.coeff);
|
||||
base = std::move(ret.base);
|
||||
}
|
||||
|
||||
std::unordered_set<const Variable*> vset;
|
||||
for (size_t i = vars.size(); i != 1; --i) {
|
||||
vset.insert(vars[i - 1].get());
|
||||
// The previous coeff contains the variable
|
||||
if (ExprUseVar(coeff[i - 2], vset)) {
|
||||
return Array<Expr>();
|
||||
}
|
||||
std::unordered_set<const Variable*> vset;
|
||||
for (size_t i = vars.size(); i > 1; --i) {
|
||||
vset.insert(vars[i - 1].get());
|
||||
// The previous coeff contains the variable
|
||||
if (ExprUseVar(coeff[i - 2], vset)) {
|
||||
return Array<Expr>();
|
||||
}
|
||||
}
|
||||
coeff.push_back(base);
|
||||
|
|
|
@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator {
|
|||
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
|
||||
using namespace arith;
|
||||
Stmt body = stmt;
|
||||
bool is_single_point_copy = false;
|
||||
|
||||
// strip the loops
|
||||
std::vector<const For*> loops;
|
||||
|
@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator {
|
|||
const Cast* cast = store->value.as<Cast>();
|
||||
const Load* load = store->value.as<Load>();
|
||||
if (0 == loops.size()) {
|
||||
is_single_point_copy = true;
|
||||
CHECK(!has_cond);
|
||||
}
|
||||
// for now only support true condition matching
|
||||
|
@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator {
|
|||
arith::DetectLinearEquation(load->index, loop_vars);
|
||||
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
|
||||
Array<Expr> dst_shape;
|
||||
auto loop_var_size = loop_vars.size();
|
||||
if (is_single_point_copy) {
|
||||
loop_var_size = 1;
|
||||
const size_t loop_var_size = loop_vars.size();
|
||||
if (loop_var_size == 0) {
|
||||
dst_shape.push_back(make_const(Int(32), 1));
|
||||
} else {
|
||||
for (const For* op : loops) {
|
||||
|
@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator {
|
|||
CHECK_EQ(load_strides.size(), loop_var_size + 1);
|
||||
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
|
||||
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
|
||||
if (loop_var_size == 0) {
|
||||
src_strides.push_back(make_const(Int(32), 1));
|
||||
dst_strides.push_back(make_const(Int(32), 1));
|
||||
}
|
||||
Buffer dst = BufferNode::make(
|
||||
Var(store->buffer_var.node_),
|
||||
store->value.type(),
|
||||
|
|
|
@ -20,6 +20,10 @@ def test_basic():
|
|||
m = tvm.arith.DetectLinearEquation(b * 7, [a])
|
||||
assert m[0].value == 0
|
||||
|
||||
m = tvm.arith.DetectLinearEquation(b * 7, [])
|
||||
assert len(m) == 1
|
||||
assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
|
||||
|
||||
def test_multivariate():
|
||||
v = [tvm.var("v%d" % i) for i in range(4)]
|
||||
b = tvm.var("b")
|
||||
|
@ -42,6 +46,10 @@ def test_multivariate():
|
|||
assert(m[0].value == 0)
|
||||
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
|
||||
|
||||
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
|
||||
assert(len(m) == 1)
|
||||
assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_basic()
|
||||
test_multivariate()
|
||||
|
|
Загрузка…
Ссылка в новой задаче