Merge branch 'master' of ssh://github.com/tqchen/tvm
This commit is contained in:
Коммит
05e871d4b4
|
@ -6,3 +6,4 @@ from .expr import Var, const
|
|||
from .expr_util import *
|
||||
from .tensor import Tensor
|
||||
from .domain import RDom, Range, infer_range
|
||||
from .split import Split
|
||||
|
|
|
@ -17,7 +17,7 @@ class Range(object):
|
|||
self.extent = _expr_util.simplify(end - begin)
|
||||
|
||||
def is_value(self):
|
||||
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1
|
||||
return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (
|
||||
|
|
|
@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
|
|||
def canonical_to_expr(c):
|
||||
elements = []
|
||||
for k, v in sorted(c.items()):
|
||||
if k == constant_canonical_key:
|
||||
if k == constant_canonical_key and v != 0:
|
||||
elements.append(_expr.const(v))
|
||||
elif v == 0:
|
||||
continue
|
||||
|
@ -87,7 +87,7 @@ class DivOp(BinaryOp):
|
|||
if isinstance(erhs, _expr.ConstExpr):
|
||||
lhs = lhs.copy()
|
||||
for k, v in lhs.items():
|
||||
lhs[k] /= erhs.value
|
||||
lhs[k] /= float(erhs.value)
|
||||
return lhs
|
||||
elhs = canonical_to_expr(lhs)
|
||||
return {elhs / erhs: 1}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import domain as _dom
|
||||
from . import tensor as _tensor
|
||||
|
||||
|
||||
class Split(object):
|
||||
def __init__(self, dim, factor):
|
||||
self.dim = dim
|
||||
self.factor = factor
|
||||
self.loop_index = _expr.Var('loop_index_%d_' % dim)
|
||||
|
||||
def infer_inner_domain(self, domain):
|
||||
if isinstance(domain, _dom.RDom):
|
||||
domain = domain.domain
|
||||
assert self.dim < len(domain)
|
||||
inner_domain = domain[:]
|
||||
dim_out_range = domain[self.dim]
|
||||
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
|
||||
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
|
||||
return inner_domain
|
||||
|
|
@ -25,7 +25,6 @@ class Tensor(object):
|
|||
|
||||
self.name = name if name else "TensorObj"
|
||||
self.inputs = None
|
||||
self.rdom = None
|
||||
|
||||
def __call__(self, *indices):
|
||||
if len(indices) != self.ndim:
|
||||
|
|
|
@ -26,6 +26,6 @@ def test_simplify():
|
|||
assert tvm.format_str(tvm.simplify(e4)) == '0'
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simplify()
|
||||
test_basic()
|
||||
test_bind()
|
||||
test_simplify()
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import tvm
|
||||
|
||||
def test_split_dom_infer():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
rd = tvm.RDom(tvm.Range(A.shape[1]))
|
||||
split1 = tvm.Split(0, 64)
|
||||
split2 = tvm.Split(1, 64)
|
||||
split3 = tvm.Split(0, 8)
|
||||
dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
|
||||
dom1 = split1.infer_inner_domain(dom)
|
||||
dom2 = split2.infer_inner_domain(dom1)
|
||||
dom3 = split3.infer_inner_domain(dom2)
|
||||
dom4 = split3.infer_inner_domain(rd)
|
||||
i1 = split1.loop_index.name
|
||||
i2 = split2.loop_index.name
|
||||
i3 = split3.loop_index.name
|
||||
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
|
||||
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
|
||||
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
|
||||
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_split_dom_infer()
|
Загрузка…
Ссылка в новой задаче