[LANG] Generalize compute to tensor region (#1476)
This commit is contained in:
Родитель
3d62cf7c37
Коммит
b90620ea25
|
@ -1 +1 @@
|
||||||
Subproject commit 4f0564ec769477c66d480dd966088f172050c874
|
Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f
|
|
@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
|
||||||
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
|
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using Region = Array<Range>;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Type of iteration variable.
|
* \brief Type of iteration variable.
|
||||||
* Each IterVar have a specific type.
|
* Each IterVar have a specific type.
|
||||||
|
|
|
@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \return The list of iteration variable at root
|
* \return The list of iteration variable at root
|
||||||
* \note root_iter_vars dedides the shape of the outputs.
|
* \note root_iter_vars decides the shape of the outputs.
|
||||||
*/
|
*/
|
||||||
virtual Array<IterVar> root_iter_vars() const = 0;
|
virtual Array<IterVar> root_iter_vars() const = 0;
|
||||||
/*!
|
/*!
|
||||||
|
@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
|
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
|
||||||
|
*/
|
||||||
|
class TensorComputeOpNode : public OperationNode {
|
||||||
|
public:
|
||||||
|
/*! \brief IterVar on each axis */
|
||||||
|
Array<IterVar> axis;
|
||||||
|
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
|
||||||
|
Array<IterVar> reduce_axis;
|
||||||
|
/*! \brief number of axes that can be scheduled */
|
||||||
|
int schedulable_ndim;
|
||||||
|
/*! \brief TensorIntrin used to compute */
|
||||||
|
TensorIntrin intrin;
|
||||||
|
/*! \brief input tensors of intrin */
|
||||||
|
Array<Tensor> inputs;
|
||||||
|
/*! \brief region of input tensors */
|
||||||
|
Array<Region> input_regions;
|
||||||
|
/*! \brief constructor */
|
||||||
|
TensorComputeOpNode() {}
|
||||||
|
// override functions
|
||||||
|
int num_outputs() const final;
|
||||||
|
Array<IterVar> root_iter_vars() const final;
|
||||||
|
Type output_dtype(size_t i) const final;
|
||||||
|
Array<Expr> output_shape(size_t i) const final;
|
||||||
|
Array<Tensor> InputTensors() const final;
|
||||||
|
Operation ReplaceInputs(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<Tensor, Tensor>& rmap) const final;
|
||||||
|
void PropBoundToInputs(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||||
|
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
|
||||||
|
void GatherBound(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||||
|
std::unordered_map<IterVar, Range>* out_dom_map) const final;
|
||||||
|
Stmt BuildRealize(
|
||||||
|
const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& realize_map,
|
||||||
|
const Stmt& body) const final;
|
||||||
|
Stmt BuildProvide(
|
||||||
|
const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
bool debug_keep_trivial_loop) const final;
|
||||||
|
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("name", &name);
|
||||||
|
v->Visit("tag", &tag);
|
||||||
|
v->Visit("axis", &axis);
|
||||||
|
v->Visit("reduce_axis", &reduce_axis);
|
||||||
|
v->Visit("schedulable_ndim", &schedulable_ndim);
|
||||||
|
v->Visit("intrin", &intrin);
|
||||||
|
v->Visit("inputs", &inputs);
|
||||||
|
v->Visit("input_regions", &input_regions);
|
||||||
|
}
|
||||||
|
static Operation make(std::string name,
|
||||||
|
std::string tag,
|
||||||
|
Array<IterVar> axis,
|
||||||
|
Array<IterVar> reduce_axis,
|
||||||
|
int schedulable_ndim,
|
||||||
|
TensorIntrin intrin,
|
||||||
|
Array<Tensor> tensors,
|
||||||
|
Array<Region> regions);
|
||||||
|
|
||||||
|
static constexpr const char* _type_key = "TensorComputeOp";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Symbolic scan.
|
* \brief Symbolic scan.
|
||||||
*/
|
*/
|
||||||
|
@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
|
||||||
public:
|
public:
|
||||||
/*! \brief The input tensors */
|
/*! \brief The input tensors */
|
||||||
Array<Tensor> inputs;
|
Array<Tensor> inputs;
|
||||||
/*! \brief Symbolic placeholder representationinputs */
|
/*! \brief Symbolic placeholder representation of inputs */
|
||||||
Array<Buffer> input_placeholders;
|
Array<Buffer> input_placeholders;
|
||||||
/*! \brief Symbolic placeholder representation of outputs */
|
/*! \brief Symbolic placeholder representation of outputs */
|
||||||
Array<Buffer> output_placeholders;
|
Array<Buffer> output_placeholders;
|
||||||
|
|
|
@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
|
||||||
inline const TensorIntrinNode* TensorIntrin::operator->() const {
|
inline const TensorIntrinNode* TensorIntrin::operator->() const {
|
||||||
return static_cast<const TensorIntrinNode*>(node_.get());
|
return static_cast<const TensorIntrinNode*>(node_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Internal node container of tensor intrinsic calling.
|
||||||
|
class TensorIntrinCallNode;
|
||||||
|
|
||||||
|
/*! \brief Tensor intrinsic calling node. */
|
||||||
|
class TensorIntrinCall : public NodeRef {
|
||||||
|
public:
|
||||||
|
TensorIntrinCall() {}
|
||||||
|
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
|
||||||
|
/*!
|
||||||
|
* \brief access the internal node container
|
||||||
|
* \return the pointer to the internal node container
|
||||||
|
*/
|
||||||
|
inline const TensorIntrinCallNode* operator->() const;
|
||||||
|
|
||||||
|
/*! \brief specify container node */
|
||||||
|
using ContainerType = TensorIntrinCallNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TensorIntrinCallNode : public Node {
|
||||||
|
public:
|
||||||
|
/*! \brief the tensor intrinsic */
|
||||||
|
TensorIntrin intrin;
|
||||||
|
/*! \brief input tensors of the intrinsic */
|
||||||
|
Array<Tensor> tensors;
|
||||||
|
/*! \brief regions of input tensors */
|
||||||
|
Array<Region> regions;
|
||||||
|
/*!
|
||||||
|
* \brief IterVar on each reduction axis, if the
|
||||||
|
* intrin will use the reduce axis
|
||||||
|
*/
|
||||||
|
Array<IterVar> reduce_axis;
|
||||||
|
|
||||||
|
void VisitAttrs(AttrVisitor* v) final {
|
||||||
|
v->Visit("intrin", &intrin);
|
||||||
|
v->Visit("tensors", &tensors);
|
||||||
|
v->Visit("regions", ®ions);
|
||||||
|
v->Visit("reduce_axis", &reduce_axis);
|
||||||
|
}
|
||||||
|
static TensorIntrinCall make(TensorIntrin intrin,
|
||||||
|
Array<Tensor> tensors,
|
||||||
|
Array<Region> regions,
|
||||||
|
Array<IterVar> reduce_axis);
|
||||||
|
|
||||||
|
static constexpr const char* _type_key = "TensorIntrinCall";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
|
||||||
|
return static_cast<const TensorIntrinCallNode*>(node_.get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
#endif // TVM_TENSOR_INTRIN_H_
|
#endif // TVM_TENSOR_INTRIN_H_
|
||||||
|
|
|
@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
|
||||||
raise ValueError("nested tag is not allowed for now")
|
raise ValueError("nested tag is not allowed for now")
|
||||||
tag = _tag.TagScope.get_current().tag
|
tag = _tag.TagScope.get_current().tag
|
||||||
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
|
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
|
||||||
|
# for python3
|
||||||
|
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
|
||||||
ndim = len(shape)
|
ndim = len(shape)
|
||||||
code = fcompute.__code__
|
code = fcompute.__code__
|
||||||
|
|
||||||
if fcompute.__code__.co_argcount == 0:
|
out_ndim = ndim
|
||||||
|
if code.co_argcount == 0:
|
||||||
arg_names = ["i%d" % i for i in range(ndim)]
|
arg_names = ["i%d" % i for i in range(ndim)]
|
||||||
else:
|
else:
|
||||||
arg_names = code.co_varnames[:code.co_argcount]
|
arg_names = code.co_varnames[:code.co_argcount]
|
||||||
|
out_ndim = code.co_argcount
|
||||||
|
|
||||||
if ndim != len(arg_names):
|
if out_ndim != len(arg_names):
|
||||||
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
|
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
|
||||||
|
|
||||||
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
|
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
|
||||||
body = fcompute(*[v.var for v in dim_var])
|
body = fcompute(*[v.var for v in dim_var])
|
||||||
|
|
||||||
|
if isinstance(body, _tensor.TensorIntrinCall):
|
||||||
|
for i, s in enumerate(shape[out_ndim:]):
|
||||||
|
var_name = "ax" + str(i)
|
||||||
|
dim_var.append(_IterVar((0, s), var_name, 4))
|
||||||
|
op_node = _api_internal._TensorComputeOp(name,
|
||||||
|
tag,
|
||||||
|
dim_var,
|
||||||
|
body.reduce_axis,
|
||||||
|
out_ndim,
|
||||||
|
body.intrin,
|
||||||
|
body.tensors,
|
||||||
|
body.regions)
|
||||||
|
else:
|
||||||
if not isinstance(body, (list, tuple)):
|
if not isinstance(body, (list, tuple)):
|
||||||
body = [body]
|
body = [body]
|
||||||
body = convert(body)
|
body = convert(body)
|
||||||
op_node = _api_internal._ComputeOp(
|
op_node = _api_internal._ComputeOp(
|
||||||
name, tag, attrs, dim_var, body)
|
name, tag, attrs, dim_var, body)
|
||||||
|
|
||||||
num = op_node.num_outputs
|
num = op_node.num_outputs
|
||||||
outputs = tuple(op_node.output(i) for i in range(num))
|
outputs = tuple(op_node.output(i) for i in range(num))
|
||||||
return outputs[0] if num == 1 else outputs
|
return outputs[0] if num == 1 else outputs
|
||||||
|
@ -529,14 +548,14 @@ def decl_buffer(shape,
|
||||||
dtype = float32 if dtype is None else dtype
|
dtype = float32 if dtype is None else dtype
|
||||||
strides = () if strides is None else strides
|
strides = () if strides is None else strides
|
||||||
if offset_factor != 0 and elem_offset is None:
|
if offset_factor != 0 and elem_offset is None:
|
||||||
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
|
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
|
||||||
|
elem_offset = var('%s_elem_offset' % name, shape_dtype)
|
||||||
if data is None:
|
if data is None:
|
||||||
data = var(name, "handle")
|
data = var(name, "handle")
|
||||||
return _api_internal._Buffer(
|
return _api_internal._Buffer(
|
||||||
data, dtype, shape, strides, elem_offset, name, scope,
|
data, dtype, shape, strides, elem_offset, name, scope,
|
||||||
data_alignment, offset_factor)
|
data_alignment, offset_factor)
|
||||||
|
|
||||||
|
|
||||||
def _IterVar(dom, name, iter_type, thread_tag=''):
|
def _IterVar(dom, name, iter_type, thread_tag=''):
|
||||||
"""Internal function to create IterVar
|
"""Internal function to create IterVar
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
|
||||||
"""Data content of the tensor."""
|
"""Data content of the tensor."""
|
||||||
return self.tensor.dtype
|
return self.tensor.dtype
|
||||||
|
|
||||||
|
@register_node
|
||||||
|
class TensorIntrinCall(NodeBase):
|
||||||
|
"""Intermediate structure for calling a tensor intrinsic."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
itervar_cls = None
|
itervar_cls = None
|
||||||
|
|
||||||
|
@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
|
||||||
return "%s.v%d" % (op.name, self.value_index)
|
return "%s.v%d" % (op.name, self.value_index)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Operation(NodeBase):
|
class Operation(NodeBase):
|
||||||
"""Represent an operation that generate a tensor"""
|
"""Represent an operation that generate a tensor"""
|
||||||
|
|
||||||
|
@ -155,6 +161,12 @@ class ComputeOp(Operation):
|
||||||
return self.__getattr__("reduce_axis")
|
return self.__getattr__("reduce_axis")
|
||||||
|
|
||||||
|
|
||||||
|
@register_node
|
||||||
|
class TensorComputeOp(Operation):
|
||||||
|
"""Tensor operation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@register_node
|
@register_node
|
||||||
class ScanOp(Operation):
|
class ScanOp(Operation):
|
||||||
"""Scan operation."""
|
"""Scan operation."""
|
||||||
|
|
|
@ -6,9 +6,25 @@ from . import expr as _expr
|
||||||
from . import stmt as _stmt
|
from . import stmt as _stmt
|
||||||
from . import make as _make
|
from . import make as _make
|
||||||
from . import tensor as _tensor
|
from . import tensor as _tensor
|
||||||
|
from . import schedule as _schedule
|
||||||
from .build_module import current_build_config
|
from .build_module import current_build_config
|
||||||
from ._ffi.node import NodeBase, register_node
|
from ._ffi.node import NodeBase, register_node
|
||||||
|
|
||||||
|
|
||||||
|
def _get_region(tslice):
|
||||||
|
region = []
|
||||||
|
for idx in tslice.indices:
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
assert idx.step is None
|
||||||
|
region.append(_api.Range(idx.start, idx.stop))
|
||||||
|
else:
|
||||||
|
if isinstance(idx, _schedule.IterVar):
|
||||||
|
begin = idx.var
|
||||||
|
else:
|
||||||
|
begin = idx
|
||||||
|
region.append(_make.range_by_min_extent(begin, 1))
|
||||||
|
return region
|
||||||
|
|
||||||
@register_node
|
@register_node
|
||||||
class TensorIntrin(NodeBase):
|
class TensorIntrin(NodeBase):
|
||||||
"""Tensor intrinsic functions for certain computation.
|
"""Tensor intrinsic functions for certain computation.
|
||||||
|
@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
|
||||||
--------
|
--------
|
||||||
decl_tensor_intrin: Construct a TensorIntrin
|
decl_tensor_intrin: Construct a TensorIntrin
|
||||||
"""
|
"""
|
||||||
pass
|
def __call__(self, *args, **kwargs):
|
||||||
|
tensors = [x.tensor for x in args]
|
||||||
|
regions = [_get_region(x) for x in args]
|
||||||
|
reduce_axis = []
|
||||||
|
if "reduce_axis" in kwargs:
|
||||||
|
reduce_axis = kwargs["reduce_axis"]
|
||||||
|
if not isinstance(reduce_axis, (list, tuple)):
|
||||||
|
reduce_axis = [reduce_axis]
|
||||||
|
reduce_axis = _api.convert(reduce_axis)
|
||||||
|
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
|
||||||
|
|
||||||
def decl_tensor_intrin(op,
|
def decl_tensor_intrin(op,
|
||||||
fcompute,
|
fcompute,
|
||||||
|
|
|
@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
|
||||||
args[6]);
|
args[6]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API("_TensorIntrinCall")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = TensorIntrinCallNode::make(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2],
|
||||||
|
args[3]);
|
||||||
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("_TensorEqual")
|
TVM_REGISTER_API("_TensorEqual")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
*ret = args[0].operator Tensor() == args[1].operator Tensor();
|
*ret = args[0].operator Tensor() == args[1].operator Tensor();
|
||||||
|
@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
|
||||||
args[7]);
|
args[7]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API("_TensorComputeOp")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = TensorComputeOpNode::make(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2],
|
||||||
|
args[3],
|
||||||
|
args[4],
|
||||||
|
args[5],
|
||||||
|
args[6],
|
||||||
|
args[7]);
|
||||||
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("_ExternOp")
|
TVM_REGISTER_API("_ExternOp")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
*ret = ExternOpNode::make(args[0],
|
*ret = ExternOpNode::make(args[0],
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
// Tensor
|
||||||
|
|
||||||
Expr Tensor::operator()(Array<Var> indices) const {
|
Expr Tensor::operator()(Array<Var> indices) const {
|
||||||
Array<Expr> arr(indices.begin(), indices.end());
|
Array<Expr> arr(indices.begin(), indices.end());
|
||||||
return operator()(arr);
|
return operator()(arr);
|
||||||
|
@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor Operation::output(size_t i) const {
|
||||||
|
auto node = make_node<TensorNode>();
|
||||||
|
node->op = *this;
|
||||||
|
node->value_index = i;
|
||||||
|
node->dtype = (*this)->output_dtype(i);
|
||||||
|
node->shape = (*this)->output_shape(i);
|
||||||
|
return Tensor(node);
|
||||||
|
}
|
||||||
|
|
||||||
Tensor TensorNode::make(Array<Expr> shape,
|
Tensor TensorNode::make(Array<Expr> shape,
|
||||||
Type dtype,
|
Type dtype,
|
||||||
Operation op,
|
Operation op,
|
||||||
|
@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(TensorNode);
|
TVM_REGISTER_NODE_TYPE(TensorNode);
|
||||||
|
|
||||||
Tensor Operation::output(size_t i) const {
|
|
||||||
auto node = make_node<TensorNode>();
|
// TensorIntrin
|
||||||
node->op = *this;
|
|
||||||
node->value_index = i;
|
|
||||||
node->dtype = (*this)->output_dtype(i);
|
|
||||||
node->shape = (*this)->output_shape(i);
|
|
||||||
return Tensor(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorIntrin TensorIntrinNode::make(std::string name,
|
TensorIntrin TensorIntrinNode::make(std::string name,
|
||||||
Operation op,
|
Operation op,
|
||||||
|
@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
|
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
|
||||||
|
|
||||||
|
|
||||||
|
// TensorIntrinCall
|
||||||
|
|
||||||
|
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
|
||||||
|
Array<Tensor> tensors,
|
||||||
|
Array<Region> regions,
|
||||||
|
Array<IterVar> reduce_axis) {
|
||||||
|
auto n = make_node<TensorIntrinCallNode>();
|
||||||
|
n->intrin = std::move(intrin);
|
||||||
|
n->tensors = std::move(tensors);
|
||||||
|
n->regions = std::move(regions);
|
||||||
|
n->reduce_axis = std::move(reduce_axis);
|
||||||
|
return TensorIntrinCall(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||||
|
.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
|
||||||
|
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
|
||||||
|
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "compute_op.h"
|
#include "compute_op.h"
|
||||||
#include "op_util.h"
|
#include "op_util.h"
|
||||||
#include "../schedule/message_passing.h"
|
#include "../schedule/message_passing.h"
|
||||||
|
#include "../arithmetic/compute_expr.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
|
||||||
v.Run();
|
v.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stmt TransformUpdate(const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
const ComputeLoopNest& n,
|
||||||
|
Stmt body,
|
||||||
|
Stmt update) {
|
||||||
|
Array<Expr> conds;
|
||||||
|
std::unordered_set<const Variable*> banned;
|
||||||
|
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
|
||||||
|
IterVar iv = stage->leaf_iter_vars[i];
|
||||||
|
auto iit = stage->iter_var_attrs.find(iv);
|
||||||
|
if (iit != stage->iter_var_attrs.end()) {
|
||||||
|
const IterVarAttr& attr = (*iit).second;
|
||||||
|
if (attr->iter_type == kTensorized) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (iv->iter_type == kCommReduce) {
|
||||||
|
auto vit = dom_map.find(iv);
|
||||||
|
CHECK(vit != dom_map.end());
|
||||||
|
const Range& vrange = vit->second;
|
||||||
|
conds.push_back(likely(iv->var > vrange->min));
|
||||||
|
banned.insert(iv->var.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const Expr& pred : n.main_predicates) {
|
||||||
|
if (ir::ExprUseVar(pred, banned)) {
|
||||||
|
LOG(FATAL) << "Tensorize update transform failed, the condition "
|
||||||
|
<< pred << " has a conflict with the reset condition";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
|
||||||
|
update, body);
|
||||||
|
}
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
// loop nest structure for general compute
|
// loop nest structure for general compute
|
||||||
// This the the loop nest structured used in compute.
|
// This the loop nest structured used in compute.
|
||||||
// Does not include the loop body.
|
// Does not include the loop body.
|
||||||
struct ComputeLoopNest {
|
struct ComputeLoopNest {
|
||||||
// The common number of loops between init and main
|
// The common number of loops between init and main
|
||||||
|
@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
|
||||||
const Stage& stage,
|
const Stage& stage,
|
||||||
const std::unordered_map<IterVar, Range>& dom_map,
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
bool debug_keep_trivial_loop);
|
bool debug_keep_trivial_loop);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Transform the update part when there is no init func in tensorizing
|
||||||
|
* \param stage The stage for tensorizing.
|
||||||
|
* \param dom_map The range of each iter var.
|
||||||
|
* \param n The loop nest structured used in compute.
|
||||||
|
* \param body The body func in tensorize intrin
|
||||||
|
* \param update The update func in tensorize intrin
|
||||||
|
* \return Transformed result.
|
||||||
|
*/
|
||||||
|
Stmt TransformUpdate(const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
const ComputeLoopNest& n,
|
||||||
|
Stmt body,
|
||||||
|
Stmt update);
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
||||||
#endif // TVM_OP_COMPUTE_OP_H_
|
#endif // TVM_OP_COMPUTE_OP_H_
|
||||||
|
|
|
@ -0,0 +1,361 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \brief Tensor Compute Op.
|
||||||
|
* \file tensor_compute_op.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/operation.h>
|
||||||
|
#include <tvm/arithmetic.h>
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/ir_visitor.h>
|
||||||
|
#include <tvm/ir_pass.h>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "./op_util.h"
|
||||||
|
#include "./compute_op.h"
|
||||||
|
#include "../arithmetic/compute_expr.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
using namespace ir;
|
||||||
|
// TensorComputeOpNode
|
||||||
|
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||||
|
.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op,
|
||||||
|
IRPrinter *p) {
|
||||||
|
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
|
||||||
|
|
||||||
|
int TensorComputeOpNode::num_outputs() const {
|
||||||
|
return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<IterVar> TensorComputeOpNode::root_iter_vars() const {
|
||||||
|
Array<IterVar> ret = axis;
|
||||||
|
for (IterVar iv : reduce_axis) {
|
||||||
|
ret.push_back(iv);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type TensorComputeOpNode::output_dtype(size_t i) const {
|
||||||
|
return this->intrin->buffers[this->inputs.size() + i]->dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<Expr> TensorComputeOpNode::output_shape(size_t i) const {
|
||||||
|
Array<Expr> shape;
|
||||||
|
for (const auto& ivar : this->axis) {
|
||||||
|
shape.push_back(ivar->dom->extent);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Operation TensorComputeOpNode::make(std::string name,
|
||||||
|
std::string tag,
|
||||||
|
Array<IterVar> axis,
|
||||||
|
Array<IterVar> reduce_axis,
|
||||||
|
int schedulable_ndim,
|
||||||
|
TensorIntrin intrin,
|
||||||
|
Array<Tensor> tensors,
|
||||||
|
Array<Region> regions) {
|
||||||
|
auto n = make_node<TensorComputeOpNode>();
|
||||||
|
n->name = std::move(name);
|
||||||
|
n->tag = std::move(tag);
|
||||||
|
n->axis = std::move(axis);
|
||||||
|
n->reduce_axis = std::move(reduce_axis);
|
||||||
|
n->schedulable_ndim = std::move(schedulable_ndim);
|
||||||
|
n->intrin = std::move(intrin);
|
||||||
|
n->inputs = std::move(tensors);
|
||||||
|
n->input_regions = std::move(regions);
|
||||||
|
return Operation(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<Tensor> TensorComputeOpNode::InputTensors() const {
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation TensorComputeOpNode::ReplaceInputs(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<Tensor, Tensor>& rmap) const {
|
||||||
|
CHECK_EQ(self.operator->(), this);
|
||||||
|
auto n = make_node<TensorComputeOpNode>(*this);
|
||||||
|
auto intrin = make_node<TensorIntrinNode>(*(this->intrin.operator->()));
|
||||||
|
intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
|
||||||
|
if (intrin->reduce_init.defined()) {
|
||||||
|
intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
|
||||||
|
}
|
||||||
|
if (intrin->reduce_update.defined()) {
|
||||||
|
intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < n->inputs.size(); ++i) {
|
||||||
|
Tensor t = n->inputs[i];
|
||||||
|
if (rmap.count(t)) {
|
||||||
|
n->inputs.Set(i, rmap.at(t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (intrin->body.same_as(n->intrin->body) &&
|
||||||
|
intrin->reduce_init.same_as(n->intrin->reduce_init) &&
|
||||||
|
intrin->reduce_update.same_as(n->intrin->reduce_update) &&
|
||||||
|
inputs.same_as(n->inputs)) {
|
||||||
|
return self;
|
||||||
|
} else {
|
||||||
|
n->intrin = TensorIntrin(intrin);
|
||||||
|
return Operation(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorComputeOpNode::PropBoundToInputs(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||||
|
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
|
||||||
|
for (size_t i = 0; i < this->inputs.size(); ++i) {
|
||||||
|
Tensor t = this->inputs[i];
|
||||||
|
Region region = input_regions[i];
|
||||||
|
|
||||||
|
auto it = out_dom_map->find(t);
|
||||||
|
if (it == out_dom_map->end()) continue;
|
||||||
|
TensorDom& dom = it->second;
|
||||||
|
for (size_t j = 0; j < t.ndim(); ++j) {
|
||||||
|
dom.data[j].emplace_back(EvalSet(region[j], dom_map));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorComputeOpNode::GatherBound(
|
||||||
|
const Operation& self,
|
||||||
|
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||||
|
std::unordered_map<IterVar, Range>* out_dom_map) const {
|
||||||
|
const TensorDom& tdom = tensor_dom.at(self.output(0));
|
||||||
|
for (size_t i = 0; i < this->axis.size(); ++i) {
|
||||||
|
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
|
||||||
|
CHECK(!out_dom_map->count(this->axis[i]));
|
||||||
|
(*out_dom_map)[this->axis[i]] = r;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
|
||||||
|
CHECK(!out_dom_map->count(this->reduce_axis[i]));
|
||||||
|
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt TensorComputeOpNode::BuildRealize(
|
||||||
|
const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& realize_map,
|
||||||
|
const Stmt& body) const {
|
||||||
|
CHECK_EQ(stage->op.get(), this);
|
||||||
|
HalideIR::Internal::Region bounds;
|
||||||
|
for (IterVar iv : this->axis) {
|
||||||
|
bounds.push_back(realize_map.at(iv));
|
||||||
|
}
|
||||||
|
Stmt realize = body;
|
||||||
|
for (int i = this->num_outputs(); i > 0; --i) {
|
||||||
|
Tensor t = stage->op.output(i-1);
|
||||||
|
realize = ir::Realize::make(t->op, t->value_index,
|
||||||
|
t->dtype, bounds, const_true(), realize);
|
||||||
|
// alignment requirement, only useful for compute
|
||||||
|
for (int i = 0; i < schedulable_ndim; ++i) {
|
||||||
|
auto it = stage->iter_var_attrs.find(this->axis[i]);
|
||||||
|
if (it != stage->iter_var_attrs.end()) {
|
||||||
|
IterVarAttr attr = (*it).second;
|
||||||
|
if (attr->dim_align_factor != 0) {
|
||||||
|
Array<Expr> tuple = {static_cast<int>(i),
|
||||||
|
attr->dim_align_factor,
|
||||||
|
attr->dim_align_offset};
|
||||||
|
realize = ir::AttrStmt::make(
|
||||||
|
t, ir::attr::buffer_dim_align,
|
||||||
|
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
|
||||||
|
realize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return realize;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeLoopNest MakeLoopNest(
|
||||||
|
const TensorComputeOpNode* self,
|
||||||
|
const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
bool debug_keep_trivial_loop) {
|
||||||
|
CHECK_EQ(stage->op.operator->(), self);
|
||||||
|
ComputeLoopNest ret;
|
||||||
|
// make main loop nest
|
||||||
|
ret.main_nest = op::MakeLoopNest(
|
||||||
|
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
|
||||||
|
debug_keep_trivial_loop);
|
||||||
|
ret.main_predicates = schedule::MakeBoundCheck(
|
||||||
|
stage, dom_map, ret.main_vmap, false,
|
||||||
|
std::unordered_set<IterVar>());
|
||||||
|
for (auto& e : ret.main_predicates) {
|
||||||
|
e = likely(e);
|
||||||
|
}
|
||||||
|
if (stage->store_predicate.defined()) {
|
||||||
|
ret.main_predicates.push_back(stage->store_predicate);
|
||||||
|
}
|
||||||
|
if (self->reduce_axis.size() != 0) {
|
||||||
|
// try to find the location to insert the initialization.
|
||||||
|
// Fuse the initialization and provide loop when possible.
|
||||||
|
std::unordered_map<IterVar, int> update_state;
|
||||||
|
for (IterVar iv : self->reduce_axis) {
|
||||||
|
update_state[iv] = 2;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < self->schedulable_ndim; ++i) {
|
||||||
|
update_state[self->axis[i]] = 1;
|
||||||
|
}
|
||||||
|
// find which iter var is related to reduction and which is related to axis.
|
||||||
|
schedule::PassDownBitMaskOr(stage, &update_state);
|
||||||
|
auto leaf_iter_vars = stage->leaf_iter_vars;
|
||||||
|
// first first loop that is related to reduction.
|
||||||
|
size_t begin_loop = leaf_iter_vars.size();
|
||||||
|
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
|
||||||
|
auto iv = leaf_iter_vars[i];
|
||||||
|
int flag = update_state.at(iv);
|
||||||
|
if ((flag & 2) != 0) {
|
||||||
|
begin_loop = i; break;
|
||||||
|
}
|
||||||
|
ret.init_vmap[iv] = ret.main_vmap.at(iv);
|
||||||
|
}
|
||||||
|
ret.num_common_loop = begin_loop;
|
||||||
|
// skip loops that does not relates to axis.
|
||||||
|
std::unordered_set<IterVar> skip_iter;
|
||||||
|
for (auto kv : update_state) {
|
||||||
|
int flag = kv.second;
|
||||||
|
if ((flag & 1) == 0) skip_iter.insert(kv.first);
|
||||||
|
}
|
||||||
|
ret.init_nest = op::MakeLoopNest(
|
||||||
|
stage, dom_map, begin_loop, true,
|
||||||
|
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
|
||||||
|
ret.init_predicates = schedule::MakeBoundCheck(
|
||||||
|
stage, dom_map, ret.init_vmap, true, skip_iter);
|
||||||
|
for (auto& e : ret.init_predicates) {
|
||||||
|
e = likely(e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
|
||||||
|
ret.num_common_loop = stage->leaf_iter_vars.size();
|
||||||
|
}
|
||||||
|
// copy elison here.
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Stmt TensorComputeOpNode::BuildProvide(
|
||||||
|
const Stage& stage,
|
||||||
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
bool debug_keep_trivial_loop) const {
|
||||||
|
CHECK_EQ(stage->op.operator->(), this);
|
||||||
|
|
||||||
|
// Start bind data.
|
||||||
|
Stmt nop = Evaluate::make(0);
|
||||||
|
std::vector<Stmt> input_bind_nest, output_bind_nest;
|
||||||
|
Array<Tensor> inputs = this->InputTensors();
|
||||||
|
|
||||||
|
// input binding
|
||||||
|
size_t num_inputs = inputs.size();
|
||||||
|
for (size_t i = 0; i < num_inputs; ++i) {
|
||||||
|
Tensor tensor = inputs[i];
|
||||||
|
Region region = this->input_regions[i];
|
||||||
|
Buffer buffer = this->intrin->buffers[i];
|
||||||
|
Array<NodeRef> bind_spec{buffer, tensor};
|
||||||
|
|
||||||
|
Array<Expr> tuple;
|
||||||
|
for (size_t i = 0; i < region.size(); ++i) {
|
||||||
|
tuple.push_back(region[i]->min);
|
||||||
|
tuple.push_back(region[i]->extent);
|
||||||
|
}
|
||||||
|
input_bind_nest.emplace_back(AttrStmt::make(
|
||||||
|
bind_spec, ir::attr::buffer_bind_scope,
|
||||||
|
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
|
||||||
|
}
|
||||||
|
|
||||||
|
// output binding
|
||||||
|
for (int i = 0; i < this->num_outputs(); ++i) {
|
||||||
|
Tensor tensor = stage->op.output(i);
|
||||||
|
Buffer buffer = this->intrin->buffers[num_inputs + i];
|
||||||
|
Array<NodeRef> bind_spec{buffer, tensor};
|
||||||
|
|
||||||
|
Array<Expr> tuple;
|
||||||
|
for (size_t i = 0; i < this->axis.size(); ++i) {
|
||||||
|
auto ivar = this->axis[i];
|
||||||
|
if (i < static_cast<size_t>(this->schedulable_ndim)) {
|
||||||
|
tuple.push_back(ivar->var);
|
||||||
|
tuple.push_back(1);
|
||||||
|
} else {
|
||||||
|
Range dom = ivar->dom;
|
||||||
|
tuple.push_back(dom->min);
|
||||||
|
tuple.push_back(dom->extent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output_bind_nest.emplace_back(AttrStmt::make(
|
||||||
|
bind_spec, ir::attr::buffer_bind_scope,
|
||||||
|
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check variable remap
|
||||||
|
std::unordered_map<const Variable*, Expr> vmap;
|
||||||
|
ir::ArgBinder binder(&vmap);
|
||||||
|
|
||||||
|
size_t tloc = stage->leaf_iter_vars.size();
|
||||||
|
ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);
|
||||||
|
|
||||||
|
if (this->reduce_axis.size() == 0) {
|
||||||
|
std::vector<std::vector<Stmt> > nest(
|
||||||
|
n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
|
||||||
|
nest.emplace_back(op::MakeIfNest(n.main_predicates));
|
||||||
|
CHECK_EQ(n.init_predicates.size(), 0U);
|
||||||
|
CHECK(this->intrin->body.defined())
|
||||||
|
<< "Normal store op for intrin " << this << " is not defined";
|
||||||
|
Stmt body = MergeNest(output_bind_nest, this->intrin->body);
|
||||||
|
body = MergeNest(input_bind_nest, body);
|
||||||
|
body = ir::Substitute(body, vmap);
|
||||||
|
body = MergeNest(binder.asserts(), body);
|
||||||
|
body = op::Substitute(body, n.main_vmap);
|
||||||
|
Stmt ret = MergeNest(nest, body);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
// Need to split reduction
|
||||||
|
CHECK(this->intrin->reduce_update.defined())
|
||||||
|
<< "Reduction update op is not defined";
|
||||||
|
// Need init and update steps
|
||||||
|
CHECK_NE(this->reduce_axis.size(), 0U);
|
||||||
|
std::vector<std::vector<Stmt> > common(
|
||||||
|
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
|
||||||
|
std::vector<std::vector<Stmt> > update_nest(
|
||||||
|
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
|
||||||
|
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
|
||||||
|
|
||||||
|
if (this->intrin->reduce_init.defined()) {
|
||||||
|
// init nest
|
||||||
|
std::vector<std::vector<Stmt> > init_nest(
|
||||||
|
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
|
||||||
|
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
|
||||||
|
Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
|
||||||
|
init = op::Substitute(init, n.init_vmap);
|
||||||
|
init = MergeNest(init_nest, init);
|
||||||
|
// The update
|
||||||
|
Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
|
||||||
|
update = MergeNest(input_bind_nest, update);
|
||||||
|
update = ir::Substitute(update, vmap);
|
||||||
|
update = MergeNest(binder.asserts(), update);
|
||||||
|
update = op::Substitute(update, n.main_vmap);
|
||||||
|
update = MergeNest(update_nest, update);
|
||||||
|
return MergeNest(common, Block::make(init, update));
|
||||||
|
} else {
|
||||||
|
// When init op is not available, use body op for reset in the first iter.
|
||||||
|
CHECK(this->intrin->body.defined())
|
||||||
|
<< "Normal body op is not defined";
|
||||||
|
Stmt update = TransformUpdate(stage, dom_map, n,
|
||||||
|
this->intrin->body,
|
||||||
|
this->intrin->reduce_update);
|
||||||
|
update = MergeNest(output_bind_nest, update);
|
||||||
|
update = MergeNest(input_bind_nest, update);
|
||||||
|
update = ir::Substitute(update, vmap);
|
||||||
|
update = MergeNest(binder.asserts(), update);
|
||||||
|
update = op::Substitute(update, n.main_vmap);
|
||||||
|
update = MergeNest(update_nest, update);
|
||||||
|
return MergeNest(common, update);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tvm
|
|
@ -10,7 +10,6 @@
|
||||||
#include "op_util.h"
|
#include "op_util.h"
|
||||||
#include "compute_op.h"
|
#include "compute_op.h"
|
||||||
#include "../schedule/message_passing.h"
|
#include "../schedule/message_passing.h"
|
||||||
#include "../arithmetic/compute_expr.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
@ -323,50 +322,6 @@ void VerifyTensorizeBody(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Transform the update part when there is no init func in tensorizing
|
|
||||||
* \param stage The stage for tensorizing.
|
|
||||||
* \param dom_map The range of each iter var.
|
|
||||||
* \param n The loop nest structured used in compute.
|
|
||||||
* \param body The body func in tensorize intrin
|
|
||||||
* \param update The update func in tensorize intrin
|
|
||||||
* \return Transformed result.
|
|
||||||
*/
|
|
||||||
Stmt TransformUpdate(const Stage& stage,
|
|
||||||
const std::unordered_map<IterVar, Range>& dom_map,
|
|
||||||
const ComputeLoopNest& n,
|
|
||||||
Stmt body,
|
|
||||||
Stmt update) {
|
|
||||||
Array<Expr> conds;
|
|
||||||
std::unordered_set<const Variable*> banned;
|
|
||||||
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
|
|
||||||
IterVar iv = stage->leaf_iter_vars[i];
|
|
||||||
auto iit = stage->iter_var_attrs.find(iv);
|
|
||||||
if (iit != stage->iter_var_attrs.end()) {
|
|
||||||
const IterVarAttr& attr = (*iit).second;
|
|
||||||
if (attr->iter_type == kTensorized) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (iv->iter_type == kCommReduce) {
|
|
||||||
auto vit = dom_map.find(iv);
|
|
||||||
CHECK(vit != dom_map.end());
|
|
||||||
const Range& vrange = vit->second;
|
|
||||||
conds.push_back(likely(iv->var > vrange->min));
|
|
||||||
banned.insert(iv->var.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const Expr& pred : n.main_predicates) {
|
|
||||||
if (ir::ExprUseVar(pred, banned)) {
|
|
||||||
LOG(FATAL) << "Tensorize update transform failed, the condition "
|
|
||||||
<< pred << " has a conflict with the reset condition";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
|
|
||||||
update, body);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stmt MakeTensorize(const ComputeOpNode* self,
|
Stmt MakeTensorize(const ComputeOpNode* self,
|
||||||
const Stage& stage,
|
const Stage& stage,
|
||||||
const std::unordered_map<IterVar, Range>& dom_map,
|
const std::unordered_map<IterVar, Range>& dom_map,
|
||||||
|
|
|
@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
|
||||||
// bind pointer and offset.
|
// bind pointer and offset.
|
||||||
if (is_zero(arg->elem_offset)) {
|
if (is_zero(arg->elem_offset)) {
|
||||||
CHECK(is_zero(value->elem_offset))
|
CHECK(is_zero(value->elem_offset))
|
||||||
<< "Trying to bind a Buffer with offset into one without offset";
|
<< "Trying to bind a Buffer with offset into one without offset "
|
||||||
|
<< " required elem_offset=" << arg->elem_offset
|
||||||
|
<< ", provided elem_offset=" << value->elem_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
this->Bind(arg->data, value->data, arg_name + ".data");
|
this->Bind(arg->data, value->data, arg_name + ".data");
|
||||||
|
|
|
@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor,
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache write and relayout the data according to loop pattern
|
template<typename OpType>
|
||||||
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
void PrepareAxisMapping(Stage orig_stage,
|
||||||
const Array<Tensor>& tensor_array,
|
OpType* op,
|
||||||
const std::string& scope) {
|
std::unordered_set<IterVar>* p_red_axis,
|
||||||
size_t tensor_size = tensor_array.size();
|
Array<IterVar>* p_new_axis,
|
||||||
sch->InvalidateCache();
|
std::unordered_map<IterVar, Range>* p_dom_map,
|
||||||
Tensor tensor = tensor_array[0];
|
std::unordered_map<const Variable*, Expr>* p_vsub,
|
||||||
Stage orig_stage = sch[tensor->op];
|
std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
|
||||||
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
|
std::vector<Expr>* p_predicates) {
|
||||||
std::unordered_set<IterVar> red_axis;
|
auto& red_axis = *p_red_axis;
|
||||||
for (IterVar iv : compute->reduce_axis) {
|
auto& new_axis = *p_new_axis;
|
||||||
|
auto& dom_map = *p_dom_map;
|
||||||
|
auto& vsub = *p_vsub;
|
||||||
|
auto& vsub2newvar = *p_vsub2newvar;
|
||||||
|
auto& predicates = *p_predicates;
|
||||||
|
|
||||||
|
for (IterVar iv : op->reduce_axis) {
|
||||||
red_axis.insert(iv);
|
red_axis.insert(iv);
|
||||||
}
|
}
|
||||||
std::unordered_map<IterVar, Range> dom_map;
|
for (IterVar iv : op->axis) {
|
||||||
Array<IterVar> new_axis;
|
|
||||||
|
|
||||||
for (IterVar iv : compute->axis) {
|
|
||||||
dom_map[iv] = iv->dom;
|
dom_map[iv] = iv->dom;
|
||||||
}
|
}
|
||||||
schedule::PassDownDomain(orig_stage, &dom_map, true);
|
schedule::PassDownDomain(orig_stage, &dom_map, true);
|
||||||
std::unordered_map<const Variable*, Expr> vsub;
|
|
||||||
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
|
||||||
std::vector<Expr> predicates;
|
|
||||||
{
|
{
|
||||||
// The source->cache
|
// The source->cache
|
||||||
std::unordered_map<IterVar, Expr> value_map;
|
std::unordered_map<IterVar, Expr> value_map;
|
||||||
|
@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||||
}
|
}
|
||||||
// skip reduction iteration.
|
// skip reduction iteration.
|
||||||
std::unordered_set<IterVar> skip_bound_check;
|
std::unordered_set<IterVar> skip_bound_check;
|
||||||
for (IterVar iv : compute->reduce_axis) {
|
for (IterVar iv : op->reduce_axis) {
|
||||||
skip_bound_check.insert(iv);
|
skip_bound_check.insert(iv);
|
||||||
}
|
}
|
||||||
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
|
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
|
||||||
predicates = schedule::MakeBoundCheck(
|
predicates = schedule::MakeBoundCheck(
|
||||||
orig_stage, dom_map, value_map, true, skip_bound_check);
|
orig_stage, dom_map, value_map, true, skip_bound_check);
|
||||||
// The root axis
|
// The root axis
|
||||||
for (IterVar iv : compute->axis) {
|
for (IterVar iv : op->axis) {
|
||||||
|
if (value_map.count(iv)) {
|
||||||
vsub[iv->var.get()] = value_map.at(iv);
|
vsub[iv->var.get()] = value_map.at(iv);
|
||||||
|
} // to handle tensor axis
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<Tensor> ReplaceOriginalOp(Schedule sch,
|
||||||
|
Stage orig_stage,
|
||||||
|
const std::string& scope,
|
||||||
|
Operation cache_op,
|
||||||
|
Operation orig_new_op,
|
||||||
|
size_t tensor_size) {
|
||||||
|
Array<Tensor> cache_tensor_list;
|
||||||
|
for (size_t i = 0; i < tensor_size; i++) {
|
||||||
|
Tensor cache_tensor = cache_op.output(i);
|
||||||
|
cache_tensor_list.push_back(cache_tensor);
|
||||||
|
}
|
||||||
|
// The replace of the dataflow
|
||||||
|
std::unordered_map<Tensor, Tensor> vmap;
|
||||||
|
std::unordered_map<Tensor, Tensor> rvmap;
|
||||||
|
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||||
|
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||||
|
for (size_t i = 0; i < tensor_size; i++) {
|
||||||
|
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||||
|
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||||
|
}
|
||||||
|
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
|
||||||
|
// mutate orig stage
|
||||||
|
orig_stage->op = orig_new_op;
|
||||||
|
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
|
||||||
|
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
|
||||||
|
orig_stage->relations = Array<IterVarRelation>();
|
||||||
|
// create schedule for new cached stage.
|
||||||
|
ArrayNode* stages = sch->stages.CopyOnWrite();
|
||||||
|
size_t pos = FindNodeRef(stages, orig_stage);
|
||||||
|
Stage cache_stage = Stage(cache_op);
|
||||||
|
cache_stage.set_scope(scope);
|
||||||
|
CHECK_LT(pos, stages->data.size());
|
||||||
|
stages->data.insert(stages->data.begin() + pos,
|
||||||
|
cache_stage.node_);
|
||||||
|
sch->stage_map.Set(cache_op, cache_stage);
|
||||||
|
// Update group
|
||||||
|
cache_stage->group = orig_stage->group;
|
||||||
|
if (cache_stage->group.defined()) {
|
||||||
|
++cache_stage->group->num_child_stages;
|
||||||
|
}
|
||||||
|
return cache_tensor_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Cache write and relayout the data according to loop pattern
|
||||||
|
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||||
|
const Array<Tensor>& tensor_array,
|
||||||
|
const std::string& scope) {
|
||||||
|
size_t tensor_size = tensor_array.size();
|
||||||
|
sch->InvalidateCache();
|
||||||
|
Tensor tensor = tensor_array[0];
|
||||||
|
Stage orig_stage = sch[tensor->op];
|
||||||
|
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
|
||||||
|
|
||||||
|
std::unordered_set<IterVar> red_axis;
|
||||||
|
Array<IterVar> new_axis;
|
||||||
|
std::unordered_map<IterVar, Range> dom_map;
|
||||||
|
|
||||||
|
std::unordered_map<const Variable*, Expr> vsub;
|
||||||
|
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
||||||
|
std::vector<Expr> predicates;
|
||||||
|
|
||||||
|
PrepareAxisMapping(orig_stage, compute,
|
||||||
|
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
|
||||||
|
|
||||||
Expr body;
|
Expr body;
|
||||||
Array<Expr> body_list;
|
Array<Expr> body_list;
|
||||||
|
@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||||
body = InjectPredicate(predicates, body);
|
body = InjectPredicate(predicates, body);
|
||||||
body = VarReplacer(vsub2newvar).Mutate(body);
|
body = VarReplacer(vsub2newvar).Mutate(body);
|
||||||
// Reduce nodes in ONE computeOp must be the same except value_index
|
// Reduce nodes in ONE computeOp must be the same except value_index
|
||||||
// This is right only if the oringinal body ensures Reduce nodes are the same
|
// This is right only if the original body ensures Reduce nodes are the same
|
||||||
if (body->is_type<ir::Reduce>()) {
|
if (body->is_type<ir::Reduce>()) {
|
||||||
const ir::Reduce* reduce_body = body.as<ir::Reduce>();
|
const ir::Reduce* reduce_body = body.as<ir::Reduce>();
|
||||||
if (first_reduce != nullptr) {
|
if (first_reduce != nullptr) {
|
||||||
|
@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||||
Operation cache_op = ComputeOpNode::make(
|
Operation cache_op = ComputeOpNode::make(
|
||||||
compute->name + "." + scope, compute->tag, compute->attrs,
|
compute->name + "." + scope, compute->tag, compute->attrs,
|
||||||
new_axis, body_list);
|
new_axis, body_list);
|
||||||
Array<Tensor> cache_tensor_list;
|
|
||||||
Array<Expr> cache_expr_list;
|
Array<Expr> cache_expr_list;
|
||||||
for (size_t i = 0; i < tensor_size; i++) {
|
for (size_t i = 0; i < tensor_size; i++) {
|
||||||
Tensor cache_tensor = cache_op.output(i);
|
Tensor cache_tensor = cache_op.output(i);
|
||||||
cache_tensor_list.push_back(cache_tensor);
|
|
||||||
cache_expr_list.push_back(cache_tensor(args));
|
cache_expr_list.push_back(cache_tensor(args));
|
||||||
}
|
}
|
||||||
Operation orig_new_op = ComputeOpNode::make(
|
Operation orig_new_op = ComputeOpNode::make(
|
||||||
compute->name, compute->tag, compute->attrs,
|
compute->name, compute->tag, compute->attrs,
|
||||||
compute->axis, cache_expr_list);
|
compute->axis, cache_expr_list);
|
||||||
// The replace of the dataflow
|
return ReplaceOriginalOp(sch, orig_stage, scope,
|
||||||
std::unordered_map<Tensor, Tensor> vmap;
|
cache_op, orig_new_op, tensor_size);
|
||||||
std::unordered_map<Tensor, Tensor> rvmap;
|
}
|
||||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
|
||||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
|
||||||
|
// for tensor compute op
|
||||||
|
Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
|
||||||
|
const Array<Tensor>& tensor_array,
|
||||||
|
const std::string& scope) {
|
||||||
|
size_t tensor_size = tensor_array.size();
|
||||||
|
sch->InvalidateCache();
|
||||||
|
Tensor tensor = tensor_array[0];
|
||||||
|
Stage orig_stage = sch[tensor->op];
|
||||||
|
const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
|
||||||
|
CHECK_EQ(tensor_op->num_outputs(), 1)
|
||||||
|
<< "cache write only support single output tensor_compute_op";
|
||||||
|
|
||||||
|
std::unordered_set<IterVar> red_axis;
|
||||||
|
Array<IterVar> new_axis;
|
||||||
|
std::unordered_map<IterVar, Range> dom_map;
|
||||||
|
|
||||||
|
std::unordered_map<const Variable*, Expr> vsub;
|
||||||
|
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
||||||
|
std::vector<Expr> predicates;
|
||||||
|
|
||||||
|
PrepareAxisMapping(orig_stage, tensor_op,
|
||||||
|
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
|
||||||
|
IterVar iv = tensor_op->axis[i];
|
||||||
|
IterVar new_iv = IterVarNode::make(
|
||||||
|
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
|
||||||
|
new_axis.push_back(new_iv);
|
||||||
|
}
|
||||||
|
Array<Region> new_regions;
|
||||||
|
for (Region old_region : tensor_op->input_regions) {
|
||||||
|
Region region;
|
||||||
|
for (Range r : old_region) {
|
||||||
|
Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
|
||||||
|
Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
|
||||||
|
region.push_back(Range::make_by_min_extent(min, extent));
|
||||||
|
}
|
||||||
|
new_regions.push_back(region);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation cache_op = TensorComputeOpNode::make(
|
||||||
|
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
|
||||||
|
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
|
||||||
|
tensor_op->intrin, tensor_op->inputs, new_regions);
|
||||||
|
|
||||||
|
// axis will be used in generating compute op
|
||||||
|
Array<IterVar> compute_axis = tensor_op->axis;
|
||||||
|
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
|
||||||
|
IterVar iv = tensor_op->axis[i];
|
||||||
|
IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
|
||||||
|
compute_axis.Set(i, aiv);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The reader args
|
||||||
|
Array<Expr> args;
|
||||||
|
{
|
||||||
|
// cache->compute
|
||||||
|
std::unordered_map<IterVar, Expr> value_map;
|
||||||
|
for (IterVar iv : compute_axis) {
|
||||||
|
value_map[iv] = iv->var;
|
||||||
|
}
|
||||||
|
schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
|
||||||
|
for (IterVar iv : orig_stage->leaf_iter_vars) {
|
||||||
|
if (red_axis.count(iv)) continue;
|
||||||
|
args.push_back(value_map.at(iv));
|
||||||
|
}
|
||||||
|
// tensorized region axis
|
||||||
|
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
|
||||||
|
IterVar iv = compute_axis[i];
|
||||||
|
args.push_back(value_map.at(iv));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<Expr> cache_expr_list;
|
||||||
for (size_t i = 0; i < tensor_size; i++) {
|
for (size_t i = 0; i < tensor_size; i++) {
|
||||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
Tensor cache_tensor = cache_op.output(i);
|
||||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
cache_expr_list.push_back(cache_tensor(args));
|
||||||
}
|
}
|
||||||
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
|
Operation orig_new_op = ComputeOpNode::make(
|
||||||
// mutate orig stage
|
tensor_op->name, tensor_op->tag, {},
|
||||||
orig_stage->op = orig_new_op;
|
compute_axis, cache_expr_list);
|
||||||
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
|
return ReplaceOriginalOp(sch, orig_stage, scope,
|
||||||
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
|
cache_op, orig_new_op, tensor_size);
|
||||||
orig_stage->relations = Array<IterVarRelation>();
|
|
||||||
// create schedule for new cached stage.
|
|
||||||
ArrayNode* stages = sch->stages.CopyOnWrite();
|
|
||||||
size_t pos = FindNodeRef(stages, orig_stage);
|
|
||||||
Stage cache_stage = Stage(cache_op);
|
|
||||||
cache_stage.set_scope(scope);
|
|
||||||
CHECK_LT(pos, stages->data.size());
|
|
||||||
stages->data.insert(stages->data.begin() + pos,
|
|
||||||
cache_stage.node_);
|
|
||||||
sch->stage_map.Set(cache_op, cache_stage);
|
|
||||||
// Update group
|
|
||||||
cache_stage->group = orig_stage->group;
|
|
||||||
if (cache_stage->group.defined()) {
|
|
||||||
++cache_stage->group->num_child_stages;
|
|
||||||
}
|
|
||||||
return cache_tensor_list;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
|
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
|
||||||
const std::string& scope) {
|
const std::string& scope) {
|
||||||
(*this)->InvalidateCache();
|
(*this)->InvalidateCache();
|
||||||
|
@ -291,22 +418,25 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
|
||||||
CHECK(orig_stage.same_as(tmp_stage))
|
CHECK(orig_stage.same_as(tmp_stage))
|
||||||
<< "Input tensor list must be generated by ONE computeOp";
|
<< "Input tensor list must be generated by ONE computeOp";
|
||||||
}
|
}
|
||||||
|
|
||||||
return CacheWriteWithReLayout(*this, tensor_array, scope);
|
return CacheWriteWithReLayout(*this, tensor_array, scope);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Tensor Schedule::cache_write(const Tensor& tensor,
|
Tensor Schedule::cache_write(const Tensor& tensor,
|
||||||
const std::string& scope) {
|
const std::string& scope) {
|
||||||
|
// support original compute and tensor compute both
|
||||||
(*this)->InvalidateCache();
|
(*this)->InvalidateCache();
|
||||||
Stage orig_stage = operator[](tensor->op);
|
const char* type_key = tensor->op->type_key();
|
||||||
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
|
if (!strcmp(type_key, "ComputeOp")) {
|
||||||
CHECK(compute)
|
|
||||||
<< "cache write only take ComputeOp as writers";
|
|
||||||
CHECK_EQ(compute->num_outputs(), 1)
|
|
||||||
<< "cache write only support single output ComputeOp";
|
|
||||||
|
|
||||||
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
|
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
|
||||||
|
} else if (!strcmp(type_key, "TensorComputeOp")) {
|
||||||
|
return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
|
||||||
|
return Tensor();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void RebaseNonZeroMinLoop(const Schedule& sch) {
|
void RebaseNonZeroMinLoop(const Schedule& sch) {
|
||||||
std::unordered_map<IterVar, IterVar> rebase_map;
|
std::unordered_map<IterVar, IterVar> rebase_map;
|
||||||
|
|
|
@ -85,6 +85,78 @@ def test_tensor_reduce():
|
||||||
assert(isinstance(C_loaded, tvm.tensor.Tensor))
|
assert(isinstance(C_loaded, tvm.tensor.Tensor))
|
||||||
assert(str(C_loaded) == str(C))
|
assert(str(C_loaded) == str(C))
|
||||||
|
|
||||||
|
def test_tensor_compute1():
|
||||||
|
m = 1024
|
||||||
|
factor = 16
|
||||||
|
dtype = 'float32'
|
||||||
|
|
||||||
|
def intrin_vadd(n):
|
||||||
|
x = tvm.placeholder((n,))
|
||||||
|
y = tvm.placeholder((n,))
|
||||||
|
z = tvm.compute(x.shape, lambda i: x[i] + y[i])
|
||||||
|
|
||||||
|
def intrin_func(ins, outs):
|
||||||
|
ib = tvm.ir_builder.create()
|
||||||
|
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
|
||||||
|
return ib.get()
|
||||||
|
|
||||||
|
with tvm.build_config(offset_factor=n):
|
||||||
|
return tvm.decl_tensor_intrin(z.op, intrin_func)
|
||||||
|
|
||||||
|
vadd = intrin_vadd(factor)
|
||||||
|
|
||||||
|
A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype)
|
||||||
|
B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype)
|
||||||
|
C = tvm.compute((m//factor, factor),
|
||||||
|
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
|
||||||
|
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
|
||||||
|
assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
|
||||||
|
|
||||||
|
def test_tensor_compute2():
|
||||||
|
M = 2048
|
||||||
|
N = 1024
|
||||||
|
L = 1024
|
||||||
|
factor = 16
|
||||||
|
factor1 = 32
|
||||||
|
factor2 = 32
|
||||||
|
dtype = 'float32'
|
||||||
|
|
||||||
|
def intrin_gemm(m, n, l):
|
||||||
|
k = tvm.reduce_axis((0, l))
|
||||||
|
x = tvm.placeholder((m, l))
|
||||||
|
y = tvm.placeholder((n, l))
|
||||||
|
# in theory, no relation
|
||||||
|
z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k))
|
||||||
|
|
||||||
|
def intrin_func(ins, outs):
|
||||||
|
x_ptr = ins[0].access_ptr("r")
|
||||||
|
y_ptr = ins[1].access_ptr("r")
|
||||||
|
z_ptr = outs[0].access_ptr("w")
|
||||||
|
body = tvm.call_packed(
|
||||||
|
"gemv", x_ptr, y_ptr, z_ptr, m, n, l)
|
||||||
|
reset = tvm.call_packed(
|
||||||
|
"fill_zero", z_ptr, m, n)
|
||||||
|
update = tvm.call_packed(
|
||||||
|
"gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
|
||||||
|
return body, reset, update
|
||||||
|
|
||||||
|
with tvm.build_config(offset_factor=n):
|
||||||
|
return tvm.decl_tensor_intrin(z.op, intrin_func)
|
||||||
|
|
||||||
|
vgemm = intrin_gemm(factor1, factor2, factor)
|
||||||
|
|
||||||
|
A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
|
||||||
|
B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
|
||||||
|
k = tvm.reduce_axis((0, L//factor), name='k')
|
||||||
|
C = tvm.compute((M//factor1, N//factor2, factor1, factor2),
|
||||||
|
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
|
||||||
|
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
|
||||||
|
assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
|
||||||
|
assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
|
||||||
|
|
||||||
def test_tensor_scan():
|
def test_tensor_scan():
|
||||||
m = tvm.var("m")
|
m = tvm.var("m")
|
||||||
|
@ -221,6 +293,8 @@ if __name__ == "__main__":
|
||||||
test_conv1d()
|
test_conv1d()
|
||||||
test_tensor_slice()
|
test_tensor_slice()
|
||||||
test_tensor()
|
test_tensor()
|
||||||
|
test_tensor_compute1()
|
||||||
|
test_tensor_compute2()
|
||||||
test_tensor_reduce()
|
test_tensor_reduce()
|
||||||
test_tensor_scan()
|
test_tensor_scan()
|
||||||
test_scan_multi_out()
|
test_scan_multi_out()
|
||||||
|
|
|
@ -276,6 +276,133 @@ def test_schedule_bound_condition():
|
||||||
stmt = tvm.ir_pass.Simplify(stmt)
|
stmt = tvm.ir_pass.Simplify(stmt)
|
||||||
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
|
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
|
||||||
|
|
||||||
|
|
||||||
|
def intrin_gemv(m, n):
|
||||||
|
w = tvm.placeholder((m, n), name='w')
|
||||||
|
x = tvm.placeholder((n,), name='x')
|
||||||
|
k = tvm.reduce_axis((0, n), name='k')
|
||||||
|
z = tvm.compute((m,), lambda i:
|
||||||
|
tvm.sum(w[i, k] * x[k], axis=k), name='z')
|
||||||
|
Wb = tvm.decl_buffer(w.shape, w.dtype,
|
||||||
|
name="W",
|
||||||
|
offset_factor=16,
|
||||||
|
strides=[tvm.var('ldw'), 1])
|
||||||
|
def intrin_func(ins, outs):
|
||||||
|
ww, xx = ins
|
||||||
|
zz = outs[0]
|
||||||
|
ww_ptr = ww.access_ptr("r")
|
||||||
|
xx_ptr = xx.access_ptr("r")
|
||||||
|
zz_ptr = zz.access_ptr("w")
|
||||||
|
body = tvm.call_packed(
|
||||||
|
"gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
|
||||||
|
reset = tvm.call_packed(
|
||||||
|
"fill_zero", zz_ptr, n)
|
||||||
|
update = tvm.call_packed(
|
||||||
|
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
|
||||||
|
return body, reset, update
|
||||||
|
|
||||||
|
with tvm.build_config(data_alignment=16,
|
||||||
|
offset_factor=16):
|
||||||
|
return tvm.decl_tensor_intrin(z.op, intrin_func,
|
||||||
|
binds={w: Wb})
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_tensor_compute1():
|
||||||
|
# basic: split, reorder, tile
|
||||||
|
M, N, L = 2048, 1024, 512
|
||||||
|
factor, rfactor = 16, 16
|
||||||
|
A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
|
||||||
|
B = tvm.placeholder((M, L//rfactor, rfactor), name='B')
|
||||||
|
k = tvm.reduce_axis((0, L//rfactor), name='k')
|
||||||
|
|
||||||
|
gemv = intrin_gemv(factor, rfactor)
|
||||||
|
C = tvm.compute((N, M//factor, factor),
|
||||||
|
lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
|
||||||
|
name='C')
|
||||||
|
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
ai, aj, ax = s[C].op.axis
|
||||||
|
aio, aii = s[C].split(ai, 16)
|
||||||
|
s[C].reorder(aio, aj, aii)
|
||||||
|
aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4)
|
||||||
|
|
||||||
|
s = s.normalize()
|
||||||
|
bounds = tvm.schedule.InferBound(s)
|
||||||
|
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||||
|
|
||||||
|
|
||||||
|
def intrin_vadd(n, cache_read=False, cache_write=False):
|
||||||
|
scope_ubuf = 'local'
|
||||||
|
dtype = 'float32'
|
||||||
|
x = tvm.placeholder((n,), dtype=dtype, name='vx')
|
||||||
|
y = tvm.placeholder((n,), dtype=dtype, name='vy')
|
||||||
|
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
|
||||||
|
s = tvm.create_schedule(z.op)
|
||||||
|
|
||||||
|
def create_buffer(t):
|
||||||
|
return tvm.decl_buffer(t.shape, t.dtype,
|
||||||
|
name='W'+t.name,
|
||||||
|
scope=scope_ubuf,
|
||||||
|
offset_factor=16)
|
||||||
|
|
||||||
|
binds = {}
|
||||||
|
if cache_read:
|
||||||
|
binds[x] = create_buffer(x)
|
||||||
|
binds[y] = create_buffer(y)
|
||||||
|
if cache_write:
|
||||||
|
binds[z] = create_buffer(z)
|
||||||
|
|
||||||
|
def intrin_func(ins, outs):
|
||||||
|
ib = tvm.ir_builder.create()
|
||||||
|
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
|
||||||
|
return ib.get()
|
||||||
|
|
||||||
|
with tvm.build_config(offset_factor=16):
|
||||||
|
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_tensor_compute2():
|
||||||
|
# cache_read, cache_write
|
||||||
|
M = 1024
|
||||||
|
factor = 16
|
||||||
|
dtype = 'float32'
|
||||||
|
scope_ubuf = 'local'
|
||||||
|
|
||||||
|
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
|
||||||
|
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
|
||||||
|
|
||||||
|
vadd = intrin_vadd(factor, True, True)
|
||||||
|
C = tvm.compute((M//factor, factor),
|
||||||
|
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
|
||||||
|
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
AL = s.cache_read(A, scope_ubuf, C)
|
||||||
|
BL = s.cache_read(B, scope_ubuf, C)
|
||||||
|
CL = s.cache_write(C, scope_ubuf)
|
||||||
|
s = s.normalize()
|
||||||
|
bounds = tvm.schedule.InferBound(s)
|
||||||
|
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_tensor_compute3():
|
||||||
|
# compute_at
|
||||||
|
M = 1024
|
||||||
|
factor = 16
|
||||||
|
dtype = 'float32'
|
||||||
|
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
|
||||||
|
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
|
||||||
|
Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
|
||||||
|
|
||||||
|
vadd = intrin_vadd(factor)
|
||||||
|
C = tvm.compute((M//factor, factor),
|
||||||
|
lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
|
||||||
|
s = tvm.create_schedule(C.op)
|
||||||
|
s[Bi].compute_at(s[C], C.op.axis[0])
|
||||||
|
s = s.normalize()
|
||||||
|
bounds = tvm.schedule.InferBound(s)
|
||||||
|
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_schedule_middle_cache()
|
test_schedule_middle_cache()
|
||||||
test_inline_multi_reduce()
|
test_inline_multi_reduce()
|
||||||
|
@ -294,3 +421,6 @@ if __name__ == "__main__":
|
||||||
test_schedule2()
|
test_schedule2()
|
||||||
test_schedule_cache()
|
test_schedule_cache()
|
||||||
test_schedule_bound_condition()
|
test_schedule_bound_condition()
|
||||||
|
test_schedule_tensor_compute1()
|
||||||
|
test_schedule_tensor_compute2()
|
||||||
|
test_schedule_tensor_compute3()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче