checkin domain
This commit is contained in:
Родитель
bda9581727
Коммит
6819145a0c
|
@ -5,3 +5,4 @@ from .op import *
|
|||
from .expr import Var, const
|
||||
from .expr_util import *
|
||||
from .tensor import Tensor
|
||||
from .domain import RDom, Range
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
|
||||
|
||||
class Range(object):
|
||||
"""Represent a range in one dimension.
|
||||
"""
|
||||
def __init__(self, begin, end=None):
|
||||
if end is None:
|
||||
end = begin
|
||||
begin = _expr.const(0)
|
||||
self.begin = _expr._symbol(begin)
|
||||
self.end = _expr._symbol(end)
|
||||
self.extent = _expr_util.simplify(end - begin)
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (
|
||||
_expr_util.format_str(self.begin),
|
||||
_expr_util.format_str(self.end))
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class RDom(object):
|
||||
"""reduction Domain
|
||||
"""
|
||||
def __init__(self, domain):
|
||||
if isinstance(domain, Range):
|
||||
domain = [domain]
|
||||
self.index = []
|
||||
self.domain = domain
|
||||
for i in range(len(domain)):
|
||||
self.index.append(_expr.Var("rd_index_%d_" % i))
|
||||
|
||||
|
||||
"""Use list of ranges as domain"""
|
||||
Domain = list
|
|
@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
|
|||
self.src = _symbol(src)
|
||||
|
||||
def children(self):
|
||||
return (self.src)
|
||||
return (self.src,)
|
||||
|
||||
|
||||
class ReduceExpr(Expr):
|
||||
def __init__(self, op, src, rdom):
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.rdom = rdom
|
||||
|
||||
def children(self):
|
||||
return (self.src,)
|
||||
|
||||
|
||||
class TensorReadExpr(Expr):
|
||||
"""Tensor read expression, tensor[indices]"""
|
||||
def __init__(self, tensor, indices):
|
||||
self.tensor = tensor
|
||||
self.indices = indices
|
||||
|
||||
def children(self):
|
||||
return self.indices
|
||||
|
||||
|
||||
def const(value):
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import op as _op
|
||||
from . import tensor as _tensor
|
||||
|
||||
def expr_with_new_children(e, children):
|
||||
"""Returns same expr as e but with new children
|
||||
|
@ -50,10 +49,27 @@ def transform(e, f):
|
|||
result : return value of f
|
||||
The final result of transformation.
|
||||
"""
|
||||
assert isinstance(e, _expr.Expr)
|
||||
if not isinstance(e, _expr.Expr):
|
||||
raise TypeError("Cannot handle type %s" % type(e))
|
||||
return f(e , [transform(c, f) for c in e.children()])
|
||||
|
||||
|
||||
def visit(e, f):
|
||||
"""Apply f to each element of e
|
||||
|
||||
Parameters
|
||||
----------
|
||||
e : Expr
|
||||
The input expression.
|
||||
|
||||
f : function with signiture (e)
|
||||
"""
|
||||
assert isinstance(e, _expr.Expr)
|
||||
for c in e.children():
|
||||
visit(c, f)
|
||||
f(e)
|
||||
|
||||
|
||||
def format_str(expr):
|
||||
"""change expression to string.
|
||||
|
||||
|
@ -76,12 +92,15 @@ def format_str(expr):
|
|||
return str(e.value)
|
||||
elif isinstance(e, _expr.Var):
|
||||
return e.name
|
||||
elif isinstance(e, _tensor.TensorReadExpr):
|
||||
elif isinstance(e, _expr.TensorReadExpr):
|
||||
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
|
||||
elif isinstance(e, _expr.ReduceExpr):
|
||||
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
|
||||
else:
|
||||
raise TypeError("Do not know how to handle type " + str(type(e)))
|
||||
return transform(expr, make_str)
|
||||
|
||||
|
||||
def simplify(expr):
|
||||
"""simplify expression
|
||||
|
||||
|
|
|
@ -22,15 +22,20 @@ def canonical_to_expr(c):
|
|||
else:
|
||||
return _expr.const(0)
|
||||
|
||||
|
||||
class BinaryOp(object):
|
||||
"""Base class of binary operator"""
|
||||
def __call__(self, lhs, rhs):
|
||||
return _expr.BinaryOpExpr(self, lhs, rhs)
|
||||
|
||||
|
||||
class AddOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s + %s)' % (lhs, rhs)
|
||||
|
||||
def format_reduce_str(self, src, rd):
|
||||
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
lhs = lhs.copy()
|
||||
for k, v in rhs.items():
|
||||
|
@ -40,6 +45,7 @@ class AddOp(BinaryOp):
|
|||
lhs[k] = v
|
||||
return lhs
|
||||
|
||||
|
||||
class SubOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s - %s)' % (lhs, rhs)
|
||||
|
@ -53,6 +59,7 @@ class SubOp(BinaryOp):
|
|||
lhs[k] = -v
|
||||
return lhs
|
||||
|
||||
|
||||
class MulOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s * %s)' % (lhs, rhs)
|
||||
|
@ -72,6 +79,7 @@ class MulOp(BinaryOp):
|
|||
return rhs
|
||||
return {elhs * erhs: 1}
|
||||
|
||||
|
||||
class DivOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s / %s)' % (lhs, rhs)
|
||||
|
@ -86,6 +94,7 @@ class DivOp(BinaryOp):
|
|||
elhs = canonical_to_expr(lhs)
|
||||
return {elhs / erhs: 1}
|
||||
|
||||
|
||||
class MaxOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return 'max(%s, %s)' % (lhs, rhs)
|
||||
|
@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
|
|||
return lhs if ediff.value >= 0 else rhs
|
||||
return {MaxOp()(lhs, rhs): 1}
|
||||
|
||||
|
||||
class MinOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return 'min(%s, %s)' % (lhs, rhs)
|
||||
|
@ -120,3 +130,16 @@ _expr.__addop__ = add
|
|||
_expr.__subop__ = sub
|
||||
_expr.__mulop__ = mul
|
||||
_expr.__divop__ = div
|
||||
|
||||
|
||||
def reduce_sum(expr, rdom):
|
||||
return _expr.ReduceExpr(add, expr, rdom)
|
||||
|
||||
def reduce_prod(expr, rdom):
|
||||
return _expr.ReduceExpr(mul, expr, rdom)
|
||||
|
||||
def reduce_min(expr, rdom):
|
||||
return _expr.ReduceExpr(min, expr, rdom)
|
||||
|
||||
def reduce_max(expr, rdom):
|
||||
return _expr.ReduceExpr(max, expr, rdom)
|
||||
|
|
|
@ -1,33 +1,70 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
|
||||
class TensorReadExpr(_expr.Expr):
|
||||
def __init__(self, tensor, indices):
|
||||
self.tensor = tensor
|
||||
self.indices = indices
|
||||
|
||||
def children(self):
|
||||
return self.indices
|
||||
from . import expr_util as _expr_util
|
||||
|
||||
|
||||
class Tensor(object):
|
||||
def __init__(self, ndim, fcompute=None, name=None):
|
||||
def __init__(self, ndim, fcompute=None, name=None, shape=None):
|
||||
self.ndim = ndim
|
||||
if fcompute:
|
||||
arg_names = fcompute.func_code.co_varnames
|
||||
assert(len(arg_names) == ndim)
|
||||
self.dim_index = [_expr.Var(n) for n in arg_names]
|
||||
self.expr = fcompute(*self.dim_index)
|
||||
if shape is None:
|
||||
raise ValueError("argument shape need to be given for intermediate tensor")
|
||||
self.shape = shape
|
||||
else:
|
||||
self.expr = None
|
||||
self.dim_index = None
|
||||
shape_name = '_shape'
|
||||
if name: shape_name = name + shape_name
|
||||
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
|
||||
self.shape = shape if shape else tuple(
|
||||
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
|
||||
|
||||
self.name = name if name else "TensorObj"
|
||||
self.inputs = None
|
||||
|
||||
def __call__(self, *indices):
|
||||
if len(indices) != self.ndim:
|
||||
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
|
||||
return TensorReadExpr(self, indices)
|
||||
return _expr.TensorReadExpr(self, indices)
|
||||
|
||||
def input_tensors(self):
|
||||
"""List of input tensors to this tensor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
inputs : list of input tensors
|
||||
"""
|
||||
if self.inputs is not None:
|
||||
return self.inputs
|
||||
self.inputs = []
|
||||
if self.expr:
|
||||
def collect(e):
|
||||
if isinstance(e, _expr.TensorReadExpr):
|
||||
self.inputs.append(e.tensor)
|
||||
_expr_util.visit(self.expr, collect)
|
||||
return self.inputs
|
||||
|
||||
def infer_input_domains(self, out_domain):
|
||||
"""Infer the input domains of each domain given output domains
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_domain : list of Range
|
||||
Domain of each dimension.
|
||||
|
||||
Returns
|
||||
-------
|
||||
in_domains: dict Tensor->Domain
|
||||
"""
|
||||
assert self.expr
|
||||
assert len(out_domain) == len(self.dim_index)
|
||||
index_domains = {
|
||||
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
|
||||
}
|
||||
def collect(e):
|
||||
if isinstance(e, _expr.TensorReadExpr):
|
||||
self.inputs.append(e.tensor)
|
||||
_expr_util.visit(self.expr, collect)
|
||||
|
|
|
@ -3,8 +3,27 @@ import tvm
|
|||
def test_tensor():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
B = tvm.Tensor(2, name='B')
|
||||
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k))
|
||||
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
|
||||
shape=(A.shape[0], B.shape[0], A.shape[1]))
|
||||
print(tvm.format_str(T.expr))
|
||||
|
||||
def test_tensor_inputs():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
B = tvm.Tensor(2, name='B')
|
||||
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
|
||||
shape=(A.shape[0], B.shape[0], A.shape[1]))
|
||||
assert(T.input_tensors() == [A, B])
|
||||
|
||||
def test_tensor_reduce():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
B = tvm.Tensor(2, name='B')
|
||||
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
|
||||
shape=(A.shape[0], B.shape[0], A.shape[1]))
|
||||
rd = tvm.RDom(tvm.Range(A.shape[1]))
|
||||
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
|
||||
shape=(A.shape[0], B.shape[0]))
|
||||
print(tvm.format_str(C.expr))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tensor()
|
||||
test_tensor_inputs()
|
||||
test_tensor_reduce()
|
||||
|
|
Загрузка…
Ссылка в новой задаче