[LANG] Generalize compute to tensor region (#1476)
This commit is contained in:
Родитель
3d62cf7c37
Коммит
b90620ea25
|
@ -1 +1 @@
|
|||
Subproject commit 4f0564ec769477c66d480dd966088f172050c874
|
||||
Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f
|
|
@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
|
|||
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
|
||||
};
|
||||
|
||||
using Region = Array<Range>;
|
||||
|
||||
/*!
|
||||
* \brief Type of iteration variable.
|
||||
* Each IterVar have a specific type.
|
||||
|
|
|
@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
|
|||
}
|
||||
/*!
|
||||
* \return The list of iteration variable at root
|
||||
* \note root_iter_vars dedides the shape of the outputs.
|
||||
* \note root_iter_vars decides the shape of the outputs.
|
||||
*/
|
||||
virtual Array<IterVar> root_iter_vars() const = 0;
|
||||
/*!
|
||||
|
@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
|
|||
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
|
||||
*/
|
||||
class TensorComputeOpNode : public OperationNode {
|
||||
public:
|
||||
/*! \brief IterVar on each axis */
|
||||
Array<IterVar> axis;
|
||||
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
|
||||
Array<IterVar> reduce_axis;
|
||||
/*! \brief number of axes that can be scheduled */
|
||||
int schedulable_ndim;
|
||||
/*! \brief TensorIntrin used to compute */
|
||||
TensorIntrin intrin;
|
||||
/*! \brief input tensors of intrin */
|
||||
Array<Tensor> inputs;
|
||||
/*! \brief region of input tensors */
|
||||
Array<Region> input_regions;
|
||||
/*! \brief constructor */
|
||||
TensorComputeOpNode() {}
|
||||
// override functions
|
||||
int num_outputs() const final;
|
||||
Array<IterVar> root_iter_vars() const final;
|
||||
Type output_dtype(size_t i) const final;
|
||||
Array<Expr> output_shape(size_t i) const final;
|
||||
Array<Tensor> InputTensors() const final;
|
||||
Operation ReplaceInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, Tensor>& rmap) const final;
|
||||
void PropBoundToInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
|
||||
void GatherBound(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||
std::unordered_map<IterVar, Range>* out_dom_map) const final;
|
||||
Stmt BuildRealize(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& realize_map,
|
||||
const Stmt& body) const final;
|
||||
Stmt BuildProvide(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
bool debug_keep_trivial_loop) const final;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("name", &name);
|
||||
v->Visit("tag", &tag);
|
||||
v->Visit("axis", &axis);
|
||||
v->Visit("reduce_axis", &reduce_axis);
|
||||
v->Visit("schedulable_ndim", &schedulable_ndim);
|
||||
v->Visit("intrin", &intrin);
|
||||
v->Visit("inputs", &inputs);
|
||||
v->Visit("input_regions", &input_regions);
|
||||
}
|
||||
static Operation make(std::string name,
|
||||
std::string tag,
|
||||
Array<IterVar> axis,
|
||||
Array<IterVar> reduce_axis,
|
||||
int schedulable_ndim,
|
||||
TensorIntrin intrin,
|
||||
Array<Tensor> tensors,
|
||||
Array<Region> regions);
|
||||
|
||||
static constexpr const char* _type_key = "TensorComputeOp";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Symbolic scan.
|
||||
*/
|
||||
|
@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
|
|||
public:
|
||||
/*! \brief The input tensors */
|
||||
Array<Tensor> inputs;
|
||||
/*! \brief Symbolic placeholder representationinputs */
|
||||
/*! \brief Symbolic placeholder representation of inputs */
|
||||
Array<Buffer> input_placeholders;
|
||||
/*! \brief Symbolic placeholder representation of outputs */
|
||||
Array<Buffer> output_placeholders;
|
||||
|
|
|
@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
|
|||
inline const TensorIntrinNode* TensorIntrin::operator->() const {
|
||||
return static_cast<const TensorIntrinNode*>(node_.get());
|
||||
}
|
||||
|
||||
|
||||
// Internal node container of tensor intrinsic calling.
|
||||
class TensorIntrinCallNode;
|
||||
|
||||
/*! \brief Tensor intrinsic calling node. */
|
||||
class TensorIntrinCall : public NodeRef {
|
||||
public:
|
||||
TensorIntrinCall() {}
|
||||
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
|
||||
/*!
|
||||
* \brief access the internal node container
|
||||
* \return the pointer to the internal node container
|
||||
*/
|
||||
inline const TensorIntrinCallNode* operator->() const;
|
||||
|
||||
/*! \brief specify container node */
|
||||
using ContainerType = TensorIntrinCallNode;
|
||||
};
|
||||
|
||||
class TensorIntrinCallNode : public Node {
|
||||
public:
|
||||
/*! \brief the tensor intrinsic */
|
||||
TensorIntrin intrin;
|
||||
/*! \brief input tensors of the intrinsic */
|
||||
Array<Tensor> tensors;
|
||||
/*! \brief regions of input tensors */
|
||||
Array<Region> regions;
|
||||
/*!
|
||||
* \brief IterVar on each reduction axis, if the
|
||||
* intrin will use the reduce axis
|
||||
*/
|
||||
Array<IterVar> reduce_axis;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("intrin", &intrin);
|
||||
v->Visit("tensors", &tensors);
|
||||
v->Visit("regions", ®ions);
|
||||
v->Visit("reduce_axis", &reduce_axis);
|
||||
}
|
||||
static TensorIntrinCall make(TensorIntrin intrin,
|
||||
Array<Tensor> tensors,
|
||||
Array<Region> regions,
|
||||
Array<IterVar> reduce_axis);
|
||||
|
||||
static constexpr const char* _type_key = "TensorIntrinCall";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
|
||||
};
|
||||
|
||||
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
|
||||
return static_cast<const TensorIntrinCallNode*>(node_.get());
|
||||
}
|
||||
|
||||
} // namespace tvm
|
||||
#endif // TVM_TENSOR_INTRIN_H_
|
||||
|
|
|
@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
|
|||
raise ValueError("nested tag is not allowed for now")
|
||||
tag = _tag.TagScope.get_current().tag
|
||||
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
|
||||
# for python3
|
||||
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
|
||||
ndim = len(shape)
|
||||
code = fcompute.__code__
|
||||
|
||||
if fcompute.__code__.co_argcount == 0:
|
||||
out_ndim = ndim
|
||||
if code.co_argcount == 0:
|
||||
arg_names = ["i%d" % i for i in range(ndim)]
|
||||
else:
|
||||
arg_names = code.co_varnames[:code.co_argcount]
|
||||
out_ndim = code.co_argcount
|
||||
|
||||
if ndim != len(arg_names):
|
||||
if out_ndim != len(arg_names):
|
||||
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
|
||||
|
||||
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
|
||||
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
|
||||
body = fcompute(*[v.var for v in dim_var])
|
||||
if not isinstance(body, (list, tuple)):
|
||||
body = [body]
|
||||
body = convert(body)
|
||||
op_node = _api_internal._ComputeOp(
|
||||
name, tag, attrs, dim_var, body)
|
||||
|
||||
if isinstance(body, _tensor.TensorIntrinCall):
|
||||
for i, s in enumerate(shape[out_ndim:]):
|
||||
var_name = "ax" + str(i)
|
||||
dim_var.append(_IterVar((0, s), var_name, 4))
|
||||
op_node = _api_internal._TensorComputeOp(name,
|
||||
tag,
|
||||
dim_var,
|
||||
body.reduce_axis,
|
||||
out_ndim,
|
||||
body.intrin,
|
||||
body.tensors,
|
||||
body.regions)
|
||||
else:
|
||||
if not isinstance(body, (list, tuple)):
|
||||
body = [body]
|
||||
body = convert(body)
|
||||
op_node = _api_internal._ComputeOp(
|
||||
name, tag, attrs, dim_var, body)
|
||||
|
||||
num = op_node.num_outputs
|
||||
outputs = tuple(op_node.output(i) for i in range(num))
|
||||
return outputs[0] if num == 1 else outputs
|
||||
|
@ -529,14 +548,14 @@ def decl_buffer(shape,
|
|||
dtype = float32 if dtype is None else dtype
|
||||
strides = () if strides is None else strides
|
||||
if offset_factor != 0 and elem_offset is None:
|
||||
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
|
||||
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
|
||||
elem_offset = var('%s_elem_offset' % name, shape_dtype)
|
||||
if data is None:
|
||||
data = var(name, "handle")
|
||||
return _api_internal._Buffer(
|
||||
data, dtype, shape, strides, elem_offset, name, scope,
|
||||
data_alignment, offset_factor)
|
||||
|
||||
|
||||
def _IterVar(dom, name, iter_type, thread_tag=''):
|
||||
"""Internal function to create IterVar
|
||||
|
||||
|
|
|
@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
|
|||
"""Data content of the tensor."""
|
||||
return self.tensor.dtype
|
||||
|
||||
@register_node
|
||||
class TensorIntrinCall(NodeBase):
|
||||
"""Intermediate structure for calling a tensor intrinsic."""
|
||||
pass
|
||||
|
||||
|
||||
itervar_cls = None
|
||||
|
||||
|
@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
|
|||
return "%s.v%d" % (op.name, self.value_index)
|
||||
|
||||
|
||||
|
||||
class Operation(NodeBase):
|
||||
"""Represent an operation that generate a tensor"""
|
||||
|
||||
|
@ -155,6 +161,12 @@ class ComputeOp(Operation):
|
|||
return self.__getattr__("reduce_axis")
|
||||
|
||||
|
||||
@register_node
|
||||
class TensorComputeOp(Operation):
|
||||
"""Tensor operation."""
|
||||
pass
|
||||
|
||||
|
||||
@register_node
|
||||
class ScanOp(Operation):
|
||||
"""Scan operation."""
|
||||
|
|
|
@ -6,9 +6,25 @@ from . import expr as _expr
|
|||
from . import stmt as _stmt
|
||||
from . import make as _make
|
||||
from . import tensor as _tensor
|
||||
from . import schedule as _schedule
|
||||
from .build_module import current_build_config
|
||||
from ._ffi.node import NodeBase, register_node
|
||||
|
||||
|
||||
def _get_region(tslice):
|
||||
region = []
|
||||
for idx in tslice.indices:
|
||||
if isinstance(idx, slice):
|
||||
assert idx.step is None
|
||||
region.append(_api.Range(idx.start, idx.stop))
|
||||
else:
|
||||
if isinstance(idx, _schedule.IterVar):
|
||||
begin = idx.var
|
||||
else:
|
||||
begin = idx
|
||||
region.append(_make.range_by_min_extent(begin, 1))
|
||||
return region
|
||||
|
||||
@register_node
|
||||
class TensorIntrin(NodeBase):
|
||||
"""Tensor intrinsic functions for certain computation.
|
||||
|
@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
|
|||
--------
|
||||
decl_tensor_intrin: Construct a TensorIntrin
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
tensors = [x.tensor for x in args]
|
||||
regions = [_get_region(x) for x in args]
|
||||
reduce_axis = []
|
||||
if "reduce_axis" in kwargs:
|
||||
reduce_axis = kwargs["reduce_axis"]
|
||||
if not isinstance(reduce_axis, (list, tuple)):
|
||||
reduce_axis = [reduce_axis]
|
||||
reduce_axis = _api.convert(reduce_axis)
|
||||
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
|
||||
|
||||
def decl_tensor_intrin(op,
|
||||
fcompute,
|
||||
|
|
|
@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
|
|||
args[6]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("_TensorIntrinCall")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = TensorIntrinCallNode::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("_TensorEqual")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = args[0].operator Tensor() == args[1].operator Tensor();
|
||||
|
@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
|
|||
args[7]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("_TensorComputeOp")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = TensorComputeOpNode::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4],
|
||||
args[5],
|
||||
args[6],
|
||||
args[7]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("_ExternOp")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ExternOpNode::make(args[0],
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
|
||||
namespace tvm {
|
||||
|
||||
// Tensor
|
||||
|
||||
Expr Tensor::operator()(Array<Var> indices) const {
|
||||
Array<Expr> arr(indices.begin(), indices.end());
|
||||
return operator()(arr);
|
||||
|
@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
|
|||
return n;
|
||||
}
|
||||
|
||||
Tensor Operation::output(size_t i) const {
|
||||
auto node = make_node<TensorNode>();
|
||||
node->op = *this;
|
||||
node->value_index = i;
|
||||
node->dtype = (*this)->output_dtype(i);
|
||||
node->shape = (*this)->output_shape(i);
|
||||
return Tensor(node);
|
||||
}
|
||||
|
||||
Tensor TensorNode::make(Array<Expr> shape,
|
||||
Type dtype,
|
||||
Operation op,
|
||||
|
@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|||
|
||||
TVM_REGISTER_NODE_TYPE(TensorNode);
|
||||
|
||||
Tensor Operation::output(size_t i) const {
|
||||
auto node = make_node<TensorNode>();
|
||||
node->op = *this;
|
||||
node->value_index = i;
|
||||
node->dtype = (*this)->output_dtype(i);
|
||||
node->shape = (*this)->output_shape(i);
|
||||
return Tensor(node);
|
||||
}
|
||||
|
||||
// TensorIntrin
|
||||
|
||||
TensorIntrin TensorIntrinNode::make(std::string name,
|
||||
Operation op,
|
||||
|
@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|||
});
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
|
||||
|
||||
|
||||
// TensorIntrinCall
|
||||
|
||||
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
|
||||
Array<Tensor> tensors,
|
||||
Array<Region> regions,
|
||||
Array<IterVar> reduce_axis) {
|
||||
auto n = make_node<TensorIntrinCallNode>();
|
||||
n->intrin = std::move(intrin);
|
||||
n->tensors = std::move(tensors);
|
||||
n->regions = std::move(regions);
|
||||
n->reduce_axis = std::move(reduce_axis);
|
||||
return TensorIntrinCall(n);
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||
.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
|
||||
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
|
||||
});
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "compute_op.h"
|
||||
#include "op_util.h"
|
||||
#include "../schedule/message_passing.h"
|
||||
#include "../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
|
|||
v.Run();
|
||||
}
|
||||
|
||||
Stmt TransformUpdate(const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const ComputeLoopNest& n,
|
||||
Stmt body,
|
||||
Stmt update) {
|
||||
Array<Expr> conds;
|
||||
std::unordered_set<const Variable*> banned;
|
||||
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
|
||||
IterVar iv = stage->leaf_iter_vars[i];
|
||||
auto iit = stage->iter_var_attrs.find(iv);
|
||||
if (iit != stage->iter_var_attrs.end()) {
|
||||
const IterVarAttr& attr = (*iit).second;
|
||||
if (attr->iter_type == kTensorized) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (iv->iter_type == kCommReduce) {
|
||||
auto vit = dom_map.find(iv);
|
||||
CHECK(vit != dom_map.end());
|
||||
const Range& vrange = vit->second;
|
||||
conds.push_back(likely(iv->var > vrange->min));
|
||||
banned.insert(iv->var.get());
|
||||
}
|
||||
}
|
||||
for (const Expr& pred : n.main_predicates) {
|
||||
if (ir::ExprUseVar(pred, banned)) {
|
||||
LOG(FATAL) << "Tensorize update transform failed, the condition "
|
||||
<< pred << " has a conflict with the reset condition";
|
||||
}
|
||||
}
|
||||
|
||||
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
|
||||
update, body);
|
||||
}
|
||||
} // namespace tvm
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace tvm {
|
||||
// loop nest structure for general compute
|
||||
// This the the loop nest structured used in compute.
|
||||
// This the loop nest structured used in compute.
|
||||
// Does not include the loop body.
|
||||
struct ComputeLoopNest {
|
||||
// The common number of loops between init and main
|
||||
|
@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
|
|||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
bool debug_keep_trivial_loop);
|
||||
|
||||
/*!
|
||||
* \brief Transform the update part when there is no init func in tensorizing
|
||||
* \param stage The stage for tensorizing.
|
||||
* \param dom_map The range of each iter var.
|
||||
* \param n The loop nest structured used in compute.
|
||||
* \param body The body func in tensorize intrin
|
||||
* \param update The update func in tensorize intrin
|
||||
* \return Transformed result.
|
||||
*/
|
||||
Stmt TransformUpdate(const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const ComputeLoopNest& n,
|
||||
Stmt body,
|
||||
Stmt update);
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_OP_COMPUTE_OP_H_
|
||||
|
|
|
@ -0,0 +1,361 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \brief Tensor Compute Op.
|
||||
* \file tensor_compute_op.cc
|
||||
*/
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/arithmetic.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <unordered_set>
|
||||
#include "./op_util.h"
|
||||
#include "./compute_op.h"
|
||||
#include "../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
using namespace ir;
|
||||
// TensorComputeOpNode
|
||||
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||
.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op,
|
||||
IRPrinter *p) {
|
||||
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
|
||||
});
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
|
||||
|
||||
int TensorComputeOpNode::num_outputs() const {
|
||||
return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
|
||||
}
|
||||
|
||||
Array<IterVar> TensorComputeOpNode::root_iter_vars() const {
|
||||
Array<IterVar> ret = axis;
|
||||
for (IterVar iv : reduce_axis) {
|
||||
ret.push_back(iv);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Type TensorComputeOpNode::output_dtype(size_t i) const {
|
||||
return this->intrin->buffers[this->inputs.size() + i]->dtype;
|
||||
}
|
||||
|
||||
Array<Expr> TensorComputeOpNode::output_shape(size_t i) const {
|
||||
Array<Expr> shape;
|
||||
for (const auto& ivar : this->axis) {
|
||||
shape.push_back(ivar->dom->extent);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
||||
Operation TensorComputeOpNode::make(std::string name,
|
||||
std::string tag,
|
||||
Array<IterVar> axis,
|
||||
Array<IterVar> reduce_axis,
|
||||
int schedulable_ndim,
|
||||
TensorIntrin intrin,
|
||||
Array<Tensor> tensors,
|
||||
Array<Region> regions) {
|
||||
auto n = make_node<TensorComputeOpNode>();
|
||||
n->name = std::move(name);
|
||||
n->tag = std::move(tag);
|
||||
n->axis = std::move(axis);
|
||||
n->reduce_axis = std::move(reduce_axis);
|
||||
n->schedulable_ndim = std::move(schedulable_ndim);
|
||||
n->intrin = std::move(intrin);
|
||||
n->inputs = std::move(tensors);
|
||||
n->input_regions = std::move(regions);
|
||||
return Operation(n);
|
||||
}
|
||||
|
||||
Array<Tensor> TensorComputeOpNode::InputTensors() const {
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Operation TensorComputeOpNode::ReplaceInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, Tensor>& rmap) const {
|
||||
CHECK_EQ(self.operator->(), this);
|
||||
auto n = make_node<TensorComputeOpNode>(*this);
|
||||
auto intrin = make_node<TensorIntrinNode>(*(this->intrin.operator->()));
|
||||
intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
|
||||
if (intrin->reduce_init.defined()) {
|
||||
intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
|
||||
}
|
||||
if (intrin->reduce_update.defined()) {
|
||||
intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
|
||||
}
|
||||
for (size_t i = 0; i < n->inputs.size(); ++i) {
|
||||
Tensor t = n->inputs[i];
|
||||
if (rmap.count(t)) {
|
||||
n->inputs.Set(i, rmap.at(t));
|
||||
}
|
||||
}
|
||||
|
||||
if (intrin->body.same_as(n->intrin->body) &&
|
||||
intrin->reduce_init.same_as(n->intrin->reduce_init) &&
|
||||
intrin->reduce_update.same_as(n->intrin->reduce_update) &&
|
||||
inputs.same_as(n->inputs)) {
|
||||
return self;
|
||||
} else {
|
||||
n->intrin = TensorIntrin(intrin);
|
||||
return Operation(n);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorComputeOpNode::PropBoundToInputs(
|
||||
const Operation& self,
|
||||
const std::unordered_map<const Variable*, IntSet>& dom_map,
|
||||
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
|
||||
for (size_t i = 0; i < this->inputs.size(); ++i) {
|
||||
Tensor t = this->inputs[i];
|
||||
Region region = input_regions[i];
|
||||
|
||||
auto it = out_dom_map->find(t);
|
||||
if (it == out_dom_map->end()) continue;
|
||||
TensorDom& dom = it->second;
|
||||
for (size_t j = 0; j < t.ndim(); ++j) {
|
||||
dom.data[j].emplace_back(EvalSet(region[j], dom_map));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TensorComputeOpNode::GatherBound(
|
||||
const Operation& self,
|
||||
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
|
||||
std::unordered_map<IterVar, Range>* out_dom_map) const {
|
||||
const TensorDom& tdom = tensor_dom.at(self.output(0));
|
||||
for (size_t i = 0; i < this->axis.size(); ++i) {
|
||||
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
|
||||
CHECK(!out_dom_map->count(this->axis[i]));
|
||||
(*out_dom_map)[this->axis[i]] = r;
|
||||
}
|
||||
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
|
||||
CHECK(!out_dom_map->count(this->reduce_axis[i]));
|
||||
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
|
||||
}
|
||||
}
|
||||
|
||||
Stmt TensorComputeOpNode::BuildRealize(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& realize_map,
|
||||
const Stmt& body) const {
|
||||
CHECK_EQ(stage->op.get(), this);
|
||||
HalideIR::Internal::Region bounds;
|
||||
for (IterVar iv : this->axis) {
|
||||
bounds.push_back(realize_map.at(iv));
|
||||
}
|
||||
Stmt realize = body;
|
||||
for (int i = this->num_outputs(); i > 0; --i) {
|
||||
Tensor t = stage->op.output(i-1);
|
||||
realize = ir::Realize::make(t->op, t->value_index,
|
||||
t->dtype, bounds, const_true(), realize);
|
||||
// alignment requirement, only useful for compute
|
||||
for (int i = 0; i < schedulable_ndim; ++i) {
|
||||
auto it = stage->iter_var_attrs.find(this->axis[i]);
|
||||
if (it != stage->iter_var_attrs.end()) {
|
||||
IterVarAttr attr = (*it).second;
|
||||
if (attr->dim_align_factor != 0) {
|
||||
Array<Expr> tuple = {static_cast<int>(i),
|
||||
attr->dim_align_factor,
|
||||
attr->dim_align_offset};
|
||||
realize = ir::AttrStmt::make(
|
||||
t, ir::attr::buffer_dim_align,
|
||||
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
|
||||
realize);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return realize;
|
||||
}
|
||||
|
||||
ComputeLoopNest MakeLoopNest(
|
||||
const TensorComputeOpNode* self,
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
bool debug_keep_trivial_loop) {
|
||||
CHECK_EQ(stage->op.operator->(), self);
|
||||
ComputeLoopNest ret;
|
||||
// make main loop nest
|
||||
ret.main_nest = op::MakeLoopNest(
|
||||
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
|
||||
debug_keep_trivial_loop);
|
||||
ret.main_predicates = schedule::MakeBoundCheck(
|
||||
stage, dom_map, ret.main_vmap, false,
|
||||
std::unordered_set<IterVar>());
|
||||
for (auto& e : ret.main_predicates) {
|
||||
e = likely(e);
|
||||
}
|
||||
if (stage->store_predicate.defined()) {
|
||||
ret.main_predicates.push_back(stage->store_predicate);
|
||||
}
|
||||
if (self->reduce_axis.size() != 0) {
|
||||
// try to find the location to insert the initialization.
|
||||
// Fuse the initialization and provide loop when possible.
|
||||
std::unordered_map<IterVar, int> update_state;
|
||||
for (IterVar iv : self->reduce_axis) {
|
||||
update_state[iv] = 2;
|
||||
}
|
||||
for (int i = 0; i < self->schedulable_ndim; ++i) {
|
||||
update_state[self->axis[i]] = 1;
|
||||
}
|
||||
// find which iter var is related to reduction and which is related to axis.
|
||||
schedule::PassDownBitMaskOr(stage, &update_state);
|
||||
auto leaf_iter_vars = stage->leaf_iter_vars;
|
||||
// first first loop that is related to reduction.
|
||||
size_t begin_loop = leaf_iter_vars.size();
|
||||
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
|
||||
auto iv = leaf_iter_vars[i];
|
||||
int flag = update_state.at(iv);
|
||||
if ((flag & 2) != 0) {
|
||||
begin_loop = i; break;
|
||||
}
|
||||
ret.init_vmap[iv] = ret.main_vmap.at(iv);
|
||||
}
|
||||
ret.num_common_loop = begin_loop;
|
||||
// skip loops that does not relates to axis.
|
||||
std::unordered_set<IterVar> skip_iter;
|
||||
for (auto kv : update_state) {
|
||||
int flag = kv.second;
|
||||
if ((flag & 1) == 0) skip_iter.insert(kv.first);
|
||||
}
|
||||
ret.init_nest = op::MakeLoopNest(
|
||||
stage, dom_map, begin_loop, true,
|
||||
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
|
||||
ret.init_predicates = schedule::MakeBoundCheck(
|
||||
stage, dom_map, ret.init_vmap, true, skip_iter);
|
||||
for (auto& e : ret.init_predicates) {
|
||||
e = likely(e);
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
|
||||
ret.num_common_loop = stage->leaf_iter_vars.size();
|
||||
}
|
||||
// copy elison here.
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
Stmt TensorComputeOpNode::BuildProvide(
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
bool debug_keep_trivial_loop) const {
|
||||
CHECK_EQ(stage->op.operator->(), this);
|
||||
|
||||
// Start bind data.
|
||||
Stmt nop = Evaluate::make(0);
|
||||
std::vector<Stmt> input_bind_nest, output_bind_nest;
|
||||
Array<Tensor> inputs = this->InputTensors();
|
||||
|
||||
// input binding
|
||||
size_t num_inputs = inputs.size();
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
Tensor tensor = inputs[i];
|
||||
Region region = this->input_regions[i];
|
||||
Buffer buffer = this->intrin->buffers[i];
|
||||
Array<NodeRef> bind_spec{buffer, tensor};
|
||||
|
||||
Array<Expr> tuple;
|
||||
for (size_t i = 0; i < region.size(); ++i) {
|
||||
tuple.push_back(region[i]->min);
|
||||
tuple.push_back(region[i]->extent);
|
||||
}
|
||||
input_bind_nest.emplace_back(AttrStmt::make(
|
||||
bind_spec, ir::attr::buffer_bind_scope,
|
||||
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
|
||||
}
|
||||
|
||||
// output binding
|
||||
for (int i = 0; i < this->num_outputs(); ++i) {
|
||||
Tensor tensor = stage->op.output(i);
|
||||
Buffer buffer = this->intrin->buffers[num_inputs + i];
|
||||
Array<NodeRef> bind_spec{buffer, tensor};
|
||||
|
||||
Array<Expr> tuple;
|
||||
for (size_t i = 0; i < this->axis.size(); ++i) {
|
||||
auto ivar = this->axis[i];
|
||||
if (i < static_cast<size_t>(this->schedulable_ndim)) {
|
||||
tuple.push_back(ivar->var);
|
||||
tuple.push_back(1);
|
||||
} else {
|
||||
Range dom = ivar->dom;
|
||||
tuple.push_back(dom->min);
|
||||
tuple.push_back(dom->extent);
|
||||
}
|
||||
}
|
||||
|
||||
output_bind_nest.emplace_back(AttrStmt::make(
|
||||
bind_spec, ir::attr::buffer_bind_scope,
|
||||
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
|
||||
}
|
||||
|
||||
// Check variable remap
|
||||
std::unordered_map<const Variable*, Expr> vmap;
|
||||
ir::ArgBinder binder(&vmap);
|
||||
|
||||
size_t tloc = stage->leaf_iter_vars.size();
|
||||
ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);
|
||||
|
||||
if (this->reduce_axis.size() == 0) {
|
||||
std::vector<std::vector<Stmt> > nest(
|
||||
n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
|
||||
nest.emplace_back(op::MakeIfNest(n.main_predicates));
|
||||
CHECK_EQ(n.init_predicates.size(), 0U);
|
||||
CHECK(this->intrin->body.defined())
|
||||
<< "Normal store op for intrin " << this << " is not defined";
|
||||
Stmt body = MergeNest(output_bind_nest, this->intrin->body);
|
||||
body = MergeNest(input_bind_nest, body);
|
||||
body = ir::Substitute(body, vmap);
|
||||
body = MergeNest(binder.asserts(), body);
|
||||
body = op::Substitute(body, n.main_vmap);
|
||||
Stmt ret = MergeNest(nest, body);
|
||||
return ret;
|
||||
} else {
|
||||
// Need to split reduction
|
||||
CHECK(this->intrin->reduce_update.defined())
|
||||
<< "Reduction update op is not defined";
|
||||
// Need init and update steps
|
||||
CHECK_NE(this->reduce_axis.size(), 0U);
|
||||
std::vector<std::vector<Stmt> > common(
|
||||
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
|
||||
std::vector<std::vector<Stmt> > update_nest(
|
||||
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
|
||||
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
|
||||
|
||||
if (this->intrin->reduce_init.defined()) {
|
||||
// init nest
|
||||
std::vector<std::vector<Stmt> > init_nest(
|
||||
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
|
||||
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
|
||||
Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
|
||||
init = op::Substitute(init, n.init_vmap);
|
||||
init = MergeNest(init_nest, init);
|
||||
// The update
|
||||
Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
|
||||
update = MergeNest(input_bind_nest, update);
|
||||
update = ir::Substitute(update, vmap);
|
||||
update = MergeNest(binder.asserts(), update);
|
||||
update = op::Substitute(update, n.main_vmap);
|
||||
update = MergeNest(update_nest, update);
|
||||
return MergeNest(common, Block::make(init, update));
|
||||
} else {
|
||||
// When init op is not available, use body op for reset in the first iter.
|
||||
CHECK(this->intrin->body.defined())
|
||||
<< "Normal body op is not defined";
|
||||
Stmt update = TransformUpdate(stage, dom_map, n,
|
||||
this->intrin->body,
|
||||
this->intrin->reduce_update);
|
||||
update = MergeNest(output_bind_nest, update);
|
||||
update = MergeNest(input_bind_nest, update);
|
||||
update = ir::Substitute(update, vmap);
|
||||
update = MergeNest(binder.asserts(), update);
|
||||
update = op::Substitute(update, n.main_vmap);
|
||||
update = MergeNest(update_nest, update);
|
||||
return MergeNest(common, update);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tvm
|
|
@ -10,7 +10,6 @@
|
|||
#include "op_util.h"
|
||||
#include "compute_op.h"
|
||||
#include "../schedule/message_passing.h"
|
||||
#include "../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
@ -323,50 +322,6 @@ void VerifyTensorizeBody(
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Transform the update part when there is no init func in tensorizing
|
||||
* \param stage The stage for tensorizing.
|
||||
* \param dom_map The range of each iter var.
|
||||
* \param n The loop nest structured used in compute.
|
||||
* \param body The body func in tensorize intrin
|
||||
* \param update The update func in tensorize intrin
|
||||
* \return Transformed result.
|
||||
*/
|
||||
Stmt TransformUpdate(const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const ComputeLoopNest& n,
|
||||
Stmt body,
|
||||
Stmt update) {
|
||||
Array<Expr> conds;
|
||||
std::unordered_set<const Variable*> banned;
|
||||
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
|
||||
IterVar iv = stage->leaf_iter_vars[i];
|
||||
auto iit = stage->iter_var_attrs.find(iv);
|
||||
if (iit != stage->iter_var_attrs.end()) {
|
||||
const IterVarAttr& attr = (*iit).second;
|
||||
if (attr->iter_type == kTensorized) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (iv->iter_type == kCommReduce) {
|
||||
auto vit = dom_map.find(iv);
|
||||
CHECK(vit != dom_map.end());
|
||||
const Range& vrange = vit->second;
|
||||
conds.push_back(likely(iv->var > vrange->min));
|
||||
banned.insert(iv->var.get());
|
||||
}
|
||||
}
|
||||
for (const Expr& pred : n.main_predicates) {
|
||||
if (ir::ExprUseVar(pred, banned)) {
|
||||
LOG(FATAL) << "Tensorize update transform failed, the condition "
|
||||
<< pred << " has a conflict with the reset condition";
|
||||
}
|
||||
}
|
||||
|
||||
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
|
||||
update, body);
|
||||
}
|
||||
|
||||
Stmt MakeTensorize(const ComputeOpNode* self,
|
||||
const Stage& stage,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
|
|
|
@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
|
|||
// bind pointer and offset.
|
||||
if (is_zero(arg->elem_offset)) {
|
||||
CHECK(is_zero(value->elem_offset))
|
||||
<< "Trying to bind a Buffer with offset into one without offset";
|
||||
<< "Trying to bind a Buffer with offset into one without offset "
|
||||
<< " required elem_offset=" << arg->elem_offset
|
||||
<< ", provided elem_offset=" << value->elem_offset;
|
||||
}
|
||||
|
||||
this->Bind(arg->data, value->data, arg_name + ".data");
|
||||
|
|
|
@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor,
|
|||
return cache;
|
||||
}
|
||||
|
||||
// Cache write and relayout the data according to loop pattern
|
||||
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||
const Array<Tensor>& tensor_array,
|
||||
const std::string& scope) {
|
||||
size_t tensor_size = tensor_array.size();
|
||||
sch->InvalidateCache();
|
||||
Tensor tensor = tensor_array[0];
|
||||
Stage orig_stage = sch[tensor->op];
|
||||
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
|
||||
std::unordered_set<IterVar> red_axis;
|
||||
for (IterVar iv : compute->reduce_axis) {
|
||||
template<typename OpType>
|
||||
void PrepareAxisMapping(Stage orig_stage,
|
||||
OpType* op,
|
||||
std::unordered_set<IterVar>* p_red_axis,
|
||||
Array<IterVar>* p_new_axis,
|
||||
std::unordered_map<IterVar, Range>* p_dom_map,
|
||||
std::unordered_map<const Variable*, Expr>* p_vsub,
|
||||
std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
|
||||
std::vector<Expr>* p_predicates) {
|
||||
auto& red_axis = *p_red_axis;
|
||||
auto& new_axis = *p_new_axis;
|
||||
auto& dom_map = *p_dom_map;
|
||||
auto& vsub = *p_vsub;
|
||||
auto& vsub2newvar = *p_vsub2newvar;
|
||||
auto& predicates = *p_predicates;
|
||||
|
||||
for (IterVar iv : op->reduce_axis) {
|
||||
red_axis.insert(iv);
|
||||
}
|
||||
std::unordered_map<IterVar, Range> dom_map;
|
||||
Array<IterVar> new_axis;
|
||||
|
||||
for (IterVar iv : compute->axis) {
|
||||
for (IterVar iv : op->axis) {
|
||||
dom_map[iv] = iv->dom;
|
||||
}
|
||||
schedule::PassDownDomain(orig_stage, &dom_map, true);
|
||||
std::unordered_map<const Variable*, Expr> vsub;
|
||||
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
||||
std::vector<Expr> predicates;
|
||||
{
|
||||
// The source->cache
|
||||
std::unordered_map<IterVar, Expr> value_map;
|
||||
|
@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
|||
}
|
||||
// skip reduction iteration.
|
||||
std::unordered_set<IterVar> skip_bound_check;
|
||||
for (IterVar iv : compute->reduce_axis) {
|
||||
for (IterVar iv : op->reduce_axis) {
|
||||
skip_bound_check.insert(iv);
|
||||
}
|
||||
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
|
||||
predicates = schedule::MakeBoundCheck(
|
||||
orig_stage, dom_map, value_map, true, skip_bound_check);
|
||||
// The root axis
|
||||
for (IterVar iv : compute->axis) {
|
||||
vsub[iv->var.get()] = value_map.at(iv);
|
||||
for (IterVar iv : op->axis) {
|
||||
if (value_map.count(iv)) {
|
||||
vsub[iv->var.get()] = value_map.at(iv);
|
||||
} // to handle tensor axis
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Array<Tensor> ReplaceOriginalOp(Schedule sch,
|
||||
Stage orig_stage,
|
||||
const std::string& scope,
|
||||
Operation cache_op,
|
||||
Operation orig_new_op,
|
||||
size_t tensor_size) {
|
||||
Array<Tensor> cache_tensor_list;
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
Tensor cache_tensor = cache_op.output(i);
|
||||
cache_tensor_list.push_back(cache_tensor);
|
||||
}
|
||||
// The replace of the dataflow
|
||||
std::unordered_map<Tensor, Tensor> vmap;
|
||||
std::unordered_map<Tensor, Tensor> rvmap;
|
||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||
}
|
||||
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
|
||||
// mutate orig stage
|
||||
orig_stage->op = orig_new_op;
|
||||
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
|
||||
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
|
||||
orig_stage->relations = Array<IterVarRelation>();
|
||||
// create schedule for new cached stage.
|
||||
ArrayNode* stages = sch->stages.CopyOnWrite();
|
||||
size_t pos = FindNodeRef(stages, orig_stage);
|
||||
Stage cache_stage = Stage(cache_op);
|
||||
cache_stage.set_scope(scope);
|
||||
CHECK_LT(pos, stages->data.size());
|
||||
stages->data.insert(stages->data.begin() + pos,
|
||||
cache_stage.node_);
|
||||
sch->stage_map.Set(cache_op, cache_stage);
|
||||
// Update group
|
||||
cache_stage->group = orig_stage->group;
|
||||
if (cache_stage->group.defined()) {
|
||||
++cache_stage->group->num_child_stages;
|
||||
}
|
||||
return cache_tensor_list;
|
||||
}
|
||||
|
||||
|
||||
// Cache write and relayout the data according to loop pattern
|
||||
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
||||
const Array<Tensor>& tensor_array,
|
||||
const std::string& scope) {
|
||||
size_t tensor_size = tensor_array.size();
|
||||
sch->InvalidateCache();
|
||||
Tensor tensor = tensor_array[0];
|
||||
Stage orig_stage = sch[tensor->op];
|
||||
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
|
||||
|
||||
std::unordered_set<IterVar> red_axis;
|
||||
Array<IterVar> new_axis;
|
||||
std::unordered_map<IterVar, Range> dom_map;
|
||||
|
||||
std::unordered_map<const Variable*, Expr> vsub;
|
||||
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
||||
std::vector<Expr> predicates;
|
||||
|
||||
PrepareAxisMapping(orig_stage, compute,
|
||||
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
|
||||
|
||||
Expr body;
|
||||
Array<Expr> body_list;
|
||||
|
@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
|||
body = InjectPredicate(predicates, body);
|
||||
body = VarReplacer(vsub2newvar).Mutate(body);
|
||||
// Reduce nodes in ONE computeOp must be the same except value_index
|
||||
// This is right only if the oringinal body ensures Reduce nodes are the same
|
||||
// This is right only if the original body ensures Reduce nodes are the same
|
||||
if (body->is_type<ir::Reduce>()) {
|
||||
const ir::Reduce* reduce_body = body.as<ir::Reduce>();
|
||||
if (first_reduce != nullptr) {
|
||||
|
@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
|
|||
Operation cache_op = ComputeOpNode::make(
|
||||
compute->name + "." + scope, compute->tag, compute->attrs,
|
||||
new_axis, body_list);
|
||||
Array<Tensor> cache_tensor_list;
|
||||
|
||||
Array<Expr> cache_expr_list;
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
Tensor cache_tensor = cache_op.output(i);
|
||||
cache_tensor_list.push_back(cache_tensor);
|
||||
cache_expr_list.push_back(cache_tensor(args));
|
||||
}
|
||||
Operation orig_new_op = ComputeOpNode::make(
|
||||
compute->name, compute->tag, compute->attrs,
|
||||
compute->axis, cache_expr_list);
|
||||
// The replace of the dataflow
|
||||
std::unordered_map<Tensor, Tensor> vmap;
|
||||
std::unordered_map<Tensor, Tensor> rvmap;
|
||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
|
||||
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
|
||||
}
|
||||
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
|
||||
// mutate orig stage
|
||||
orig_stage->op = orig_new_op;
|
||||
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
|
||||
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
|
||||
orig_stage->relations = Array<IterVarRelation>();
|
||||
// create schedule for new cached stage.
|
||||
ArrayNode* stages = sch->stages.CopyOnWrite();
|
||||
size_t pos = FindNodeRef(stages, orig_stage);
|
||||
Stage cache_stage = Stage(cache_op);
|
||||
cache_stage.set_scope(scope);
|
||||
CHECK_LT(pos, stages->data.size());
|
||||
stages->data.insert(stages->data.begin() + pos,
|
||||
cache_stage.node_);
|
||||
sch->stage_map.Set(cache_op, cache_stage);
|
||||
// Update group
|
||||
cache_stage->group = orig_stage->group;
|
||||
if (cache_stage->group.defined()) {
|
||||
++cache_stage->group->num_child_stages;
|
||||
}
|
||||
return cache_tensor_list;
|
||||
return ReplaceOriginalOp(sch, orig_stage, scope,
|
||||
cache_op, orig_new_op, tensor_size);
|
||||
}
|
||||
|
||||
|
||||
// for tensor compute op
|
||||
Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
|
||||
const Array<Tensor>& tensor_array,
|
||||
const std::string& scope) {
|
||||
size_t tensor_size = tensor_array.size();
|
||||
sch->InvalidateCache();
|
||||
Tensor tensor = tensor_array[0];
|
||||
Stage orig_stage = sch[tensor->op];
|
||||
const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
|
||||
CHECK_EQ(tensor_op->num_outputs(), 1)
|
||||
<< "cache write only support single output tensor_compute_op";
|
||||
|
||||
std::unordered_set<IterVar> red_axis;
|
||||
Array<IterVar> new_axis;
|
||||
std::unordered_map<IterVar, Range> dom_map;
|
||||
|
||||
std::unordered_map<const Variable*, Expr> vsub;
|
||||
std::unordered_map<const Variable*, Expr> vsub2newvar;
|
||||
std::vector<Expr> predicates;
|
||||
|
||||
PrepareAxisMapping(orig_stage, tensor_op,
|
||||
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
|
||||
|
||||
|
||||
for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
|
||||
IterVar iv = tensor_op->axis[i];
|
||||
IterVar new_iv = IterVarNode::make(
|
||||
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
|
||||
new_axis.push_back(new_iv);
|
||||
}
|
||||
Array<Region> new_regions;
|
||||
for (Region old_region : tensor_op->input_regions) {
|
||||
Region region;
|
||||
for (Range r : old_region) {
|
||||
Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
|
||||
Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
|
||||
region.push_back(Range::make_by_min_extent(min, extent));
|
||||
}
|
||||
new_regions.push_back(region);
|
||||
}
|
||||
|
||||
Operation cache_op = TensorComputeOpNode::make(
|
||||
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
|
||||
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
|
||||
tensor_op->intrin, tensor_op->inputs, new_regions);
|
||||
|
||||
// axis will be used in generating compute op
|
||||
Array<IterVar> compute_axis = tensor_op->axis;
|
||||
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
|
||||
IterVar iv = tensor_op->axis[i];
|
||||
IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
|
||||
compute_axis.Set(i, aiv);
|
||||
}
|
||||
|
||||
// The reader args
|
||||
Array<Expr> args;
|
||||
{
|
||||
// cache->compute
|
||||
std::unordered_map<IterVar, Expr> value_map;
|
||||
for (IterVar iv : compute_axis) {
|
||||
value_map[iv] = iv->var;
|
||||
}
|
||||
schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
|
||||
for (IterVar iv : orig_stage->leaf_iter_vars) {
|
||||
if (red_axis.count(iv)) continue;
|
||||
args.push_back(value_map.at(iv));
|
||||
}
|
||||
// tensorized region axis
|
||||
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
|
||||
IterVar iv = compute_axis[i];
|
||||
args.push_back(value_map.at(iv));
|
||||
}
|
||||
}
|
||||
|
||||
Array<Expr> cache_expr_list;
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
Tensor cache_tensor = cache_op.output(i);
|
||||
cache_expr_list.push_back(cache_tensor(args));
|
||||
}
|
||||
Operation orig_new_op = ComputeOpNode::make(
|
||||
tensor_op->name, tensor_op->tag, {},
|
||||
compute_axis, cache_expr_list);
|
||||
return ReplaceOriginalOp(sch, orig_stage, scope,
|
||||
cache_op, orig_new_op, tensor_size);
|
||||
}
|
||||
|
||||
|
||||
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
|
||||
const std::string& scope) {
|
||||
(*this)->InvalidateCache();
|
||||
|
@ -291,23 +418,26 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
|
|||
CHECK(orig_stage.same_as(tmp_stage))
|
||||
<< "Input tensor list must be generated by ONE computeOp";
|
||||
}
|
||||
|
||||
return CacheWriteWithReLayout(*this, tensor_array, scope);
|
||||
}
|
||||
|
||||
|
||||
Tensor Schedule::cache_write(const Tensor& tensor,
|
||||
const std::string& scope) {
|
||||
// support original compute and tensor compute both
|
||||
(*this)->InvalidateCache();
|
||||
Stage orig_stage = operator[](tensor->op);
|
||||
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
|
||||
CHECK(compute)
|
||||
<< "cache write only take ComputeOp as writers";
|
||||
CHECK_EQ(compute->num_outputs(), 1)
|
||||
<< "cache write only support single output ComputeOp";
|
||||
|
||||
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
|
||||
const char* type_key = tensor->op->type_key();
|
||||
if (!strcmp(type_key, "ComputeOp")) {
|
||||
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
|
||||
} else if (!strcmp(type_key, "TensorComputeOp")) {
|
||||
return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
|
||||
} else {
|
||||
LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
|
||||
return Tensor();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RebaseNonZeroMinLoop(const Schedule& sch) {
|
||||
std::unordered_map<IterVar, IterVar> rebase_map;
|
||||
for (Stage s : sch->stages) {
|
||||
|
|
|
@ -85,6 +85,78 @@ def test_tensor_reduce():
|
|||
assert(isinstance(C_loaded, tvm.tensor.Tensor))
|
||||
assert(str(C_loaded) == str(C))
|
||||
|
||||
def test_tensor_compute1():
|
||||
m = 1024
|
||||
factor = 16
|
||||
dtype = 'float32'
|
||||
|
||||
def intrin_vadd(n):
|
||||
x = tvm.placeholder((n,))
|
||||
y = tvm.placeholder((n,))
|
||||
z = tvm.compute(x.shape, lambda i: x[i] + y[i])
|
||||
|
||||
def intrin_func(ins, outs):
|
||||
ib = tvm.ir_builder.create()
|
||||
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
|
||||
return ib.get()
|
||||
|
||||
with tvm.build_config(offset_factor=n):
|
||||
return tvm.decl_tensor_intrin(z.op, intrin_func)
|
||||
|
||||
vadd = intrin_vadd(factor)
|
||||
|
||||
A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype)
|
||||
B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype)
|
||||
C = tvm.compute((m//factor, factor),
|
||||
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
|
||||
assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
|
||||
|
||||
def test_tensor_compute2():
|
||||
M = 2048
|
||||
N = 1024
|
||||
L = 1024
|
||||
factor = 16
|
||||
factor1 = 32
|
||||
factor2 = 32
|
||||
dtype = 'float32'
|
||||
|
||||
def intrin_gemm(m, n, l):
|
||||
k = tvm.reduce_axis((0, l))
|
||||
x = tvm.placeholder((m, l))
|
||||
y = tvm.placeholder((n, l))
|
||||
# in theory, no relation
|
||||
z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k))
|
||||
|
||||
def intrin_func(ins, outs):
|
||||
x_ptr = ins[0].access_ptr("r")
|
||||
y_ptr = ins[1].access_ptr("r")
|
||||
z_ptr = outs[0].access_ptr("w")
|
||||
body = tvm.call_packed(
|
||||
"gemv", x_ptr, y_ptr, z_ptr, m, n, l)
|
||||
reset = tvm.call_packed(
|
||||
"fill_zero", z_ptr, m, n)
|
||||
update = tvm.call_packed(
|
||||
"gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
|
||||
return body, reset, update
|
||||
|
||||
with tvm.build_config(offset_factor=n):
|
||||
return tvm.decl_tensor_intrin(z.op, intrin_func)
|
||||
|
||||
vgemm = intrin_gemm(factor1, factor2, factor)
|
||||
|
||||
A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
|
||||
B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
|
||||
k = tvm.reduce_axis((0, L//factor), name='k')
|
||||
C = tvm.compute((M//factor1, N//factor2, factor1, factor2),
|
||||
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
|
||||
assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
|
||||
assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
|
||||
|
||||
def test_tensor_scan():
|
||||
m = tvm.var("m")
|
||||
|
@ -221,6 +293,8 @@ if __name__ == "__main__":
|
|||
test_conv1d()
|
||||
test_tensor_slice()
|
||||
test_tensor()
|
||||
test_tensor_compute1()
|
||||
test_tensor_compute2()
|
||||
test_tensor_reduce()
|
||||
test_tensor_scan()
|
||||
test_scan_multi_out()
|
||||
|
|
|
@ -276,6 +276,133 @@ def test_schedule_bound_condition():
|
|||
stmt = tvm.ir_pass.Simplify(stmt)
|
||||
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
|
||||
|
||||
|
||||
def intrin_gemv(m, n):
|
||||
w = tvm.placeholder((m, n), name='w')
|
||||
x = tvm.placeholder((n,), name='x')
|
||||
k = tvm.reduce_axis((0, n), name='k')
|
||||
z = tvm.compute((m,), lambda i:
|
||||
tvm.sum(w[i, k] * x[k], axis=k), name='z')
|
||||
Wb = tvm.decl_buffer(w.shape, w.dtype,
|
||||
name="W",
|
||||
offset_factor=16,
|
||||
strides=[tvm.var('ldw'), 1])
|
||||
def intrin_func(ins, outs):
|
||||
ww, xx = ins
|
||||
zz = outs[0]
|
||||
ww_ptr = ww.access_ptr("r")
|
||||
xx_ptr = xx.access_ptr("r")
|
||||
zz_ptr = zz.access_ptr("w")
|
||||
body = tvm.call_packed(
|
||||
"gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
|
||||
reset = tvm.call_packed(
|
||||
"fill_zero", zz_ptr, n)
|
||||
update = tvm.call_packed(
|
||||
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
|
||||
return body, reset, update
|
||||
|
||||
with tvm.build_config(data_alignment=16,
|
||||
offset_factor=16):
|
||||
return tvm.decl_tensor_intrin(z.op, intrin_func,
|
||||
binds={w: Wb})
|
||||
|
||||
|
||||
def test_schedule_tensor_compute1():
|
||||
# basic: split, reorder, tile
|
||||
M, N, L = 2048, 1024, 512
|
||||
factor, rfactor = 16, 16
|
||||
A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
|
||||
B = tvm.placeholder((M, L//rfactor, rfactor), name='B')
|
||||
k = tvm.reduce_axis((0, L//rfactor), name='k')
|
||||
|
||||
gemv = intrin_gemv(factor, rfactor)
|
||||
C = tvm.compute((N, M//factor, factor),
|
||||
lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
|
||||
name='C')
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
ai, aj, ax = s[C].op.axis
|
||||
aio, aii = s[C].split(ai, 16)
|
||||
s[C].reorder(aio, aj, aii)
|
||||
aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4)
|
||||
|
||||
s = s.normalize()
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
|
||||
def intrin_vadd(n, cache_read=False, cache_write=False):
|
||||
scope_ubuf = 'local'
|
||||
dtype = 'float32'
|
||||
x = tvm.placeholder((n,), dtype=dtype, name='vx')
|
||||
y = tvm.placeholder((n,), dtype=dtype, name='vy')
|
||||
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
|
||||
s = tvm.create_schedule(z.op)
|
||||
|
||||
def create_buffer(t):
|
||||
return tvm.decl_buffer(t.shape, t.dtype,
|
||||
name='W'+t.name,
|
||||
scope=scope_ubuf,
|
||||
offset_factor=16)
|
||||
|
||||
binds = {}
|
||||
if cache_read:
|
||||
binds[x] = create_buffer(x)
|
||||
binds[y] = create_buffer(y)
|
||||
if cache_write:
|
||||
binds[z] = create_buffer(z)
|
||||
|
||||
def intrin_func(ins, outs):
|
||||
ib = tvm.ir_builder.create()
|
||||
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
|
||||
return ib.get()
|
||||
|
||||
with tvm.build_config(offset_factor=16):
|
||||
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
|
||||
|
||||
|
||||
def test_schedule_tensor_compute2():
|
||||
# cache_read, cache_write
|
||||
M = 1024
|
||||
factor = 16
|
||||
dtype = 'float32'
|
||||
scope_ubuf = 'local'
|
||||
|
||||
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
|
||||
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
|
||||
|
||||
vadd = intrin_vadd(factor, True, True)
|
||||
C = tvm.compute((M//factor, factor),
|
||||
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
AL = s.cache_read(A, scope_ubuf, C)
|
||||
BL = s.cache_read(B, scope_ubuf, C)
|
||||
CL = s.cache_write(C, scope_ubuf)
|
||||
s = s.normalize()
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
|
||||
def test_schedule_tensor_compute3():
|
||||
# compute_at
|
||||
M = 1024
|
||||
factor = 16
|
||||
dtype = 'float32'
|
||||
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
|
||||
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
|
||||
Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
|
||||
|
||||
vadd = intrin_vadd(factor)
|
||||
C = tvm.compute((M//factor, factor),
|
||||
lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
s[Bi].compute_at(s[C], C.op.axis[0])
|
||||
s = s.normalize()
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_schedule_middle_cache()
|
||||
test_inline_multi_reduce()
|
||||
|
@ -294,3 +421,6 @@ if __name__ == "__main__":
|
|||
test_schedule2()
|
||||
test_schedule_cache()
|
||||
test_schedule_bound_condition()
|
||||
test_schedule_tensor_compute1()
|
||||
test_schedule_tensor_compute2()
|
||||
test_schedule_tensor_compute3()
|
||||
|
|
Загрузка…
Ссылка в новой задаче