checked buffer / schedule / code gen

This commit is contained in:
Haichen Shen 2016-10-19 15:56:01 -07:00
Родитель d3ee03ebf4
Коммит de2be97ec3
13 изменённых файлов: 367 добавлений и 28 удалений

Просмотреть файл

@ -5,5 +5,7 @@ from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range, infer_range
from .domain import Range, RDom, infer_range
from .split import Split
from .buffer import Scope, Buffer
from .schedule import Schedule

76
python/tvm/buffer.py Normal file
Просмотреть файл

@ -0,0 +1,76 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import var_name as _name
def enum(*sequential, **named):
enums = dict(zip(sequential, range(len(sequential))), **named)
return type('Enum', (), enums)
"""Scope defines the scope of a buffer
Types
-----
Thread : thread private buffer (registers)
Shared : shared buffer within a thread block (shared memory)
Global : buffer in the global GPU RAM
"""
Scope = enum('Thread', 'Shared', 'Global')
class Buffer(object):
def __init__(self, scope, name=None):
self.scope = scope
buf_name = 'Buffer_'
if name: buf_name += name
self.name = _name.NameManager.current.get(buf_name)
self.shape = []
self.offset_index = []
def reshape(self, domain):
for r in domain:
self.shape.append(r.extent)
self.offset_index.append(r.begin)
def __call__(self, *global_index):
if len(global_index) != len(self.shape):
raise ValueError("Need to provide %d index in buffer slice" % len(self.shape))
stride = [1]
for i in reversed(range(1, len(self.shape))):
stride.insert(0, self.shape[i] * stride[0])
local_index = []
for i in range(0, len(global_index)):
local_index.append(global_index[i] - self.offset_index[i])
index = local_index[0] * stride[0]
for i in range(1, len(local_index)):
index = index + local_index[i] * stride[i]
index = _expr_util.simplify(index)
return _expr.TensorRefExpr(self, [index])
class BufferManager(object):
def __init__(self):
self._buffer_map = {}
self._old_manager = None
def get(self, tensor):
if tensor in self._buffer_map:
return self._buffer_map[tensor]
return None
def bind(self, tensor, buf):
self._buffer_map[tensor] = buf
def __enter__(self):
self._old_manager = BufferManager.current
BufferManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
BufferManager.current = self._old_manager
# initialize the default buffer manager
BufferManager.current = BufferManager()

38
python/tvm/codegen.py Normal file
Просмотреть файл

@ -0,0 +1,38 @@
from __future__ import absolute_import as _abs
from . import buffer as _buffer
from . import expr as _expr
from . import expr_util as _expr_util
def gen_code(expr):
"""change expression to string.
Parameters
----------
expr : Expr
Input expression
Returns
-------
s : str
The string representation of expr
"""
def make_str(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.format_str(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.format_str(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _expr.TensorRefExpr):
buf = _buffer.BufferManager.current.get(e.tensor)
if buf:
return _expr_util.format_str(buf(*e.indices))
return _expr_util.format_str(e.tensor(*e.indices, flatten=True))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_stmt_str(result_children[0])
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return _expr_util.transform(expr, make_str)

Просмотреть файл

@ -112,7 +112,7 @@ class UnaryOpExpr(Expr):
class ReduceExpr(Expr):
def __init__(self, op, src, rdom):
def __init__(self, op, src, rdom):
self.op = op
self.src = src
self.rdom = rdom
@ -121,8 +121,8 @@ class ReduceExpr(Expr):
return (self.src,)
class TensorReadExpr(Expr):
"""Tensor read expression, tensor[indices]"""
class TensorRefExpr(Expr):
"""Tensor reference expression, tensor[indices]"""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices

Просмотреть файл

@ -27,8 +27,12 @@ def expr_with_new_children(e, children):
else _expr.BinaryOpExpr(e.op, children[0], children[1]))
elif isinstance(e, _expr.UnaryOpExpr):
return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0])
elif isinstance(e, _expr.TensorRefExpr):
return e if children == e.indices else _expr.TensorRefExpr(e.tensor, children)
elif isinstance(e, _expr.ReduceExpr):
return e if children[0] == e.src else _expr.ReduceExpr(e.op, children[0], e.rdom)
else:
raise TypeError("donnot know how to handle Expr %s" % type(e))
raise TypeError("do not know how to handle Expr %s" % type(e))
else:
return e
@ -92,8 +96,8 @@ def format_str(expr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _expr.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.TensorRefExpr):
return "%s[%s]" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
else:
@ -120,7 +124,7 @@ def simplify(expr):
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.canonical(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return {_op.constant_canonical_key: e.value}
return {_op.const_canonical_key: e.value}
elif isinstance(e, _expr.Var):
return {e: 1}
else:

Просмотреть файл

@ -1,12 +1,12 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
constant_canonical_key = '__constant__'
const_canonical_key = '__constant__'
def canonical_to_expr(c):
elements = []
for k, v in sorted(c.items()):
if k == constant_canonical_key and v != 0:
if k == const_canonical_key and v != 0:
elements.append(_expr.const(v))
elif v == 0:
continue
@ -35,6 +35,10 @@ class AddOp(BinaryOp):
def format_reduce_str(self, src, rd):
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
def format_reduce_stmt_str(self, src):
# a temporary hack for now
return "+ %s" % (src)
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
@ -86,8 +90,15 @@ class DivOp(BinaryOp):
erhs = canonical_to_expr(rhs)
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
remove = []
for k, v in lhs.items():
lhs[k] /= float(erhs.value)
if k == const_canonical_key:
lhs[k] = v / erhs.value
else:
lhs[k / erhs] = 1
remove.append(k)
for k in remove:
del lhs[k]
return lhs
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}

127
python/tvm/schedule.py Normal file
Просмотреть файл

@ -0,0 +1,127 @@
from __future__ import absolute_import as _abs
from . import domain as _dom
from . import expr as _expr
from . import expr_util as _expr_util
from . import split as _split
from . import buffer as _buffer
from . import codegen as _gen
start_point_key = '__start__'
TAB = ' '
class Schedule(object):
"""SUnit defines the compute schedule of a tensor
Parameters
----------
tensor: tensor
"""
def __init__(self, tensor, buffer=None):
self.tensor = tensor
self.buffer = buffer
self.parent = None
#self.children = []
self.splits = []
self.split_attach = {start_point_key: []}
self.implicit_splits = [_split.Split(i, 1) for i in range(tensor.ndim)]
if isinstance(tensor.expr, _expr.ReduceExpr):
for i in range(len(tensor.expr.rdom.domain)):
self.implicit_splits.append(_split.Split(i, 1, rdom=True))
def add_split(self, split):
self.splits.append(split)
self.split_attach[split] = []
def set_buffer(self, buf):
self.buffer = buf
def attach(self, split, other):
other.parent = self
if split is None:
self.split_attach[start_point_key].append(other)
else:
self.split_attach[split].append(other)
def infer_inner_domain(self, domain):
for split in self.splits:
domain = split.infer_inner_domain(domain)
return domain
def realize(self, domain=None, indent=''):
def realize_attach(lst):
attach_tensors = [sch.tensor for sch in lst]
attach_domains = self.tensor.infer_input_domains(domain, attach_tensors, red_domain=red_domain)
for sch in lst:
body.extend(sch.realize(attach_domains[sch.tensor], indent))
# init domain and red_domain
if domain is None:
domain = self.tensor.domain
red_domain = self.tensor.expr.rdom.domain if isinstance(self.tensor.expr, _expr.ReduceExpr) else None
# init buffer shape
if self.buffer:
if self.buffer.scope == _buffer.Scope.Global:
self.buffer.reshape(self.tensor.domain)
else:
# don't handle shared buffer for now
self.buffer.reshape(domain)
_buffer.BufferManager.current.bind(self.tensor, self.buffer)
body = []
if self.split_attach[start_point_key]:
realize_attach(self.split_attach[start_point_key])
# add loop conditions for splits
for split in self.splits:
if split.rdom:
red_domain = split.generate_loop_condition(red_domain, body, indent)
else:
domain = split.generate_loop_condition(domain, body, indent)
indent += TAB
if self.split_attach[split]:
realize_attach(self.split_attach[split])
# add implicit loop conditions
for split in self.implicit_splits:
if split.rdom:
red_domain = split.generate_loop_condition(red_domain, body, indent)
else:
domain = split.generate_loop_condition(domain, body, indent)
indent += TAB
# add loop body
expr = self.tensor.expr
global_index = [r.begin for r in domain]
global_rdom_index = [r.begin for r in red_domain] if red_domain else []
if expr is None:
if self.buffer:
lhs = self.buffer(*global_index)
rhs = self.tensor(*global_index, flatten=True)
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _expr_util.format_str(rhs)))
else:
if self.buffer:
lhs = self.buffer(*global_index)
else:
lhs = self.tensor(*global_index, flatten=True)
bind_dict = {}
for i in range(self.tensor.ndim):
bind_dict[self.tensor.dim_index[i]] = global_index[i]
if isinstance(expr, _expr.ReduceExpr):
for i in range(len(expr.rdom.domain)):
bind_dict[expr.rdom.index[i]] = global_rdom_index[i]
rhs = _expr_util.bind(expr, bind_dict)
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _gen.gen_code(rhs)))
# add right brackets
for split in self.implicit_splits:
indent = indent[:-len(TAB)]
body.append('%s}' % indent)
for split in self.splits:
indent = indent[:-len(TAB)]
body.append('%s}' % indent)
return body

Просмотреть файл

@ -1,22 +1,35 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import domain as _dom
from . import tensor as _tensor
class Split(object):
def __init__(self, dim, factor):
def __init__(self, dim, factor, name=None, rdom=False):
self.dim = dim
self.factor = factor
self.loop_index = _expr.Var('loop_index_%d_' % dim)
self.rdom = rdom
if name is None:
name = 'loop_index_%d_' % dim
self.loop_index = _expr.Var(name)
def infer_inner_domain(self, domain):
if isinstance(domain, _dom.RDom):
domain = domain.domain
assert self.dim < len(domain)
inner_domain = domain[:]
dim_out_range = domain[self.dim]
def infer_inner_domain(self, out_domain):
assert self.dim < len(out_domain)
inner_domain = out_domain[:]
dim_out_range = out_domain[self.dim]
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
return inner_domain
def generate_loop_condition(self, out_domain, body, indent):
assert self.dim < len(out_domain)
loop_range = _dom.Range(out_domain[self.dim].extent / self.factor)
stmt = '%sfor (int %s = 0; %s < %s; %s += 1) {' % (
indent,
self.loop_index.name,
self.loop_index.name,
_expr_util.format_str(loop_range.end),
self.loop_index.name)
body.append(stmt)
return self.infer_inner_domain(out_domain)

Просмотреть файл

@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import domain as _dom
from . import var_name as _name
class Tensor(object):
@ -23,13 +24,26 @@ class Tensor(object):
self.shape = shape if shape else tuple(
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj"
self.name = name if name else _name.NameManager.current.get("TensorObj")
self.inputs = None
def __call__(self, *indices):
def __call__(self, *indices, **option):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return _expr.TensorReadExpr(self, indices)
if 'flatten' in option and option['flatten']:
stride = [1]
for i in reversed(range(1, len(indices))):
stride.insert(0, self.shape[i] * stride[0])
index = indices[0] * stride[0]
for i in range(1, len(indices)):
index = index + indices[i] * stride[i]
index = _expr_util.simplify(index)
return _expr.TensorRefExpr(self, [index])
return _expr.TensorRefExpr(self, indices)
@property
def domain(self):
return _dom.Domain([_dom.Range(self.shape[i]) for i in range(self.ndim)])
def input_tensors(self):
"""List of input tensors to this tensor.
@ -43,7 +57,7 @@ class Tensor(object):
inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
if isinstance(e, _expr.TensorRefExpr):
inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
self.inputs = set(inputs)
@ -93,7 +107,7 @@ class Tensor(object):
rd = e.rdom
for i in range(len(rd.domain)):
index_domains[rd.index[i]] = rd.domain[i]
elif isinstance(e, _expr.TensorReadExpr):
elif isinstance(e, _expr.TensorRefExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(begin_expr, prepare)

Просмотреть файл

@ -0,0 +1,14 @@
import tvm
def test_buffer():
buf = tvm.Buffer(tvm.Scope.Thread)
shape = [32, 16]
domain = [tvm.Range(v) for v in shape]
buf.reshape(domain)
x = tvm.Var('x')
y = tvm.Var('y')
assert tvm.format_str(buf(y, x)) == '%s[(%s + (%s * %s))]' % (buf.name, x.name, y.name, shape[1])
if __name__ == '__main__':
test_buffer()

Просмотреть файл

@ -0,0 +1,41 @@
import tvm
def test_schedule():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
rd = tvm.RDom(tvm.Range(A.shape[1]))
T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]), name="T")
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0]), name="C")
bufA = tvm.Buffer(tvm.Scope.Thread, name='A')
bufB = tvm.Buffer(tvm.Scope.Thread, name='B')
bufT = tvm.Buffer(tvm.Scope.Thread, name='T')
schA = tvm.Schedule(A, buffer=bufA)
schB = tvm.Schedule(B, buffer=bufB)
schT = tvm.Schedule(T, buffer=bufT)
schC = tvm.Schedule(C)
Cx0 = tvm.Split(dim=0, factor=64)
Cy0 = tvm.Split(dim=1, factor=64)
Cx1 = tvm.Split(dim=0, factor=8)
Cy1 = tvm.Split(dim=1, factor=8)
Tk = tvm.Split(dim=0, factor=8, rdom=True)
schC.add_split(Cx0)
schC.add_split(Cy0)
schC.add_split(Cx1)
schC.add_split(Cy1)
schT.add_split(Tk)
schC.attach(Cy1, schT)
schT.attach(Tk, schA)
schT.attach(Tk, schB)
body = schC.realize()
print('\n'.join(body))
if __name__ == "__main__":
test_schedule()

Просмотреть файл

@ -1,8 +1,8 @@
import tvm
def test_split_dom_infer():
A = tvm.Tensor(2, name='A')
rd = tvm.RDom(tvm.Range(A.shape[1]))
split1 = tvm.Split(0, 64)
split2 = tvm.Split(1, 64)
split3 = tvm.Split(0, 8)
@ -10,14 +10,12 @@ def test_split_dom_infer():
dom1 = split1.infer_inner_domain(dom)
dom2 = split2.infer_inner_domain(dom1)
dom3 = split3.infer_inner_domain(dom2)
dom4 = split3.infer_inner_domain(rd)
i1 = split1.loop_index.name
i2 = split2.loop_index.name
i3 = split3.loop_index.name
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
if __name__ == "__main__":

Просмотреть файл

@ -12,7 +12,7 @@ def test_tensor_inputs():
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
assert(T.input_tensors() == [A, B])
assert(T.input_tensors() == set([A, B]))
def test_tensor_reduce():
A = tvm.Tensor(2, name='A')
@ -25,5 +25,6 @@ def test_tensor_reduce():
print(tvm.format_str(C.expr))
if __name__ == "__main__":
test_tensor()
test_tensor_inputs()
test_tensor_reduce()