diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index aa7c5b51..b5d4429e 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, Expr x_factor, Expr y_factor) { // NOLINT(*) split(x_parent, p_x_outer, p_x_inner, x_factor); split(y_parent, p_y_outer, p_y_inner, y_factor); - reorder(Array({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer})); + reorder(Array({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); return *this; } diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 329d118e..7d3f25d6 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { {"shared", 1}, {"local", 2} }; - - return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag); + static std::unordered_map thread_tag_rank{ + {"gridIdx.x", 0}, + {"gridIdx.y", 0}, + {"gridIdx.z", 0}, + {"threadIdx.x", 1}, + {"threadIdx.y", 1}, + {"threadIdx.z", 1} + }; + return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag); } void InferBound( diff --git a/src/schedule/int_set.cc b/src/schedule/int_set.cc index b5e632e6..6a770b32 100644 --- a/src/schedule/int_set.cc +++ b/src/schedule/int_set.cc @@ -220,6 +220,8 @@ void PassUp(const SplitNode* s, *parent = IntSet::make_range(dom_map.at(s->parent)); return; } + CHECK(outer.defined()); + CHECK(inner.defined()); // copy construct auto n = std::make_shared(*(inner.operator->())); @@ -228,7 +230,6 @@ void PassUp(const SplitNode* s, n->base = Range::make_with_min_extent( AsNumber(outer) * s->factor + inner->base->min, inner->base->extent); - *parent = IntSet(n); } else { // default use all domains in the data. n->domain.push_back(outer->base); @@ -238,6 +239,7 @@ void PassUp(const SplitNode* s, n->stride.push_back(outer->stride[i] * s->factor); } } + *parent = IntSet(n); } void PassUp(const FuseNode* s, diff --git a/tests/python/test_bound_inference.py b/tests/python/test_bound_inference.py index 6de6e44e..fb169e60 100644 --- a/tests/python/test_bound_inference.py +++ b/tests/python/test_bound_inference.py @@ -1,6 +1,6 @@ import tvm -def test_bound_inference(): +def test_bound1(): m = tvm.Var('m') l = tvm.Var('l') A = tvm.placeholder((m, l), name='A') @@ -12,8 +12,42 @@ def test_bound_inference(): sA1.compute_at(sA2, xo) bounds = tvm.schedule.InferBound(sA2) assert isinstance(bounds, tvm.collections.Map) - print(bounds[A1.op.dim_var[0]]) - print(bounds[A1.op.dim_var[1]]) + assert(bounds[A1.op.dim_var[0]].extent.value == 8) + +def test_bound2(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + sA1 = tvm.Schedule(A1.op) + sA2 = tvm.Schedule(A2.op) + xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8) + sA1.compute_at(sA2, yo) + bounds = tvm.schedule.InferBound(sA2) + assert isinstance(bounds, tvm.collections.Map) + assert(bounds[A1.op.dim_var[0]].extent.value == 8) + assert(bounds[A1.op.dim_var[1]].extent.value == 8) + +def test_bound3(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + sA1 = tvm.Schedule(A1.op, scope="shared") + sA2 = tvm.Schedule(A2.op) + thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x") + xo, xi = sA2.split(A2.op.dim_var[0], 32) + xi0, xi1 = sA2.split(xi, outer=thread_x) + yo, yi = sA2.split(A2.op.dim_var[1], 16) + sA2.reorder(xo, xi0, yo, xi1, yi) + sA1.compute_at(sA2, yo) + + bounds = tvm.schedule.InferBound(sA2) + assert isinstance(bounds, tvm.collections.Map) + assert(bounds[A1.op.dim_var[0]].extent.value==32) + assert(bounds[A1.op.dim_var[1]].extent.value==16) def test_create_read_graph(): @@ -31,5 +65,7 @@ def test_create_read_graph(): if __name__ == "__main__": - test_bound_inference() + test_bound3() + test_bound1() + test_bound2() test_create_read_graph() diff --git a/tests/python/test_schedule.py b/tests/python/test_schedule.py index b08b5f6f..efdeab9a 100644 --- a/tests/python/test_schedule.py +++ b/tests/python/test_schedule.py @@ -34,6 +34,16 @@ def test_reorder(): sch_T.reorder(*order) assert tuple(sch_T.leaf_iter_vars) == order +def test_split(): + m = tvm.Var('m') + A = tvm.placeholder((m,), name='A') + T = tvm.compute((m,), lambda i: A[i]) + + sT = tvm.Schedule(T.op) + xo, xi = sT.split(T.op.dim_var[0], factor=10) + assert tuple(sT.leaf_iter_vars) == (xo, xi) + + def test_tile(): m = tvm.Var('m') n = tvm.Var('n') @@ -42,9 +52,10 @@ def test_tile(): sch_T = tvm.Schedule(T.op, scope="shared") xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5) - assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo) + assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi) if __name__ == "__main__": test_schedule_create() test_reorder() test_tile() + test_split()