[Hybrid Script] Unify the symbol tables to one; support `tvm.container.Array` (#2366)
This commit is contained in:
Родитель
151f550b2e
Коммит
a42d1e3c77
|
@ -52,7 +52,8 @@ The current parse interface looks like:
|
|||
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
|
||||
|
||||
|
||||
If we pass these tvm tensors to this function, it returns a op node:
|
||||
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
|
||||
or ``tvm.container.Array``, to this function, it returns a op node:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
|
|||
b = tvm.placeholder((99, ), name='b')
|
||||
c = outer_product(a, b, c) # return the output tensor(s) of the operator
|
||||
|
||||
**Under construction, we are still deciding what kind of node should be returned.**
|
||||
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
|
||||
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
|
||||
to LLVM module.
|
||||
|
||||
Tuning
|
||||
~~~~~~
|
||||
|
||||
**Under construction, not truly supported yet.**
|
||||
**Under construction, not supported yet.**
|
||||
|
||||
Follow up the example above, you can use some tvm like interfaces to tune the code:
|
||||
|
||||
|
@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
|
|||
these **4** keywords to annotate the corresponding types of for loops.
|
||||
The the usage is roughly the same as Python standard ``range``.
|
||||
|
||||
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
|
||||
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
|
||||
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
|
||||
Users can access containers by either constants or constants loops annotated.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@tvm.hybrid.script
|
||||
def foo(a, b): # b is a tvm.container.Array
|
||||
c = output_tensor(a.shape, a.dtype)
|
||||
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
|
||||
c[i] = a[i] + b[i]
|
||||
return c
|
||||
|
||||
|
||||
Variables
|
||||
~~~~~~~~~
|
||||
|
||||
|
@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
|
|||
s += a[i, j] # do something with sum
|
||||
b[i] = sum # you can still use sum in this level
|
||||
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
|
||||
b = (1, 2) # this has NOT been supported yet!
|
||||
|
||||
|
||||
Attributes
|
||||
~~~~~~~~~~
|
||||
|
||||
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
|
||||
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
|
||||
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
|
||||
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
|
||||
Currently, only constant-indexed access is supported.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -133,8 +151,11 @@ Conditional Statement and Expression
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
if condition:
|
||||
# do something
|
||||
if condition1 and condition2 and condition3:
|
||||
# do something
|
||||
else:
|
||||
# do something else
|
||||
# Select
|
||||
a = b if condition else c
|
||||
|
||||
However, NO ``True`` and ``False`` keyword supported yet.
|
||||
|
@ -153,7 +174,9 @@ Array Allocation
|
|||
**Under construction, this function will be supported later!**
|
||||
|
||||
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
|
||||
The basic usage is roughly the same as a normal array.
|
||||
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
|
||||
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
|
||||
even for ``tvm.container.Array`` for compilation.
|
||||
|
||||
|
||||
Thread Bind
|
||||
|
@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
|
|||
|
||||
Keywords
|
||||
~~~~~~~~
|
||||
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
|
||||
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
|
||||
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
|
||||
|
|
|
@ -12,15 +12,17 @@ from .util import _internal_assert
|
|||
#pylint: disable=redefined-builtin
|
||||
|
||||
LOOP_INTRIN = {
|
||||
'range' : For.Serial,
|
||||
'unroll' : For.Unrolled,
|
||||
'parallel' : For.Parallel,
|
||||
'vectorize': For.Vectorized,
|
||||
'range' : For.Serial,
|
||||
'unroll' : For.Unrolled,
|
||||
'parallel' : For.Parallel,
|
||||
'vectorize' : For.Vectorized,
|
||||
'const_range' : (For.Unrolled, ),
|
||||
}
|
||||
|
||||
|
||||
def _range(annotation, args):
|
||||
"""Handling TVM loop types"""
|
||||
n = len(args)
|
||||
n = args.__len__()
|
||||
if n == 1:
|
||||
low, ext = _api.const(0, dtype='int32'), args[0]
|
||||
else:
|
||||
|
@ -33,13 +35,13 @@ def _range(annotation, args):
|
|||
return iter_var, low, ext, for_type
|
||||
|
||||
|
||||
range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name
|
||||
range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name
|
||||
|
||||
|
||||
def bind(func_id, args):
|
||||
"""Handling TVM thread binding"""
|
||||
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
|
||||
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
|
||||
_internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
|
||||
_internal_assert(isinstance(args[0], str), \
|
||||
"A loop bind's first argument should be a string!")
|
||||
iter_var = _api.thread_axis(args[0])
|
||||
|
@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
|
|||
|
||||
|
||||
def _min_max(func_id, args):
|
||||
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
|
||||
_internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
|
||||
return getattr(_make, func_id.title())(args[0], args[1])
|
||||
|
||||
|
||||
|
@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
|
|||
def _allocate_tensor(func_id, args):
|
||||
"""Handling TVM tensor allocation.
|
||||
You may refer hybrid.intrin.allocate for more details."""
|
||||
n = len(args)
|
||||
n = args.__len__()
|
||||
_internal_assert(isinstance(_api.convert(args[0]), Array), \
|
||||
"allocate's first argument should be a tuple of shape!")
|
||||
shape = args[0]
|
||||
|
@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
|
|||
scope = 'global' if func_id != 'output_tensor' else 'output'
|
||||
return (shape, dtype, scope)
|
||||
|
||||
|
||||
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
|
||||
|
||||
|
||||
def len(func_id, args):
|
||||
"""Iterpret the len function"""
|
||||
_internal_assert(args.__len__() == 1, "Only 1 argument is expected!")
|
||||
_internal_assert(func_id == "len", "This function cannot be directly invoked!")
|
||||
try:
|
||||
return _api.convert(args[0].__len__())
|
||||
except: #pylint: disable=bare-except
|
||||
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
|
||||
return _api.convert(args[0].shape[0])
|
||||
|
|
|
@ -2,32 +2,19 @@
|
|||
|
||||
import numpy
|
||||
|
||||
class _range(object):
|
||||
"""Base class of the loop ranges in hybrid script"""
|
||||
def __init__(self, a, b=None):
|
||||
if b is None:
|
||||
self.low = 0
|
||||
self.ext = a
|
||||
else:
|
||||
self.low = a
|
||||
self.ext = b
|
||||
|
||||
class bind(object): #pylint: disable=invalid-name
|
||||
"""GPU bind software emulataion runtime."""
|
||||
def __init__(self, _, ext):
|
||||
self.ext = ext
|
||||
|
||||
def __iter__(self):
|
||||
i = 0
|
||||
while i < self.ext:
|
||||
yield i + self.low
|
||||
yield i
|
||||
i += 1
|
||||
|
||||
|
||||
class bind(_range): #pylint: disable=invalid-name
|
||||
def __init__(self, tag, ext):
|
||||
super(bind, self).__init__(ext)
|
||||
self.tag = tag
|
||||
|
||||
|
||||
unroll = vectorize = parallel = _range #pylint: disable=invalid-name
|
||||
|
||||
|
||||
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
|
||||
"""Allocate a buffer with given shape
|
||||
|
||||
|
@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
|
|||
"""
|
||||
return numpy.zeros(shape).astype(dtype)
|
||||
|
||||
output_tensor = allocate #pylint: disable=invalid-name
|
||||
|
||||
def popcount(x):
|
||||
"""
|
||||
|
@ -87,17 +73,19 @@ def sigmoid(x):
|
|||
|
||||
|
||||
HYBRID_GLOBALS = {
|
||||
'unroll' : unroll,
|
||||
'vectorize' : vectorize,
|
||||
'parallel' : parallel,
|
||||
'allocate' : allocate,
|
||||
'output_tensor': output_tensor,
|
||||
'len' : len,
|
||||
'unroll' : range,
|
||||
'vectorize' : range,
|
||||
'parallel' : range,
|
||||
'const_range' : range,
|
||||
'bind' : bind,
|
||||
'allocate' : allocate,
|
||||
'output_tensor': allocate,
|
||||
'sqrt' : numpy.sqrt,
|
||||
'log' : numpy.log,
|
||||
'tanh' : numpy.tanh,
|
||||
'power' : numpy.power,
|
||||
'exp' : numpy.exp,
|
||||
'sigmoid' : sigmoid,
|
||||
'popcount' : popcount
|
||||
'popcount' : popcount,
|
||||
}
|
||||
|
|
|
@ -4,7 +4,10 @@ import ast
|
|||
import operator
|
||||
import logging
|
||||
import sys
|
||||
from numbers import Integral
|
||||
import types
|
||||
import numbers
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .util import _internal_assert
|
||||
from . import calls
|
||||
|
@ -12,18 +15,15 @@ from . import util
|
|||
from .var_decl import determine_variable_usage
|
||||
from ..api import all as _all
|
||||
from ..api import any as _any
|
||||
from ..container import Array
|
||||
from ..tensor import Tensor, Operation
|
||||
from .. import expr as _expr
|
||||
from .. import make as _make
|
||||
from .. import api as _api
|
||||
from .. import ir_pass as _ir_pass
|
||||
|
||||
def list_to_block(visit, lst):
|
||||
"""Convert a list of Python IR nodes to HalideIR Block"""
|
||||
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
|
||||
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
|
||||
if not lst:
|
||||
return util.make_nop()
|
||||
|
||||
def pack_list_to_block(lst):
|
||||
if len(lst) == 1:
|
||||
return lst[0]
|
||||
body = lst[0]
|
||||
|
@ -32,6 +32,29 @@ def list_to_block(visit, lst):
|
|||
return body
|
||||
|
||||
|
||||
def visit_list_to_block(visit, lst):
|
||||
"""Convert a list of Python IR nodes to HalideIR Block"""
|
||||
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
|
||||
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
|
||||
if not lst:
|
||||
return util.make_nop()
|
||||
return pack_list_to_block(lst)
|
||||
|
||||
|
||||
class Symbol(Enum):
|
||||
"""Enumerates types in the symbol table"""
|
||||
Callable = 0
|
||||
Input = 1
|
||||
OutputBuffer = 2
|
||||
GlobalBuffer = 3
|
||||
LocalBuffer = 4
|
||||
SharedBuffer = 5
|
||||
ConstVar = 6
|
||||
BufferVar = 7
|
||||
LoopVar = 8
|
||||
ConstLoopVar = 9
|
||||
|
||||
|
||||
class HybridParser(ast.NodeVisitor):
|
||||
"""Python AST visitor pass which finally lowers it to HalideIR"""
|
||||
|
||||
|
@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor):
|
|||
"""
|
||||
self.args = list(args)
|
||||
self.usage = usage.copy()
|
||||
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
|
||||
self.alloc_buffers = {} # Buffers formed by explicit allocate instructions
|
||||
self.loops_above = {} # State variable that indicates loop levels above the current node
|
||||
self.variables = {} # The status of defined variables
|
||||
|
||||
self.symbols = {} # Symbol table
|
||||
for k, v in symbols.items():
|
||||
if isinstance(v, types.FunctionType):
|
||||
self.symbols[k] = Symbol.Callable, v
|
||||
|
||||
self.func_name = func_name # The name of the function to be lowered
|
||||
self.outputs = [] # Output tensors' name
|
||||
self.side_effect = set() # Tensors with side effects
|
||||
self.parsed_body = None # The parsed HalideIR body
|
||||
self.returned = False # If this function has a valid return
|
||||
self.symbols = symbols # The global context
|
||||
|
||||
|
||||
|
||||
def wrap_up_realize(self, node, body):
|
||||
"""Wrap up all the variables which will no longer be used"""
|
||||
pop_buf = []
|
||||
pop_var = []
|
||||
to_pop = []
|
||||
for key, val in self.usage.items():
|
||||
_, level, _ = val
|
||||
if level != node:
|
||||
continue
|
||||
if key in self._args.keys():
|
||||
_internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
|
||||
|
||||
ty, entry = self.symbols[key] #pylint: disable=invalid-name
|
||||
if ty in [Symbol.Input, Symbol.OutputBuffer]:
|
||||
continue
|
||||
if key in self.alloc_buffers.keys():
|
||||
_buf, _scope = self.alloc_buffers[key]
|
||||
if _scope == 'output':
|
||||
continue
|
||||
pop_buf.append(key)
|
||||
elif 'Buffer' in ty.name:
|
||||
_buf = entry
|
||||
_scope = ty.name[:-6].lower() if ty is not Symbol.BufferVar else 'global'
|
||||
to_pop.append(key)
|
||||
else:
|
||||
_internal_assert(key in self.variables.keys(),
|
||||
"Key should be either in one of args, buffers, and vars")
|
||||
if not isinstance(self.variables[key], tuple):
|
||||
continue
|
||||
_buf, _scope = self.variables[key]
|
||||
pop_var.append(key)
|
||||
continue
|
||||
|
||||
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
|
||||
_dtype = _buf.dtype
|
||||
_true = _api.convert(True)
|
||||
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
|
||||
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
|
||||
|
||||
for elem in pop_buf:
|
||||
self.alloc_buffers.pop(elem)
|
||||
for elem in pop_var:
|
||||
self.variables.pop(elem)
|
||||
for elem in to_pop:
|
||||
self.symbols.pop(elem)
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def _get_buffer_from_id(self, s, for_provide=False):
|
||||
_internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1,
|
||||
"This %s is expected to be in either \
|
||||
argument list or allocated buffer!" % s)
|
||||
if s in self._args.keys():
|
||||
if for_provide:
|
||||
self.side_effect.add(self._args[s])
|
||||
return self._args[s]
|
||||
return self.alloc_buffers[s][0]
|
||||
|
||||
def _const(self, value, dtype=None):
|
||||
if dtype is None:
|
||||
if isinstance(value, bool):
|
||||
dtype = "bool"
|
||||
elif isinstance(value, Integral):
|
||||
dtype = "int32"
|
||||
else:
|
||||
dtype = "float32"
|
||||
return _api.const(value, dtype)
|
||||
|
||||
#pylint: disable=invalid-name, missing-docstring
|
||||
def visit_Module(self, node):
|
||||
_internal_assert(len(node.body) == 1, \
|
||||
"Only one-function source code can be fed to this parser!")
|
||||
"Only one-function source code will be fed to this parser!")
|
||||
return self.visit(node.body[0])
|
||||
|
||||
|
||||
|
@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor):
|
|||
self.func_name = node.name
|
||||
for idx, arg in enumerate(node.args.args):
|
||||
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
|
||||
self._args[getattr(arg, _attr)] = self.args[idx]
|
||||
res = list_to_block(self.visit, node.body)
|
||||
self.symbols[getattr(arg, _attr)] = (Symbol.Input, self.args[idx])
|
||||
res = visit_list_to_block(self.visit, node.body)
|
||||
res = self.wrap_up_realize(node, res)
|
||||
return res
|
||||
|
||||
|
@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor):
|
|||
|
||||
def visit_Name(self, node):
|
||||
name = node.id
|
||||
if name in self.loops_above.keys():
|
||||
return self.loops_above[name]
|
||||
elif name in self.variables.keys():
|
||||
res = self.variables[name]
|
||||
if isinstance(res, tuple):
|
||||
buf = res[0]
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
return _make.Call(buf.dtype, buf.name, [self._const(0)], \
|
||||
_expr.Call.Halide, buf.op, buf.value_index)
|
||||
return buf, [self._const(0)]
|
||||
ty, entry = self.symbols[name]
|
||||
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
|
||||
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
|
||||
return entry
|
||||
elif ty is Symbol.ConstVar:
|
||||
return entry if isinstance(node.ctx, ast.Load) else None
|
||||
elif ty is Symbol.BufferVar:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
return res
|
||||
return None
|
||||
buf = self._get_buffer_from_id(name)
|
||||
return buf
|
||||
return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
|
||||
_expr.Call.Halide, entry.op, entry.value_index)
|
||||
return entry, [_api.const(0, 'int32')]
|
||||
# Do I need any assertion here?
|
||||
return entry
|
||||
|
||||
|
||||
def visit_Num(self, node):
|
||||
return self._const(node.n)
|
||||
if isinstance(node.n, numbers.Integral):
|
||||
dtype = "int32"
|
||||
elif isinstance(node.n, float):
|
||||
dtype = "float32"
|
||||
else:
|
||||
_internal_assert(isinstance(node.n, bool),
|
||||
"The data type should be one of (int, float, bool)")
|
||||
dtype = "bool"
|
||||
return _api.const(node.n, dtype)
|
||||
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
|
@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor):
|
|||
_internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
|
||||
buf, args = buf
|
||||
else:
|
||||
args = [self._const(0)]
|
||||
args = [_api.const(0, 'int32')]
|
||||
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
|
||||
|
||||
read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
|
||||
|
@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor):
|
|||
for i in range(rhs.num_outputs):
|
||||
_internal_assert(isinstance(node.targets[i], ast.Name),
|
||||
"You should bind a pure name to the tensors")
|
||||
self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global')
|
||||
self.symbols[node.targets[i].id] = Symbol.GlobalBuffer, rhs.output(i)
|
||||
rmap[rhs.outputs[i].op] = rhs.output(i)
|
||||
return util.replace_io(rhs.body, rmap)
|
||||
|
||||
|
@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor):
|
|||
#TODO: support defined intermediate buffer later
|
||||
lhs_ = lhs
|
||||
lhs = lhs.id
|
||||
_internal_assert(lhs not in self.loops_above.keys(), \
|
||||
"Loop variable cannot be overwritten!")
|
||||
if lhs in self.symbols.keys():
|
||||
ty, _ = self.symbols[lhs]
|
||||
_internal_assert(ty != Symbol.LoopVar, \
|
||||
"Loop variable cannot be overwritten!")
|
||||
decl, _, rw = self.usage[lhs]
|
||||
if decl == lhs_:
|
||||
_internal_assert(lhs not in self.variables.keys() and
|
||||
lhs not in self.alloc_buffers.keys(), \
|
||||
_internal_assert(lhs not in self.symbols.keys(),
|
||||
"This value should not be defined before this point!")
|
||||
if isinstance(rhs, tuple):
|
||||
shape, dtype, scope = rhs
|
||||
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
|
||||
self.alloc_buffers[lhs] = (ph, scope)
|
||||
self.symbols[lhs] = getattr(Symbol, scope.title() + "Buffer"), ph
|
||||
if scope == 'output':
|
||||
self.outputs.append(lhs)
|
||||
return util.make_nop()
|
||||
if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
|
||||
self.variables[lhs] = rhs
|
||||
self.symbols[lhs] = Symbol.ConstVar, rhs
|
||||
else:
|
||||
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
|
||||
self.variables[lhs] = (ph, 'global')
|
||||
self.symbols[lhs] = Symbol.BufferVar, ph
|
||||
lhs = self.visit(lhs_)
|
||||
if lhs is not None:
|
||||
buf, args = lhs
|
||||
|
@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor):
|
|||
def visit_Attribute(self, node):
|
||||
_internal_assert(isinstance(node.value, ast.Name), \
|
||||
"For atrribute access, only both names are supported so far!")
|
||||
buf = self._get_buffer_from_id(node.value.id)
|
||||
buf = self.visit(node.value)
|
||||
return getattr(buf, node.attr)
|
||||
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
args = self.visit(node.slice)
|
||||
if isinstance(node.value, ast.Name):
|
||||
|
||||
buf = self.visit(node.value)
|
||||
if isinstance(buf, Array):
|
||||
for i in args:
|
||||
if isinstance(i, numbers.Integral):
|
||||
buf = buf[i]
|
||||
else:
|
||||
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
|
||||
"All indices are supposed to be constants")
|
||||
buf = buf[i.value]
|
||||
|
||||
return buf
|
||||
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
return _make.Call(buf.dtype, buf.name, args, \
|
||||
_expr.Call.Halide, buf.op, buf.value_index)
|
||||
|
||||
return buf, args
|
||||
|
||||
shape = self.visit(node.value)
|
||||
|
@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor):
|
|||
_internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
|
||||
_internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
|
||||
self.annotation[option.id] = context.func.id
|
||||
return list_to_block(self.visit, node.body)
|
||||
return visit_list_to_block(self.visit, node.body)
|
||||
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if_body = list_to_block(self.visit, node.body)
|
||||
if_body = visit_list_to_block(self.visit, node.body)
|
||||
if node.orelse:
|
||||
else_body = list_to_block(self.visit, node.orelse)
|
||||
else_body = visit_list_to_block(self.visit, node.orelse)
|
||||
else:
|
||||
else_body = util.make_nop()
|
||||
return _make.IfThenElse(cond, if_body, else_body)
|
||||
|
@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor):
|
|||
except AttributeError:
|
||||
_internal_assert(func_id in self.symbols.keys(), \
|
||||
"The function called is not in the context either!")
|
||||
outs = self.symbols[func_id](*args)
|
||||
ty, entry = self.symbols[func_id]
|
||||
_internal_assert(ty is Symbol.Callable, \
|
||||
"Are you sure what you call is a function?!")
|
||||
outs = entry(*args)
|
||||
op = outs.op if isinstance(outs, Tensor) else outs[0].op
|
||||
return op
|
||||
|
||||
|
@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor):
|
|||
iter_var, low, ext, for_type = self.visit(node.iter)
|
||||
_internal_assert(isinstance(node.target, ast.Name), \
|
||||
"The loop iterator should be a variable!")
|
||||
|
||||
_name = node.target.id
|
||||
if iter_var is None:
|
||||
|
||||
if isinstance(for_type, tuple):
|
||||
low = _ir_pass.Simplify(low)
|
||||
ext = _ir_pass.Simplify(ext)
|
||||
_internal_assert(isinstance(low, _expr.ConstExpr) and
|
||||
isinstance(ext, _expr.ConstExpr), \
|
||||
"Const range should start from a const" + \
|
||||
"and iterate const times")
|
||||
|
||||
low, ext = low.value, ext.value
|
||||
if ext > 114514:
|
||||
logging.log(logging.CRITICAL, \
|
||||
'[Warning] Are you sure to unroll a large loop in Python?')
|
||||
|
||||
bodies = []
|
||||
for i in range(low, low + ext):
|
||||
self.symbols[_name] = Symbol.ConstLoopVar, i
|
||||
bodies.append(visit_list_to_block(self.visit, node.body))
|
||||
return pack_list_to_block(bodies)
|
||||
|
||||
elif iter_var is None:
|
||||
_internal_assert(for_type is not None, "The loop bind function parse error!")
|
||||
offset = iter_var = _api.var(_name)
|
||||
if not _ir_pass.Equal(low, self._const(0)):
|
||||
if not _ir_pass.Equal(low, _api.const(0, 'int32')):
|
||||
offset = iter_var + low
|
||||
self.loops_above[_name] = offset
|
||||
self.symbols[_name] = Symbol.LoopVar, offset
|
||||
_body = visit_list_to_block(self.visit, node.body)
|
||||
else:
|
||||
_internal_assert(for_type is None, "The loop iterating function parse error!")
|
||||
self.loops_above[_name] = iter_var.var
|
||||
_body = list_to_block(self.visit, node.body)
|
||||
self.symbols[_name] = Symbol.LoopVar, iter_var.var
|
||||
_body = visit_list_to_block(self.visit, node.body)
|
||||
|
||||
_body = self.wrap_up_realize(node, _body)
|
||||
|
||||
if for_type is None:
|
||||
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
|
||||
else:
|
||||
res = _make.For(iter_var, self._const(0), ext, for_type, 0, _body)
|
||||
self.loops_above.pop(_name)
|
||||
elif not isinstance(for_type, tuple):
|
||||
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
|
||||
self.symbols.pop(_name)
|
||||
return res
|
||||
|
||||
|
||||
def visit_Return(self, node):
|
||||
_internal_assert(not self.loops_above, "Return should not be in a loop body!")
|
||||
_internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \
|
||||
"Return should not be in a loop body!")
|
||||
ids = []
|
||||
if isinstance(node.value, ast.Name):
|
||||
ids.append(node.value.id)
|
||||
ids = [node.value.id]
|
||||
else:
|
||||
_internal_assert(isinstance(node.value, ast.Tuple), \
|
||||
"You should return either a single tensor or a tuple")
|
||||
for i in node.value.elts:
|
||||
_internal_assert(isinstance(i, ast.Name), "What do you return?")
|
||||
ids.append(i.id)
|
||||
_internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \
|
||||
"What do you return?")
|
||||
ids = [i.id for i in node.value.elts]
|
||||
_internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
|
||||
if len(ids) < len(self.outputs):
|
||||
logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
|
||||
self.outputs = [self.alloc_buffers[i][0] for i in ids]
|
||||
self.outputs = [self.symbols[i][1] for i in ids]
|
||||
self.returned = True
|
||||
return util.make_nop()
|
||||
|
||||
|
|
|
@ -11,12 +11,13 @@ from .. import api as _api
|
|||
from .. import make as _make
|
||||
from .. import expr as _expr
|
||||
from .. import stmt as _stmt
|
||||
from ..container import Array
|
||||
from ..tensor import Tensor
|
||||
|
||||
|
||||
#pylint: disable=invalid-name
|
||||
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
|
||||
tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr)
|
||||
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
|
||||
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
|
||||
|
||||
def _internal_assert(cond, err):
|
||||
|
|
|
@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
|
|||
ctx = tvm.context(target, 0)
|
||||
op = None
|
||||
|
||||
outs = func(*args)
|
||||
outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args))
|
||||
op = outs[0].op if isinstance(outs, list) else outs.op
|
||||
|
||||
emu_args = []
|
||||
|
@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
|
|||
shape = [tvm_val_2_py_val(j) for j in i.shape]
|
||||
emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
|
||||
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
|
||||
else:
|
||||
assert isinstance(i, tvm.expr.Var)
|
||||
elif isinstance(i, tvm.expr.Var):
|
||||
emu_args.append(tvm_val_2_py_val(i))
|
||||
nd_args.append(emu_args[-1])
|
||||
else:
|
||||
assert isinstance(i, list)
|
||||
emu_args.append(numpy.array(i))
|
||||
|
||||
sch = tvm.create_schedule(op)
|
||||
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
|
||||
module = tvm.build(sch,
|
||||
[i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
|
||||
(outs if isinstance(outs, list) else [outs]),
|
||||
target=target)
|
||||
assert module
|
||||
|
||||
out_tensors = []
|
||||
|
@ -192,20 +197,20 @@ def test_fanout():
|
|||
def test_looptype():
|
||||
@script
|
||||
def looptype(a, b, c):
|
||||
d = output_tensor((8, ), 'int32')
|
||||
e = output_tensor((8, ), 'int32')
|
||||
f = output_tensor((8, ), 'int32')
|
||||
for i in parallel(8):
|
||||
d = output_tensor((16, ), 'int32')
|
||||
e = output_tensor((16, ), 'int32')
|
||||
f = output_tensor((16, ), 'int32')
|
||||
for i in parallel(16):
|
||||
d[i] = a[i]
|
||||
for j in vectorize(8):
|
||||
for j in vectorize(16):
|
||||
e[j] = b[j]
|
||||
for k in unroll(8):
|
||||
for k in unroll(16):
|
||||
f[k] = c[k]
|
||||
return d, e, f
|
||||
|
||||
a = tvm.placeholder((8, ), name='a', dtype='int32')
|
||||
b = tvm.placeholder((8, ), name='b', dtype='int32')
|
||||
c = tvm.placeholder((8, ), name='c', dtype='int32')
|
||||
a = tvm.placeholder((16, ), name='a', dtype='int32')
|
||||
b = tvm.placeholder((16, ), name='b', dtype='int32')
|
||||
c = tvm.placeholder((16, ), name='c', dtype='int32')
|
||||
try:
|
||||
d, e, f = looptype(a, b, c)
|
||||
ir = d.op.body
|
||||
|
@ -509,9 +514,9 @@ def test_value_index():
|
|||
def test_func_call():
|
||||
@tvm.hybrid.script
|
||||
def foo(a, b):
|
||||
for i in range(10):
|
||||
for i in range(len(a)):
|
||||
a[i] = i + 1.0
|
||||
for i in range(10):
|
||||
for i in range(len(a)):
|
||||
b[i] = i + 1.0
|
||||
c = outer_product(10, 10, a, b)
|
||||
d = output_tensor(c.shape, c.dtype)
|
||||
|
@ -538,6 +543,26 @@ def test_bool():
|
|||
a = tvm.placeholder((10, ), name='a')
|
||||
run_and_check(foo, [a])
|
||||
|
||||
def test_const_range():
|
||||
@tvm.hybrid.script
|
||||
def foo(a, b):
|
||||
c = output_tensor(a.shape, a.dtype)
|
||||
d = output_tensor(a.shape, a.dtype)
|
||||
|
||||
for i in const_range(2):
|
||||
for j in const_range(5):
|
||||
c[i, j] = a[i, j] + b[i, j]
|
||||
|
||||
for i in const_range(len(b)):
|
||||
for j in const_range(len(b[0])):
|
||||
d[i, j] = a[i, j] + b[i, j]
|
||||
|
||||
return c, d
|
||||
|
||||
a = tvm.placeholder((2, 5), name='a', dtype='int32')
|
||||
b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
|
||||
run_and_check(foo, [a, b])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_outer_product()
|
||||
test_fanout()
|
||||
|
@ -553,5 +578,6 @@ if __name__ == "__main__":
|
|||
test_value_index()
|
||||
test_func_call()
|
||||
test_bool()
|
||||
test_const_range()
|
||||
# TODO:
|
||||
# test_inplace()
|
||||
|
|
Загрузка…
Ссылка в новой задаче