checked buffer / schedule / code gen
This commit is contained in:
Родитель
d3ee03ebf4
Коммит
de2be97ec3
|
@ -5,5 +5,7 @@ from .op import *
|
|||
from .expr import Var, const
|
||||
from .expr_util import *
|
||||
from .tensor import Tensor
|
||||
from .domain import RDom, Range, infer_range
|
||||
from .domain import Range, RDom, infer_range
|
||||
from .split import Split
|
||||
from .buffer import Scope, Buffer
|
||||
from .schedule import Schedule
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
from . import var_name as _name
|
||||
|
||||
|
||||
def enum(*sequential, **named):
|
||||
enums = dict(zip(sequential, range(len(sequential))), **named)
|
||||
return type('Enum', (), enums)
|
||||
|
||||
|
||||
"""Scope defines the scope of a buffer
|
||||
|
||||
Types
|
||||
-----
|
||||
Thread : thread private buffer (registers)
|
||||
Shared : shared buffer within a thread block (shared memory)
|
||||
Global : buffer in the global GPU RAM
|
||||
"""
|
||||
Scope = enum('Thread', 'Shared', 'Global')
|
||||
|
||||
|
||||
class Buffer(object):
|
||||
def __init__(self, scope, name=None):
|
||||
self.scope = scope
|
||||
buf_name = 'Buffer_'
|
||||
if name: buf_name += name
|
||||
self.name = _name.NameManager.current.get(buf_name)
|
||||
self.shape = []
|
||||
self.offset_index = []
|
||||
|
||||
def reshape(self, domain):
|
||||
for r in domain:
|
||||
self.shape.append(r.extent)
|
||||
self.offset_index.append(r.begin)
|
||||
|
||||
def __call__(self, *global_index):
|
||||
if len(global_index) != len(self.shape):
|
||||
raise ValueError("Need to provide %d index in buffer slice" % len(self.shape))
|
||||
stride = [1]
|
||||
for i in reversed(range(1, len(self.shape))):
|
||||
stride.insert(0, self.shape[i] * stride[0])
|
||||
local_index = []
|
||||
for i in range(0, len(global_index)):
|
||||
local_index.append(global_index[i] - self.offset_index[i])
|
||||
index = local_index[0] * stride[0]
|
||||
for i in range(1, len(local_index)):
|
||||
index = index + local_index[i] * stride[i]
|
||||
index = _expr_util.simplify(index)
|
||||
return _expr.TensorRefExpr(self, [index])
|
||||
|
||||
|
||||
class BufferManager(object):
|
||||
def __init__(self):
|
||||
self._buffer_map = {}
|
||||
self._old_manager = None
|
||||
|
||||
def get(self, tensor):
|
||||
if tensor in self._buffer_map:
|
||||
return self._buffer_map[tensor]
|
||||
return None
|
||||
|
||||
def bind(self, tensor, buf):
|
||||
self._buffer_map[tensor] = buf
|
||||
|
||||
def __enter__(self):
|
||||
self._old_manager = BufferManager.current
|
||||
BufferManager.current = self
|
||||
return self
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
assert self._old_manager
|
||||
BufferManager.current = self._old_manager
|
||||
|
||||
# initialize the default buffer manager
|
||||
BufferManager.current = BufferManager()
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import buffer as _buffer
|
||||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
|
||||
def gen_code(expr):
|
||||
"""change expression to string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
Input expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
s : str
|
||||
The string representation of expr
|
||||
"""
|
||||
def make_str(e, result_children):
|
||||
if isinstance(e, _expr.BinaryOpExpr):
|
||||
return e.op.format_str(result_children[0], result_children[1])
|
||||
elif isinstance(e, _expr.UnaryOpExpr):
|
||||
return e.op.format_str(result_children[0])
|
||||
elif isinstance(e, _expr.ConstExpr):
|
||||
return str(e.value)
|
||||
elif isinstance(e, _expr.Var):
|
||||
return e.name
|
||||
elif isinstance(e, _expr.TensorRefExpr):
|
||||
buf = _buffer.BufferManager.current.get(e.tensor)
|
||||
if buf:
|
||||
return _expr_util.format_str(buf(*e.indices))
|
||||
return _expr_util.format_str(e.tensor(*e.indices, flatten=True))
|
||||
elif isinstance(e, _expr.ReduceExpr):
|
||||
return e.op.format_reduce_stmt_str(result_children[0])
|
||||
else:
|
||||
raise TypeError("Do not know how to handle type " + str(type(e)))
|
||||
return _expr_util.transform(expr, make_str)
|
||||
|
|
@ -112,7 +112,7 @@ class UnaryOpExpr(Expr):
|
|||
|
||||
|
||||
class ReduceExpr(Expr):
|
||||
def __init__(self, op, src, rdom):
|
||||
def __init__(self, op, src, rdom):
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.rdom = rdom
|
||||
|
@ -121,8 +121,8 @@ class ReduceExpr(Expr):
|
|||
return (self.src,)
|
||||
|
||||
|
||||
class TensorReadExpr(Expr):
|
||||
"""Tensor read expression, tensor[indices]"""
|
||||
class TensorRefExpr(Expr):
|
||||
"""Tensor reference expression, tensor[indices]"""
|
||||
def __init__(self, tensor, indices):
|
||||
self.tensor = tensor
|
||||
self.indices = indices
|
||||
|
|
|
@ -27,8 +27,12 @@ def expr_with_new_children(e, children):
|
|||
else _expr.BinaryOpExpr(e.op, children[0], children[1]))
|
||||
elif isinstance(e, _expr.UnaryOpExpr):
|
||||
return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0])
|
||||
elif isinstance(e, _expr.TensorRefExpr):
|
||||
return e if children == e.indices else _expr.TensorRefExpr(e.tensor, children)
|
||||
elif isinstance(e, _expr.ReduceExpr):
|
||||
return e if children[0] == e.src else _expr.ReduceExpr(e.op, children[0], e.rdom)
|
||||
else:
|
||||
raise TypeError("donnot know how to handle Expr %s" % type(e))
|
||||
raise TypeError("do not know how to handle Expr %s" % type(e))
|
||||
else:
|
||||
return e
|
||||
|
||||
|
@ -92,8 +96,8 @@ def format_str(expr):
|
|||
return str(e.value)
|
||||
elif isinstance(e, _expr.Var):
|
||||
return e.name
|
||||
elif isinstance(e, _expr.TensorReadExpr):
|
||||
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
|
||||
elif isinstance(e, _expr.TensorRefExpr):
|
||||
return "%s[%s]" % (e.tensor.name, ','.join(result_children))
|
||||
elif isinstance(e, _expr.ReduceExpr):
|
||||
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
|
||||
else:
|
||||
|
@ -120,7 +124,7 @@ def simplify(expr):
|
|||
elif isinstance(e, _expr.UnaryOpExpr):
|
||||
return e.op.canonical(result_children[0])
|
||||
elif isinstance(e, _expr.ConstExpr):
|
||||
return {_op.constant_canonical_key: e.value}
|
||||
return {_op.const_canonical_key: e.value}
|
||||
elif isinstance(e, _expr.Var):
|
||||
return {e: 1}
|
||||
else:
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
|
||||
constant_canonical_key = '__constant__'
|
||||
const_canonical_key = '__constant__'
|
||||
|
||||
def canonical_to_expr(c):
|
||||
elements = []
|
||||
for k, v in sorted(c.items()):
|
||||
if k == constant_canonical_key and v != 0:
|
||||
if k == const_canonical_key and v != 0:
|
||||
elements.append(_expr.const(v))
|
||||
elif v == 0:
|
||||
continue
|
||||
|
@ -35,6 +35,10 @@ class AddOp(BinaryOp):
|
|||
def format_reduce_str(self, src, rd):
|
||||
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
|
||||
|
||||
def format_reduce_stmt_str(self, src):
|
||||
# a temporary hack for now
|
||||
return "+ %s" % (src)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
lhs = lhs.copy()
|
||||
for k, v in rhs.items():
|
||||
|
@ -86,8 +90,15 @@ class DivOp(BinaryOp):
|
|||
erhs = canonical_to_expr(rhs)
|
||||
if isinstance(erhs, _expr.ConstExpr):
|
||||
lhs = lhs.copy()
|
||||
remove = []
|
||||
for k, v in lhs.items():
|
||||
lhs[k] /= float(erhs.value)
|
||||
if k == const_canonical_key:
|
||||
lhs[k] = v / erhs.value
|
||||
else:
|
||||
lhs[k / erhs] = 1
|
||||
remove.append(k)
|
||||
for k in remove:
|
||||
del lhs[k]
|
||||
return lhs
|
||||
elhs = canonical_to_expr(lhs)
|
||||
return {elhs / erhs: 1}
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import domain as _dom
|
||||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
from . import split as _split
|
||||
from . import buffer as _buffer
|
||||
from . import codegen as _gen
|
||||
|
||||
start_point_key = '__start__'
|
||||
TAB = ' '
|
||||
|
||||
class Schedule(object):
|
||||
"""SUnit defines the compute schedule of a tensor
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor: tensor
|
||||
"""
|
||||
def __init__(self, tensor, buffer=None):
|
||||
self.tensor = tensor
|
||||
self.buffer = buffer
|
||||
self.parent = None
|
||||
#self.children = []
|
||||
self.splits = []
|
||||
self.split_attach = {start_point_key: []}
|
||||
self.implicit_splits = [_split.Split(i, 1) for i in range(tensor.ndim)]
|
||||
if isinstance(tensor.expr, _expr.ReduceExpr):
|
||||
for i in range(len(tensor.expr.rdom.domain)):
|
||||
self.implicit_splits.append(_split.Split(i, 1, rdom=True))
|
||||
|
||||
def add_split(self, split):
|
||||
self.splits.append(split)
|
||||
self.split_attach[split] = []
|
||||
|
||||
def set_buffer(self, buf):
|
||||
self.buffer = buf
|
||||
|
||||
def attach(self, split, other):
|
||||
other.parent = self
|
||||
if split is None:
|
||||
self.split_attach[start_point_key].append(other)
|
||||
else:
|
||||
self.split_attach[split].append(other)
|
||||
|
||||
def infer_inner_domain(self, domain):
|
||||
for split in self.splits:
|
||||
domain = split.infer_inner_domain(domain)
|
||||
return domain
|
||||
|
||||
def realize(self, domain=None, indent=''):
|
||||
|
||||
def realize_attach(lst):
|
||||
attach_tensors = [sch.tensor for sch in lst]
|
||||
attach_domains = self.tensor.infer_input_domains(domain, attach_tensors, red_domain=red_domain)
|
||||
for sch in lst:
|
||||
body.extend(sch.realize(attach_domains[sch.tensor], indent))
|
||||
|
||||
# init domain and red_domain
|
||||
if domain is None:
|
||||
domain = self.tensor.domain
|
||||
red_domain = self.tensor.expr.rdom.domain if isinstance(self.tensor.expr, _expr.ReduceExpr) else None
|
||||
|
||||
# init buffer shape
|
||||
if self.buffer:
|
||||
if self.buffer.scope == _buffer.Scope.Global:
|
||||
self.buffer.reshape(self.tensor.domain)
|
||||
else:
|
||||
# don't handle shared buffer for now
|
||||
self.buffer.reshape(domain)
|
||||
_buffer.BufferManager.current.bind(self.tensor, self.buffer)
|
||||
|
||||
body = []
|
||||
|
||||
if self.split_attach[start_point_key]:
|
||||
realize_attach(self.split_attach[start_point_key])
|
||||
|
||||
# add loop conditions for splits
|
||||
for split in self.splits:
|
||||
if split.rdom:
|
||||
red_domain = split.generate_loop_condition(red_domain, body, indent)
|
||||
else:
|
||||
domain = split.generate_loop_condition(domain, body, indent)
|
||||
indent += TAB
|
||||
if self.split_attach[split]:
|
||||
realize_attach(self.split_attach[split])
|
||||
|
||||
# add implicit loop conditions
|
||||
for split in self.implicit_splits:
|
||||
if split.rdom:
|
||||
red_domain = split.generate_loop_condition(red_domain, body, indent)
|
||||
else:
|
||||
domain = split.generate_loop_condition(domain, body, indent)
|
||||
indent += TAB
|
||||
|
||||
# add loop body
|
||||
expr = self.tensor.expr
|
||||
global_index = [r.begin for r in domain]
|
||||
global_rdom_index = [r.begin for r in red_domain] if red_domain else []
|
||||
if expr is None:
|
||||
if self.buffer:
|
||||
lhs = self.buffer(*global_index)
|
||||
rhs = self.tensor(*global_index, flatten=True)
|
||||
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _expr_util.format_str(rhs)))
|
||||
else:
|
||||
if self.buffer:
|
||||
lhs = self.buffer(*global_index)
|
||||
else:
|
||||
lhs = self.tensor(*global_index, flatten=True)
|
||||
|
||||
bind_dict = {}
|
||||
for i in range(self.tensor.ndim):
|
||||
bind_dict[self.tensor.dim_index[i]] = global_index[i]
|
||||
if isinstance(expr, _expr.ReduceExpr):
|
||||
for i in range(len(expr.rdom.domain)):
|
||||
bind_dict[expr.rdom.index[i]] = global_rdom_index[i]
|
||||
rhs = _expr_util.bind(expr, bind_dict)
|
||||
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _gen.gen_code(rhs)))
|
||||
|
||||
# add right brackets
|
||||
for split in self.implicit_splits:
|
||||
indent = indent[:-len(TAB)]
|
||||
body.append('%s}' % indent)
|
||||
for split in self.splits:
|
||||
indent = indent[:-len(TAB)]
|
||||
body.append('%s}' % indent)
|
||||
|
||||
return body
|
|
@ -1,22 +1,35 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
from . import domain as _dom
|
||||
from . import tensor as _tensor
|
||||
|
||||
|
||||
class Split(object):
|
||||
def __init__(self, dim, factor):
|
||||
def __init__(self, dim, factor, name=None, rdom=False):
|
||||
self.dim = dim
|
||||
self.factor = factor
|
||||
self.loop_index = _expr.Var('loop_index_%d_' % dim)
|
||||
self.rdom = rdom
|
||||
if name is None:
|
||||
name = 'loop_index_%d_' % dim
|
||||
self.loop_index = _expr.Var(name)
|
||||
|
||||
def infer_inner_domain(self, domain):
|
||||
if isinstance(domain, _dom.RDom):
|
||||
domain = domain.domain
|
||||
assert self.dim < len(domain)
|
||||
inner_domain = domain[:]
|
||||
dim_out_range = domain[self.dim]
|
||||
def infer_inner_domain(self, out_domain):
|
||||
assert self.dim < len(out_domain)
|
||||
inner_domain = out_domain[:]
|
||||
dim_out_range = out_domain[self.dim]
|
||||
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
|
||||
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
|
||||
return inner_domain
|
||||
|
||||
def generate_loop_condition(self, out_domain, body, indent):
|
||||
assert self.dim < len(out_domain)
|
||||
loop_range = _dom.Range(out_domain[self.dim].extent / self.factor)
|
||||
stmt = '%sfor (int %s = 0; %s < %s; %s += 1) {' % (
|
||||
indent,
|
||||
self.loop_index.name,
|
||||
self.loop_index.name,
|
||||
_expr_util.format_str(loop_range.end),
|
||||
self.loop_index.name)
|
||||
body.append(stmt)
|
||||
return self.infer_inner_domain(out_domain)
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
|
|||
from . import expr as _expr
|
||||
from . import expr_util as _expr_util
|
||||
from . import domain as _dom
|
||||
from . import var_name as _name
|
||||
|
||||
|
||||
class Tensor(object):
|
||||
|
@ -23,13 +24,26 @@ class Tensor(object):
|
|||
self.shape = shape if shape else tuple(
|
||||
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
|
||||
|
||||
self.name = name if name else "TensorObj"
|
||||
self.name = name if name else _name.NameManager.current.get("TensorObj")
|
||||
self.inputs = None
|
||||
|
||||
def __call__(self, *indices):
|
||||
def __call__(self, *indices, **option):
|
||||
if len(indices) != self.ndim:
|
||||
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
|
||||
return _expr.TensorReadExpr(self, indices)
|
||||
if 'flatten' in option and option['flatten']:
|
||||
stride = [1]
|
||||
for i in reversed(range(1, len(indices))):
|
||||
stride.insert(0, self.shape[i] * stride[0])
|
||||
index = indices[0] * stride[0]
|
||||
for i in range(1, len(indices)):
|
||||
index = index + indices[i] * stride[i]
|
||||
index = _expr_util.simplify(index)
|
||||
return _expr.TensorRefExpr(self, [index])
|
||||
return _expr.TensorRefExpr(self, indices)
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return _dom.Domain([_dom.Range(self.shape[i]) for i in range(self.ndim)])
|
||||
|
||||
def input_tensors(self):
|
||||
"""List of input tensors to this tensor.
|
||||
|
@ -43,7 +57,7 @@ class Tensor(object):
|
|||
inputs = []
|
||||
if self.expr:
|
||||
def collect(e):
|
||||
if isinstance(e, _expr.TensorReadExpr):
|
||||
if isinstance(e, _expr.TensorRefExpr):
|
||||
inputs.append(e.tensor)
|
||||
_expr_util.visit(self.expr, collect)
|
||||
self.inputs = set(inputs)
|
||||
|
@ -93,7 +107,7 @@ class Tensor(object):
|
|||
rd = e.rdom
|
||||
for i in range(len(rd.domain)):
|
||||
index_domains[rd.index[i]] = rd.domain[i]
|
||||
elif isinstance(e, _expr.TensorReadExpr):
|
||||
elif isinstance(e, _expr.TensorRefExpr):
|
||||
if e.tensor in iset:
|
||||
iset[e.tensor].append(e)
|
||||
_expr_util.visit(begin_expr, prepare)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import tvm
|
||||
|
||||
def test_buffer():
|
||||
buf = tvm.Buffer(tvm.Scope.Thread)
|
||||
shape = [32, 16]
|
||||
domain = [tvm.Range(v) for v in shape]
|
||||
buf.reshape(domain)
|
||||
x = tvm.Var('x')
|
||||
y = tvm.Var('y')
|
||||
assert tvm.format_str(buf(y, x)) == '%s[(%s + (%s * %s))]' % (buf.name, x.name, y.name, shape[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_buffer()
|
|
@ -0,0 +1,41 @@
|
|||
import tvm
|
||||
|
||||
def test_schedule():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
B = tvm.Tensor(2, name='B')
|
||||
rd = tvm.RDom(tvm.Range(A.shape[1]))
|
||||
T = tvm.Tensor(2, lambda i, j:
|
||||
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
|
||||
shape=(A.shape[0], B.shape[0]), name="T")
|
||||
C = tvm.Tensor(2, lambda i, j: T(i,j),
|
||||
shape=(A.shape[0], B.shape[0]), name="C")
|
||||
|
||||
bufA = tvm.Buffer(tvm.Scope.Thread, name='A')
|
||||
bufB = tvm.Buffer(tvm.Scope.Thread, name='B')
|
||||
bufT = tvm.Buffer(tvm.Scope.Thread, name='T')
|
||||
|
||||
schA = tvm.Schedule(A, buffer=bufA)
|
||||
schB = tvm.Schedule(B, buffer=bufB)
|
||||
schT = tvm.Schedule(T, buffer=bufT)
|
||||
schC = tvm.Schedule(C)
|
||||
Cx0 = tvm.Split(dim=0, factor=64)
|
||||
Cy0 = tvm.Split(dim=1, factor=64)
|
||||
Cx1 = tvm.Split(dim=0, factor=8)
|
||||
Cy1 = tvm.Split(dim=1, factor=8)
|
||||
Tk = tvm.Split(dim=0, factor=8, rdom=True)
|
||||
|
||||
schC.add_split(Cx0)
|
||||
schC.add_split(Cy0)
|
||||
schC.add_split(Cx1)
|
||||
schC.add_split(Cy1)
|
||||
schT.add_split(Tk)
|
||||
schC.attach(Cy1, schT)
|
||||
schT.attach(Tk, schA)
|
||||
schT.attach(Tk, schB)
|
||||
|
||||
body = schC.realize()
|
||||
print('\n'.join(body))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_schedule()
|
|
@ -1,8 +1,8 @@
|
|||
import tvm
|
||||
|
||||
|
||||
def test_split_dom_infer():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
rd = tvm.RDom(tvm.Range(A.shape[1]))
|
||||
split1 = tvm.Split(0, 64)
|
||||
split2 = tvm.Split(1, 64)
|
||||
split3 = tvm.Split(0, 8)
|
||||
|
@ -10,14 +10,12 @@ def test_split_dom_infer():
|
|||
dom1 = split1.infer_inner_domain(dom)
|
||||
dom2 = split2.infer_inner_domain(dom1)
|
||||
dom3 = split3.infer_inner_domain(dom2)
|
||||
dom4 = split3.infer_inner_domain(rd)
|
||||
i1 = split1.loop_index.name
|
||||
i2 = split2.loop_index.name
|
||||
i3 = split3.loop_index.name
|
||||
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
|
||||
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
|
||||
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
|
||||
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_tensor_inputs():
|
|||
B = tvm.Tensor(2, name='B')
|
||||
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
|
||||
shape=(A.shape[0], B.shape[0], A.shape[1]))
|
||||
assert(T.input_tensors() == [A, B])
|
||||
assert(T.input_tensors() == set([A, B]))
|
||||
|
||||
def test_tensor_reduce():
|
||||
A = tvm.Tensor(2, name='A')
|
||||
|
@ -25,5 +25,6 @@ def test_tensor_reduce():
|
|||
print(tvm.format_str(C.expr))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tensor()
|
||||
test_tensor_inputs()
|
||||
test_tensor_reduce()
|
||||
|
|
Загрузка…
Ссылка в новой задаче