onnxruntime-tvm/python/tvm/schedule.py

661 строка
20 KiB
Python

"""The computation schedule api of TVM."""
from __future__ import absolute_import as _abs
from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Node or Function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_node(value)
@register_node
class Buffer(NodeBase):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :any:`decl_buffer` instead.
See the documentation of :any:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
return _api_internal._BufferVStore(self, begin, value)
@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass
@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass
@register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass
@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
IterVar is normally created by Operation, to represent
axis iterations in the computation.
It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
_tensor.iter_var_cls = IterVar
def create_schedule(ops):
"""Create a schedule for list of ops
Parameters
----------
ops : list of Operations
The source expression.
Returns
-------
sch : schedule.Schedule
The created schedule.
"""
if not isinstance(ops, (list, _container.Array)):
ops = [ops]
return _api_internal._CreateSchedule(ops)
@register_node
class Schedule(NodeBase):
"""Schedule for all the stages."""
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
k = k.op
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if k not in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule from the current schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
Returns
-------
sch : Schedule
The normalized schedule.
"""
return _api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
Returns
-------
group : Stage
A virtual stage represents the group, user can use compute_at to move
the attachment point of the group.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
This function can be used to support data layout transformation.
If there is a split/fuse/reorder on the data parallel axis of tensor
before cache_write is called. The intermediate cache stores
the data in the layout as the iteration order of leave axis.
The data will be transformed back to the original layout in the original tensor.
User can further call compute_inline to inline the original layout and keep
the data stored in the transformed layout.
Parameters
----------
tensor : Tensor, list or tuple
The tensors to be feed to. All the tensors must be produced by one computeOp
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis, factor_axis=0):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.
Parameters
----------
tensor : Tensor
The tensor to be factored.
axis : IterVar
The reduction axis in the schedule to be factored.
factor_axis : int
The position where the new axis is placed.
Returns
-------
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis)
return factored[0] if len(factored) == 1 else factored
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both
Parameters
----------
parent : IterVar
The parent iter var.
factor : Expr, optional
The splitting factor
nparts : Expr, optional
The number of outer parts.
Returns
-------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
"""
if nparts is not None:
if factor is not None:
raise ValueError("Do not need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
else:
if factor is None:
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, *args):
"""Fuse multiple consecutive iteration variables into a single iteration variable.
fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
The order is from outer to inner.
Parameters
----------
args : list of IterVars
Itervars that proceeds each other
Returns
-------
fused : IterVar
The fused variable of iteration.
"""
fused = _api_internal._StageFuse(self, args)
return fused
def set_scope(self, scope):
"""Set the thread scope of this stage
Parameters
----------
scope : str
The thread scope of this stage
"""
return _api_internal._StageSetScope(self, scope)
def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, IterVar):
threads = [threads]
_api_internal._StageEnvThreads(self, threads)
def set_store_predicate(self, predicate):
"""Set predicate under which store to the array can be performed.
Use this when there are duplicated threads doing the same store and we only
need one of them to do the store.
Parameters
----------
predicate : Expr
The guard condition fo store.
"""
_api_internal._StageSetStorePredicate(self, predicate)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
Parameters
----------
parent : Stage
The parent stage
scope : IterVar
The loop scope t be attached to.
"""
_api_internal._StageComputeAt(self, parent, scope)
def compute_inline(self):
"""Mark stage as inline
Parameters
----------
parent : Stage
The parent stage
"""
_api_internal._StageComputeInline(self)
def compute_root(self):
"""Attach the stage at parent, and mark it as root
Parameters
----------
parent : Stage
The parent stage
"""
_api_internal._StageComputeRoot(self)
def reorder(self, *args):
"""reorder the arguments in the specified order.
Parameters
----------
args : list of IterVar
The order to be ordered
"""
_api_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions
The final loop order from outmost to inner most are
[x_outer, y_outer, x_inner, y_inner]
Parameters
----------
x_parent : IterVar
The original x dimension
y_parent : IterVar
The original y dimension
x_factor : Expr
The stride factor on x axis
y_factor : Expr
The stride factor on y axis
Returns
-------
x_outer : IterVar
Outer axis of x dimension
y_outer : IterVar
Outer axis of y dimension
x_inner : IterVar
Inner axis of x dimension
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
def vectorize(self, var):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)
def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin
Parameters
----------
var : IterVar
The iteration boundary of tensorization.
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
"""
_api_internal._StageTensorize(self, var, tensor_intrin)
def unroll(self, var):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
def parallel(self, var):
"""Parallelize the iteration.
Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_api_internal._StageParallel(self, var)
def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
the corresponding loop generated.
Useful to support experimental features and extensions.
Parameters
----------
var : IterVar
The iteration to be anotated
pragma_type : str
The pragma string to be annotated
pragma_value : Expr, optional
The pragma value to pass along the pragma
Note
----
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **debug_skip_region**
Force skip the region marked by the axis and turn it into no-op.
This is useful for debug purposes.
- **parallel_launch_point**
Specify to launch parallel threads outside the
specified iteration loop. By default the threads
launch at the point of parallel construct.
This pragma moves the launching point to even outer scope.
The threads are launched once and reused across multiple
parallel constructs as BSP style program.
- **parallel_barrier_when_finish**
Insert a synchronization barrier between working threads
after the specified loop iteration finishes.
- **parallel_stride_pattern**
Hint parallel loop to execute in strided pattern.
:code:`for (int i = task_id; i < end; i += num_task)`
"""
if isinstance(pragma_value, string_types):
pragma_value = convert(pragma_value)
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
Parameters
----------
tensor : Tensor
The tensor to be prefetched
var : IterVar
The loop point at which the prefetching is applied
offset : Expr
The number of iterations to be prefetched before actual execution
"""
_api_internal._StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis
This ensures that stride[axis] == k * factor + offset for some k.
This is useful to set memory layout to for more friendly memory
access pattern. For example, we can set alignment to be
factor=2, offset=1 to avoid bank conflict for thread access on
higher dimension in GPU shared memory.
Parameters
----------
axis : IterVar
The axis dimension to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)
def double_buffer(self):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal._StageDoubleBuffer(self)
def opengl(self):
"""The special OpenGL schedule
Maps each output element to a pixel.
"""
_api_internal._StageOpenGL(self)
_init_api("tvm.schedule")