Fix dependency problem of reducer condition (#712) (#721)

* Make duplicated function name checker working

* Fix dependency checking problem for reducer condition (#712); add test

* Fix dependency checking problem for reducer condition (#712); add test

* Specify R to be computed inlined
This commit is contained in:
Cody Hao Yu 2017-12-22 23:20:16 -08:00 коммит произвёл Tianqi Chen
Родитель aa55b1a9d0
Коммит 9e01367dce
3 изменённых файлов: 5 добавлений и 1 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -138,6 +138,7 @@ xcuserdata/
*.xcscmblueprint
.DS_Store
tags
cscope*
# vim temporary files
*.swp

Просмотреть файл

@ -134,6 +134,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
VisitArray(op->source, this);
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Cast* op) {

Просмотреть файл

@ -7,8 +7,9 @@ def test_reduce_prims():
n = tvm.var('n')
m = tvm.var('m')
A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B')
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
@ -16,6 +17,7 @@ def test_reduce_prims():
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[R].compute_inline()
# one line to build the function.
def check_device(device, host="stackvm"):