Fix Tile, add a few more test cases on bound inference
This commit is contained in:
Родитель
0f693212ca
Коммит
ff26cd68d0
|
@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
|
|||
Expr x_factor, Expr y_factor) { // NOLINT(*)
|
||||
split(x_parent, p_x_outer, p_x_inner, x_factor);
|
||||
split(y_parent, p_y_outer, p_y_inner, y_factor);
|
||||
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
|
||||
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
|
|||
{"shared", 1},
|
||||
{"local", 2}
|
||||
};
|
||||
|
||||
return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag);
|
||||
static std::unordered_map<std::string, int> thread_tag_rank{
|
||||
{"gridIdx.x", 0},
|
||||
{"gridIdx.y", 0},
|
||||
{"gridIdx.z", 0},
|
||||
{"threadIdx.x", 1},
|
||||
{"threadIdx.y", 1},
|
||||
{"threadIdx.z", 1}
|
||||
};
|
||||
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
|
||||
}
|
||||
|
||||
void InferBound(
|
||||
|
|
|
@ -220,6 +220,8 @@ void PassUp(const SplitNode* s,
|
|||
*parent = IntSet::make_range(dom_map.at(s->parent));
|
||||
return;
|
||||
}
|
||||
CHECK(outer.defined());
|
||||
CHECK(inner.defined());
|
||||
// copy construct
|
||||
auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
|
||||
|
||||
|
@ -228,7 +230,6 @@ void PassUp(const SplitNode* s,
|
|||
n->base = Range::make_with_min_extent(
|
||||
AsNumber(outer) * s->factor + inner->base->min,
|
||||
inner->base->extent);
|
||||
*parent = IntSet(n);
|
||||
} else {
|
||||
// default use all domains in the data.
|
||||
n->domain.push_back(outer->base);
|
||||
|
@ -238,6 +239,7 @@ void PassUp(const SplitNode* s,
|
|||
n->stride.push_back(outer->stride[i] * s->factor);
|
||||
}
|
||||
}
|
||||
*parent = IntSet(n);
|
||||
}
|
||||
|
||||
void PassUp(const FuseNode* s,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import tvm
|
||||
|
||||
def test_bound_inference():
|
||||
def test_bound1():
|
||||
m = tvm.Var('m')
|
||||
l = tvm.Var('l')
|
||||
A = tvm.placeholder((m, l), name='A')
|
||||
|
@ -12,8 +12,42 @@ def test_bound_inference():
|
|||
sA1.compute_at(sA2, xo)
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
print(bounds[A1.op.dim_var[0]])
|
||||
print(bounds[A1.op.dim_var[1]])
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
|
||||
|
||||
def test_bound2():
|
||||
m = tvm.Var('m')
|
||||
l = tvm.Var('l')
|
||||
A = tvm.placeholder((m, l), name='A')
|
||||
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
|
||||
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
|
||||
sA1 = tvm.Schedule(A1.op)
|
||||
sA2 = tvm.Schedule(A2.op)
|
||||
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
|
||||
sA1.compute_at(sA2, yo)
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
|
||||
assert(bounds[A1.op.dim_var[1]].extent.value == 8)
|
||||
|
||||
def test_bound3():
|
||||
m = tvm.Var('m')
|
||||
l = tvm.Var('l')
|
||||
A = tvm.placeholder((m, l), name='A')
|
||||
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
|
||||
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
|
||||
sA1 = tvm.Schedule(A1.op, scope="shared")
|
||||
sA2 = tvm.Schedule(A2.op)
|
||||
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
|
||||
xo, xi = sA2.split(A2.op.dim_var[0], 32)
|
||||
xi0, xi1 = sA2.split(xi, outer=thread_x)
|
||||
yo, yi = sA2.split(A2.op.dim_var[1], 16)
|
||||
sA2.reorder(xo, xi0, yo, xi1, yi)
|
||||
sA1.compute_at(sA2, yo)
|
||||
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value==32)
|
||||
assert(bounds[A1.op.dim_var[1]].extent.value==16)
|
||||
|
||||
|
||||
def test_create_read_graph():
|
||||
|
@ -31,5 +65,7 @@ def test_create_read_graph():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bound_inference()
|
||||
test_bound3()
|
||||
test_bound1()
|
||||
test_bound2()
|
||||
test_create_read_graph()
|
||||
|
|
|
@ -34,6 +34,16 @@ def test_reorder():
|
|||
sch_T.reorder(*order)
|
||||
assert tuple(sch_T.leaf_iter_vars) == order
|
||||
|
||||
def test_split():
|
||||
m = tvm.Var('m')
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
T = tvm.compute((m,), lambda i: A[i])
|
||||
|
||||
sT = tvm.Schedule(T.op)
|
||||
xo, xi = sT.split(T.op.dim_var[0], factor=10)
|
||||
assert tuple(sT.leaf_iter_vars) == (xo, xi)
|
||||
|
||||
|
||||
def test_tile():
|
||||
m = tvm.Var('m')
|
||||
n = tvm.Var('n')
|
||||
|
@ -42,9 +52,10 @@ def test_tile():
|
|||
|
||||
sch_T = tvm.Schedule(T.op, scope="shared")
|
||||
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
|
||||
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)
|
||||
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_schedule_create()
|
||||
test_reorder()
|
||||
test_tile()
|
||||
test_split()
|
||||
|
|
Загрузка…
Ссылка в новой задаче