[SCHEDULE] Improve bound inference, support reduce codegen. (#30)
This commit is contained in:
Родитель
d4af7ad6ab
Коммит
a2c8a29b21
|
@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
|
|||
using Halide::Internal::Variable;
|
||||
|
||||
using Halide::Internal::make_const;
|
||||
using Halide::Internal::make_zero;
|
||||
using Halide::Internal::as_const_int;
|
||||
using Halide::Internal::as_const_uint;
|
||||
|
||||
|
||||
inline Type TVMType2Type(TVMType t) {
|
||||
|
@ -126,25 +129,25 @@ using Halide::abs;
|
|||
using Halide::select;
|
||||
|
||||
/*!
|
||||
* \brief sum of of source expression over rdom
|
||||
* \brief sum of of source expression over axis
|
||||
* \param source The source expression.
|
||||
* \param rdom List of iteration variables that will be used for reduction.
|
||||
* \param axis List of iteration variables that will be used for reduction.
|
||||
*/
|
||||
Expr sum(Expr source, Array<IterVar> rdom);
|
||||
Expr sum(Expr source, Array<IterVar> axis);
|
||||
|
||||
/*!
|
||||
* \brief max of of source expression over rdom
|
||||
* \brief max of of source expression over axis
|
||||
* \param source The source expression.
|
||||
* \param rdom List of iteration variables that will be used for reduction.
|
||||
* \param axis List of iteration variables that will be used for reduction.
|
||||
*/
|
||||
Expr max(Expr source, Array<IterVar> rdom);
|
||||
Expr max(Expr source, Array<IterVar> axis);
|
||||
|
||||
/*!
|
||||
* \brief max of of source expression over rdom
|
||||
* \brief max of of source expression over axis
|
||||
* \param source The source expression.
|
||||
* \param rdom List of iteration variables that will be used for reduction.
|
||||
* \param axis List of iteration variables that will be used for reduction.
|
||||
*/
|
||||
Expr min(Expr source, Array<IterVar> rdom);
|
||||
Expr min(Expr source, Array<IterVar> axis);
|
||||
|
||||
|
||||
// print functions for expr
|
||||
|
|
|
@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
|
|||
std::string op;
|
||||
/*! \brief The source operand */
|
||||
Expr source;
|
||||
/*! \brief The reduction domains */
|
||||
Array<IterVar> rdom;
|
||||
/*! \brief The reduction axis */
|
||||
Array<IterVar> axis;
|
||||
|
||||
/*! \brief construct expr from op and rdom */
|
||||
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
|
||||
|
@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
|
|||
v->Visit("dtype", &type);
|
||||
v->Visit("op", &op);
|
||||
v->Visit("source", &source);
|
||||
v->Visit("rdom", &rdom);
|
||||
v->Visit("axis", &axis);
|
||||
}
|
||||
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
|
||||
static constexpr const char* _type_key = "Reduce";
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
* \file ir_pass.h
|
||||
* \brief Collection of IR pass functions
|
||||
*
|
||||
* All the pass functions in this file are for Stmt,
|
||||
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
|
||||
* When the pass functions in this file are for Stmt,
|
||||
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
|
||||
*/
|
||||
#ifndef TVM_IR_PASS_H_
|
||||
#define TVM_IR_PASS_H_
|
||||
|
@ -37,15 +37,6 @@ inline Stmt Simplify(Stmt a) {
|
|||
return Halide::Internal::simplify(a);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Schedule s' dependent operations.
|
||||
*
|
||||
* \param s The schedule to be realized
|
||||
* \param dom_map The domain of each iter vars.
|
||||
* \return the result Stmt
|
||||
*/
|
||||
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
|
||||
|
||||
/*!
|
||||
* \brief verifies whether the IR stmt or Expr is in SSA form.
|
||||
* That is: each VarExpr is defined and assigned once(in Let/For)
|
||||
|
@ -69,6 +60,14 @@ bool HasSideEffect(const Expr& e);
|
|||
*/
|
||||
Stmt ConvertSSA(Stmt stmt);
|
||||
|
||||
/*!
|
||||
* \brief Substitute the var specified in key->var to be value.
|
||||
* \param stmt The source statement to be substituted
|
||||
* \param value_map The map of new values.
|
||||
* \return The converted form.
|
||||
*/
|
||||
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
|
||||
|
||||
/*!
|
||||
* \brief inline all calls of f in stmt.
|
||||
*
|
||||
|
|
|
@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
|
|||
public:
|
||||
/*! \brief IterVar on each axis */
|
||||
Array<IterVar> axis;
|
||||
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
|
||||
Array<IterVar> reduce_axis;
|
||||
/*! \brief the compute expression */
|
||||
Expr body;
|
||||
/*! \brief constructor */
|
||||
|
@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
|
|||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("name", &name);
|
||||
v->Visit("axis", &axis);
|
||||
v->Visit("reduce_axis", &reduce_axis);
|
||||
v->Visit("body", &body);
|
||||
}
|
||||
static Operation make(std::string name,
|
||||
|
|
|
@ -123,6 +123,8 @@ class Stage : public NodeRef {
|
|||
IterVar* p_x_outer, IterVar* p_y_outer,
|
||||
IterVar* p_x_inner, IterVar* p_y_inner,
|
||||
Expr x_factor, Expr y_factor);
|
||||
// declare container type
|
||||
using ContainerType = StageNode;
|
||||
};
|
||||
|
||||
/*!
|
||||
|
@ -152,11 +154,22 @@ class Schedule : public NodeRef {
|
|||
Stage operator[](const Tensor& tensor) {
|
||||
return this->operator[](tensor->op);
|
||||
}
|
||||
/*!
|
||||
* \brief Normalize the schedule.
|
||||
* This is needed before bound inference.
|
||||
* Insert necessary RebaseNode to make sure all leaf_iter_vars
|
||||
* are in form [0, extent)
|
||||
*
|
||||
* \return A normalized schedule, can be same as current one.
|
||||
*/
|
||||
void normalize();
|
||||
/*!
|
||||
* \brief access the internal node container
|
||||
* \return the pointer to the internal node container
|
||||
*/
|
||||
inline const ScheduleNode* operator->() const;
|
||||
// declare container type
|
||||
using ContainerType = ScheduleNode;
|
||||
};
|
||||
|
||||
/*!
|
||||
|
@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
|
|||
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Rebase the iteration to make min to be 0.
|
||||
* This is useful to normalize the Schedule
|
||||
* to make every leaf variable's min to be 0.
|
||||
*/
|
||||
class RebaseNode : public IterVarRelationNode {
|
||||
public:
|
||||
/*! \brief The parent domain */
|
||||
IterVar parent;
|
||||
/*! \brief The inner domain */
|
||||
IterVar rebased;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("parent", &parent);
|
||||
v->Visit("rebased", &rebased);
|
||||
}
|
||||
|
||||
static IterVarRelation make(IterVar parent, IterVar rebased);
|
||||
|
||||
static constexpr const char* _type_key = "Rebase";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
|
||||
};
|
||||
|
||||
|
||||
// implementations
|
||||
inline const StageNode* Stage::operator->() const {
|
||||
return static_cast<const StageNode*>(node_.get());
|
||||
|
|
|
@ -24,6 +24,15 @@ namespace schedule {
|
|||
*/
|
||||
Map<IterVar, Range> InferBound(Schedule sch);
|
||||
|
||||
/*!
|
||||
* \brief Schedule s' dependent operations.
|
||||
*
|
||||
* \param s The schedule to be realized
|
||||
* \param dom_map The domain of each iter vars.
|
||||
* \return the result Stmt
|
||||
*/
|
||||
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
||||
#endif // TVM_SCHEDULE_PASS_H_
|
||||
|
|
|
@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
|
|||
return _api_internal._IterVar(dom, name, thread_tag)
|
||||
|
||||
|
||||
def sum(expr, rdom):
|
||||
"""Create a sum expression over rdom
|
||||
def sum(expr, axis):
|
||||
"""Create a sum expression over axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The source expression.
|
||||
|
||||
rdom : RDomain
|
||||
The reduction domainx
|
||||
axis : IterVar
|
||||
The reduction IterVar axis
|
||||
"""
|
||||
rdom = rdom if isinstance(rdom, list) else [rdom]
|
||||
x = _make.Reduce("Add", expr, rdom)
|
||||
axis = axis if isinstance(axis, list) else [axis]
|
||||
x = _make.Reduce("Add", expr, axis)
|
||||
return x
|
||||
|
||||
|
||||
def min(expr, rdom):
|
||||
"""Create a min expression over rdom
|
||||
def min(expr, axis):
|
||||
"""Create a min expression over axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The source expression.
|
||||
|
||||
rdom : RDomain
|
||||
The reduction domainx
|
||||
axis : IterVar
|
||||
The reduction IterVar axis
|
||||
"""
|
||||
rdom = rdom if isinstance(rdom, list) else [rdom]
|
||||
x = _make.Reduce("Min", expr, rdom)
|
||||
axis = axis if isinstance(axis, list) else [axis]
|
||||
x = _make.Reduce("Min", expr, axis)
|
||||
return x
|
||||
|
||||
|
||||
def max(expr, rdom):
|
||||
"""Create a min expression over rdom
|
||||
def max(expr, axis):
|
||||
"""Create a min expression over axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The source expression.
|
||||
|
||||
rdom : RDomain
|
||||
The reduction domainx
|
||||
axis : IterVar
|
||||
The reduction IterVar axis
|
||||
"""
|
||||
rdom = rdom if isinstance(rdom, list) else [rdom]
|
||||
x = _make.Reduce("Max", expr, rdom)
|
||||
axis = axis if isinstance(axis, list) else [axis]
|
||||
x = _make.Reduce("Max", expr, axis)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -62,9 +62,10 @@ def build(sch,
|
|||
|
||||
# lowering
|
||||
bounds = schedule.InferBound(sch)
|
||||
stmt = ir_pass.ScheduleOps(sch, bounds)
|
||||
stmt = schedule.ScheduleOps(sch, bounds)
|
||||
stmt = ir_pass.StorageFlatten(stmt, binds)
|
||||
stmt = ir_pass.Simplify(stmt)
|
||||
print(stmt)
|
||||
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
|
||||
fsplits = codegen.SplitHostDevice(fapi)
|
||||
|
||||
|
@ -73,7 +74,8 @@ def build(sch,
|
|||
for i, f in enumerate(fsplits):
|
||||
t = target if i >= 1 else "c"
|
||||
record_codes.append(codegen.CompileToC(f, output_ssa, t))
|
||||
|
||||
for c in record_codes:
|
||||
print(c)
|
||||
if target == "cuda":
|
||||
ret = codegen.BuildNVRTC(fsplits, "stackvm")
|
||||
elif target == "opencl":
|
||||
|
|
|
@ -33,6 +33,14 @@ class Schedule(NodeBase):
|
|||
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
|
||||
return self.stage_map[k]
|
||||
|
||||
def normalize(self):
|
||||
"""Build a normalized schedule.
|
||||
|
||||
Insert necessary rebase to make certain iter var to start from 0.
|
||||
This is needed before bound inference and followup step.
|
||||
"""
|
||||
_api_internal._ScheduleNormalize(self)
|
||||
|
||||
@register_node
|
||||
class Stage(NodeBase):
|
||||
"""A Stage represents schedule for one operation."""
|
||||
|
|
|
@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
|
|||
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ScheduleNormalize)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Schedule()
|
||||
.normalize();
|
||||
});
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
|
|||
REGISTER_PASS1(ConvertSSA);
|
||||
REGISTER_PASS1(VerifySSA);
|
||||
REGISTER_PASS4(Inline);
|
||||
REGISTER_PASS2(ScheduleOps);
|
||||
REGISTER_PASS2(StorageFlatten);
|
||||
|
||||
} // namespace ir
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace schedule {
|
|||
REGISTER_SCHEDULE_PASS1(InferBound);
|
||||
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
|
||||
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
|
||||
REGISTER_SCHEDULE_PASS2(ScheduleOps);
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_c.cc
|
||||
*/
|
||||
#include <iomanip>
|
||||
#include "./codegen_c.h"
|
||||
|
||||
namespace tvm {
|
||||
|
@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
|
|||
switch (op->type.bits()) {
|
||||
case 64: case 32: {
|
||||
std::ostringstream temp;
|
||||
temp << op->value;
|
||||
temp << std::scientific << op->value;
|
||||
if (op->type.bits() == 32) temp << 'f';
|
||||
p->MarkConst(temp.str());
|
||||
os << temp.str();
|
||||
|
@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
|
|||
case 16: {
|
||||
os << '(';
|
||||
p->PrintType(op->type, os);
|
||||
os << ')' << op->value << 'f';
|
||||
os << ')' << std::scientific <<op->value << 'f';
|
||||
break;
|
||||
}
|
||||
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
|
||||
|
|
|
@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|||
<< op->op
|
||||
<< ", ";
|
||||
p->print(op->source);
|
||||
p->stream << ", rdom=" << op->rdom << ")";
|
||||
p->stream << ", axis=" << op->axis << ")";
|
||||
});
|
||||
|
||||
} // namespace Internal
|
||||
|
@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|||
namespace tvm {
|
||||
namespace ir {
|
||||
|
||||
Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
|
||||
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
|
||||
auto n = std::make_shared<Reduce>();
|
||||
CHECK(source.defined());
|
||||
for (size_t i = 0; i < rdom.size(); ++i) {
|
||||
CHECK(rdom[i].defined());
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
CHECK(axis[i].defined());
|
||||
}
|
||||
n->type = source.type();
|
||||
n->source = source;
|
||||
n->op = op;
|
||||
n->rdom = rdom;
|
||||
n->axis = axis;
|
||||
return Expr(n);
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
*/
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <memory>
|
||||
|
||||
namespace tvm {
|
||||
|
@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
|
|||
|
||||
// ComputeOpNode
|
||||
Array<IterVar> ComputeOpNode::root_iter_vars() const {
|
||||
return axis;
|
||||
if (reduce_axis.size() == 0) return axis;
|
||||
Array<IterVar> ret = axis;
|
||||
for (IterVar iv : reduce_axis) {
|
||||
ret.push_back(iv);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Type ComputeOpNode::output_dtype(size_t i) const {
|
||||
|
@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
|
|||
n->name = name;
|
||||
n->axis = axis;
|
||||
n->body = body;
|
||||
if (n->body->is_type<ir::Reduce>()) {
|
||||
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
|
||||
}
|
||||
return Operation(n);
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
|
|||
}
|
||||
}
|
||||
|
||||
inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
|
||||
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
|
||||
std::vector<IterVar> new_dom(rdom.size());
|
||||
bool changed = false;
|
||||
for (size_t i = 0; i < rdom.size(); i++) {
|
||||
|
@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
|
|||
|
||||
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
|
||||
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
|
||||
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
|
||||
Array<IterVar> new_axis = MutateIterVarArr(op->axis, m);
|
||||
Expr new_source = m->Mutate(op->source);
|
||||
if (op->rdom.same_as(new_rdom) &&
|
||||
if (op->axis.same_as(new_axis) &&
|
||||
op->source.same_as(new_source)) {
|
||||
return e;
|
||||
} else {
|
||||
return Reduce::make(op->op, new_source, new_rdom);
|
||||
return Reduce::make(op->op, new_source, new_axis);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
|
|||
|
||||
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
|
||||
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
|
||||
VisitRDom(op->rdom, v);
|
||||
VisitRDom(op->axis, v);
|
||||
v->Visit(op->source);
|
||||
})
|
||||
.set_dispatch<IntImm>(NoOp)
|
||||
|
|
|
@ -1,334 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file schedule_ops.cc
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/schedule_pass.h>
|
||||
|
||||
#include "./scope.h"
|
||||
#include "./ir_util.h"
|
||||
#include "../schedule/graph.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace ir {
|
||||
|
||||
/*!
|
||||
* \brief use message passing to calculate the assignment of each Var inside the loop body.
|
||||
* \param s The schedule to be used.
|
||||
* \param dom_map The domain map of each iteration variable's domain
|
||||
* \param p_state The message passing state
|
||||
* IterVar->The assignment.
|
||||
*/
|
||||
void PassUpOffset(const Stage& s,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
std::unordered_map<IterVar, Expr>* p_state) {
|
||||
auto& state = *p_state;
|
||||
for (size_t i = s->relations.size(); i != 0; --i) {
|
||||
IterVarRelation rel = s->relations[i - 1];
|
||||
if (rel.as<SplitNode>()) {
|
||||
const SplitNode* s = rel.as<SplitNode>();
|
||||
Expr outer = state.at(s->outer);
|
||||
Expr inner = state.at(s->inner);
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
Expr parent_min = dom_map.at(s->parent)->min;
|
||||
state[s->parent] = inner + outer * factor;
|
||||
// add min if they exist
|
||||
if (!is_zero(parent_min)) {
|
||||
state[s->parent] = parent_min + state[s->parent];
|
||||
}
|
||||
} else if (rel.as<FuseNode>()) {
|
||||
const FuseNode* s = rel.as<FuseNode>();
|
||||
Expr value = state.at(s->fused);
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
Expr outer_min = dom_map.at(s->outer)->min;
|
||||
Expr inner_min = dom_map.at(s->inner)->min;
|
||||
state[s->outer] = value / factor;
|
||||
state[s->inner] = value % factor;
|
||||
// add min if they exist
|
||||
if (!is_zero(outer_min)) {
|
||||
state[s->outer] = outer_min + state[s->outer];
|
||||
}
|
||||
if (!is_zero(inner_min)) {
|
||||
state[s->inner] = outer_min + state[s->inner];
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "unknown relation type";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief split the expr by addition.
|
||||
* \param expr The expression to be splitted.
|
||||
* \param loop_level The loop level of each Variable
|
||||
* \param result vector of (level, expr)
|
||||
* The level gives the mimimum loop level this expression need to be computed.
|
||||
* The Expr gives the expression content.
|
||||
*/
|
||||
void SplitByAdd(Expr expr,
|
||||
const std::unordered_map<const Variable*, size_t>& loop_level,
|
||||
std::vector<std::pair<size_t, Expr> > *result) {
|
||||
const Add* op = expr.as<Add>();
|
||||
if (op != nullptr) {
|
||||
SplitByAdd(op->a, loop_level, result);
|
||||
SplitByAdd(op->b, loop_level, result);
|
||||
} else {
|
||||
size_t max_level = 0;
|
||||
auto fvisit = [&max_level, &loop_level](const NodeRef& n) {
|
||||
const Variable* op = n.as<Variable>();
|
||||
if (op != nullptr) {
|
||||
auto it = loop_level.find(op);
|
||||
if (it != loop_level.end()) {
|
||||
max_level = std::max(max_level, it->second);
|
||||
}
|
||||
}
|
||||
};
|
||||
PostOrderVisit(expr, fvisit);
|
||||
result->push_back(std::make_pair(max_level, expr));
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Make the loop nest of the correspondings schedule.
|
||||
* \param sch The schedule.
|
||||
* \param dom_map The domain map.
|
||||
*
|
||||
* \return a nested representation of loop statements.
|
||||
* The flattened Stmt are ordered from outmost to inner most order.
|
||||
*/
|
||||
std::vector<std::vector<Stmt> > MakeLoopNest(
|
||||
const Stage& sch,
|
||||
const Map<IterVar, Range>& dom_map) {
|
||||
// optional, use let to define some CSE in dom_map.
|
||||
auto leaf_iter_vars = sch->leaf_iter_vars;
|
||||
std::unordered_map<IterVar, Expr> offset;
|
||||
std::unordered_map<const Variable*, size_t> loop_level;
|
||||
Stmt no_op = Evaluate::make(0);
|
||||
// create the loop nest
|
||||
std::vector<std::vector<Stmt> > nest;
|
||||
nest.resize(leaf_iter_vars.size() + 1);
|
||||
|
||||
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
|
||||
auto iv = leaf_iter_vars[i];
|
||||
Range dom = dom_map.at(iv);
|
||||
// initialize the offset and loop_level
|
||||
offset[iv] = iv->var;
|
||||
loop_level[iv->var.as<Variable>()] = i + 1;
|
||||
// Mark the iter var in the IR, to remember the point
|
||||
if (iv->thread_tag.length() == 0) {
|
||||
if (is_zero(dom->min)) {
|
||||
nest[i + 1].emplace_back(
|
||||
For::make(iv->var, 0, dom->extent,
|
||||
ForType::Serial, DeviceAPI::None, no_op));
|
||||
} else {
|
||||
Var idx(iv->var->name_hint + ".idx", iv->var.type());
|
||||
nest[i + 1].emplace_back(
|
||||
For::make(idx, 0, dom->extent,
|
||||
ForType::Serial, DeviceAPI::None, no_op));
|
||||
nest[i + 1].emplace_back(
|
||||
LetStmt::make(iv->var, dom->min + idx, no_op));
|
||||
}
|
||||
} else {
|
||||
// Always restrict threaded IterVar to starts from 0.
|
||||
CHECK(is_zero(dom->min));
|
||||
// annotate the extent of the IterVar
|
||||
nest[i + 1].emplace_back(
|
||||
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
|
||||
}
|
||||
// annotate the extent of the IterVar
|
||||
nest[i + 1].emplace_back(
|
||||
AttrStmt::make(iv, "scope", iv->var, no_op));
|
||||
}
|
||||
// message passing to get offset of root iter vars.
|
||||
PassUpOffset(sch, dom_map, &offset);
|
||||
|
||||
for (IterVar iv : sch->op->root_iter_vars()) {
|
||||
Expr value = offset.at(iv);
|
||||
if (!value.same_as(iv->var)) {
|
||||
using Entry = std::pair<size_t, Expr>;
|
||||
std::vector<Entry> splits;
|
||||
SplitByAdd(value, loop_level, &splits);
|
||||
|
||||
Expr offset = 0;
|
||||
size_t nsplit_left = splits.size() - 1;
|
||||
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
|
||||
size_t hit = 0;
|
||||
for (const auto& kv : splits) {
|
||||
if (kv.first == i) {
|
||||
if (is_zero(offset)) {
|
||||
offset = kv.second;
|
||||
} else {
|
||||
offset = offset + kv.second;
|
||||
++hit;
|
||||
}
|
||||
}
|
||||
}
|
||||
nsplit_left -= hit;
|
||||
if (hit != 0) {
|
||||
std::ostringstream os;
|
||||
os << iv->var->name_hint << ".at.l" << i;
|
||||
Var base_offset(os.str());
|
||||
if (nsplit_left == 0) {
|
||||
base_offset = iv->var;
|
||||
}
|
||||
nest[i].emplace_back(
|
||||
LetStmt::make(base_offset, offset, no_op));
|
||||
offset = base_offset;
|
||||
}
|
||||
}
|
||||
Range dom = dom_map.at(iv);
|
||||
if (!offset.same_as(iv->var)) {
|
||||
// define the iv->var
|
||||
nest.back().emplace_back(
|
||||
LetStmt::make(iv->var, offset, no_op));
|
||||
}
|
||||
Expr condition = (iv->var - dom->min) < dom->extent;
|
||||
// Boundary condition checking
|
||||
// Need better boundary condition here.
|
||||
nest.back().emplace_back(IfThenElse::make(condition, no_op));
|
||||
}
|
||||
}
|
||||
return nest;
|
||||
}
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Make pipeline specifically for compute op node.
|
||||
* \param op The compute node
|
||||
* \param tensors The tensors generated by provide.
|
||||
*/
|
||||
Stmt MakeProvide(const ComputeOpNode* op,
|
||||
const std::vector<Tensor>& tensors) {
|
||||
Tensor t = tensors[0];
|
||||
Array<Expr> args;
|
||||
for (IterVar iv : op->axis) {
|
||||
args.push_back(iv->var);
|
||||
}
|
||||
return Provide::make(t->op, t->value_index, op->body, args);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Make pipeline specifically for compute op node.
|
||||
* \param op The compute node
|
||||
* \param dom_map The domain map
|
||||
* \param tensors The tensors generated by provide.
|
||||
* \param body The content of the pipeline.
|
||||
*/
|
||||
Stmt MakeRealize(const ComputeOpNode* op,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
const std::vector<Tensor>& tensors,
|
||||
Stmt body) {
|
||||
Tensor t = tensors[0];
|
||||
Halide::Internal::Region bounds;
|
||||
for (IterVar iv : op->axis) {
|
||||
bounds.push_back(dom_map.at(iv));
|
||||
}
|
||||
return Realize::make(t->op, t->value_index, t->dtype,
|
||||
bounds, make_const(Bool(1), true), body);
|
||||
}
|
||||
|
||||
Stmt MakePipeline(const Stage& sch,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
Stmt consumer) {
|
||||
std::vector<Tensor> tensors;
|
||||
for (int i = 0; i < sch->op->num_outputs(); ++i) {
|
||||
tensors.emplace_back(sch->op.output(i));
|
||||
}
|
||||
|
||||
Stmt provide;
|
||||
if (sch->op.as<ComputeOpNode>()) {
|
||||
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
|
||||
} else {
|
||||
LOG(FATAL) << "not supported op " << sch->op->type_key();
|
||||
}
|
||||
std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
|
||||
Stmt producer = MergeNest(nest, provide);
|
||||
producer = ProducerConsumer::make(sch->op, true, producer);
|
||||
|
||||
Stmt pipeline = producer;
|
||||
if (consumer.defined()) {
|
||||
consumer = ProducerConsumer::make(sch->op, false, consumer);
|
||||
pipeline = Block::make(producer, consumer);
|
||||
}
|
||||
|
||||
if (sch->op.as<ComputeOpNode>()) {
|
||||
return MakeRealize(sch->op.as<ComputeOpNode>(),
|
||||
dom_map, tensors, pipeline);
|
||||
} else {
|
||||
LOG(FATAL) << "not supported op";
|
||||
return Stmt();
|
||||
}
|
||||
}
|
||||
|
||||
// inject the operator's realization on the stmt.
|
||||
class InjectRealize : public IRMutator {
|
||||
public:
|
||||
InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
|
||||
: schedule(schedule), dom_map(dom_map) {}
|
||||
|
||||
Stmt Mutate(Stmt stmt) final {
|
||||
CHECK(stmt.defined());
|
||||
stmt = IRMutator::Mutate(stmt);
|
||||
const AttrStmt* op = stmt.as<AttrStmt>();
|
||||
if (op != nullptr &&
|
||||
op->type_key == "scope") {
|
||||
if (op->node == schedule->attach_ivar) {
|
||||
CHECK(!found_attach);
|
||||
found_attach = true;
|
||||
stmt = AttrStmt::make(
|
||||
op->node, op->type_key, op->value,
|
||||
MakePipeline(schedule, dom_map,
|
||||
IRMutator::Mutate(op->body)));
|
||||
}
|
||||
}
|
||||
return stmt;
|
||||
}
|
||||
// the operations to be carried
|
||||
Stage schedule;
|
||||
// domain map
|
||||
Map<IterVar, Range> dom_map;
|
||||
// whether attach point is found
|
||||
bool found_attach{false};
|
||||
};
|
||||
|
||||
Stmt InjectInline(const Operation op, Stmt body) {
|
||||
CHECK(body.defined());
|
||||
const ComputeOpNode* compute = op.as<ComputeOpNode>();
|
||||
CHECK(compute != nullptr)
|
||||
<< "can only inline compute op";
|
||||
Array<Var> args;
|
||||
for (auto iv : compute->axis) {
|
||||
args.push_back(iv->var);
|
||||
}
|
||||
return Inline(body, op, args, compute->body);
|
||||
}
|
||||
|
||||
|
||||
Stmt ScheduleOps(
|
||||
Schedule sch, Map<IterVar, Range> dom_map) {
|
||||
Stmt body = Stmt();
|
||||
// reverse the post DFS order.
|
||||
for (size_t i = sch->stages.size(); i != 0; --i) {
|
||||
Stage s = sch->stages[i - 1];
|
||||
// no need to specify place holder op.
|
||||
if (s->op.as<PlaceholderOpNode>()) continue;
|
||||
if (s->attach_type == kInline) {
|
||||
body = InjectInline(s->op, body);
|
||||
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
|
||||
body = MakePipeline(s, dom_map, body);
|
||||
} else if (s->attach_type == kScope) {
|
||||
CHECK(body.defined());
|
||||
InjectRealize mutator(s, dom_map);
|
||||
body = mutator.Mutate(body);
|
||||
CHECK(mutator.found_attach)
|
||||
<< "did not find attachment point";
|
||||
}
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
|
||||
namespace tvm {
|
||||
|
@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
|
|||
v.Visit(e);
|
||||
return v.has_side_effect_;
|
||||
}
|
||||
|
||||
class IRSubstitue : public IRMutator {
|
||||
public:
|
||||
Expr Mutate_(const Variable* op, const Expr& e) final {
|
||||
auto it = smap.find(op);
|
||||
if (it != smap.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
std::unordered_map<const Variable*, Expr> smap;
|
||||
};
|
||||
|
||||
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
|
||||
IRSubstitue m;
|
||||
for (auto kv : value_map) {
|
||||
m.smap[kv.first->var.get()] = kv.second;
|
||||
}
|
||||
return m.Mutate(stmt);
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
||||
|
|
|
@ -54,6 +54,11 @@ void PassDown(const Stage& s,
|
|||
const Range& range_inner = state.at(r->inner);
|
||||
state[r->fused] = Range::make_with_min_extent(
|
||||
0, range_outer->extent * range_inner->extent);
|
||||
} else if (rel.as<RebaseNode>()) {
|
||||
const RebaseNode* r = rel.as<RebaseNode>();
|
||||
CHECK(state.count(r->parent));
|
||||
state[r->rebased] = Range::make_with_min_extent(
|
||||
0, state.at(r->parent)->extent);
|
||||
} else {
|
||||
LOG(FATAL) << "unknown relation type";
|
||||
}
|
||||
|
@ -85,6 +90,13 @@ void PassUp(const Stage& s,
|
|||
&outer, &inner);
|
||||
state[r->outer] = outer;
|
||||
state[r->inner] = inner;
|
||||
} else if (rel.as<RebaseNode>()) {
|
||||
IntSet parent;
|
||||
const RebaseNode* r = rel.as<RebaseNode>();
|
||||
PassUp(r, dom_map,
|
||||
state.at(r->rebased),
|
||||
&parent);
|
||||
state[r->parent] = parent;
|
||||
} else {
|
||||
LOG(FATAL) << "unknown relation type";
|
||||
}
|
||||
|
@ -109,9 +121,15 @@ void PassToOperation(
|
|||
// Eventually, we need to change the inference to be a Pull style inference
|
||||
if (tensor->op.as<ComputeOpNode>()) {
|
||||
auto root_iter_vars = tensor->op->root_iter_vars();
|
||||
CHECK_EQ(tensor.ndim(), root_iter_vars.size());
|
||||
for (size_t i = 0; i < tensor.ndim(); ++i) {
|
||||
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
|
||||
const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
|
||||
CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
|
||||
for (size_t i = 0; i < op->axis.size(); ++i) {
|
||||
(*result)[op->axis[i]].push_back(dim_bounds[i]);
|
||||
}
|
||||
// reduction.
|
||||
for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
|
||||
(*result)[op->reduce_axis[i]].push_back(
|
||||
IntSet::range(op->reduce_axis[i]->dom));
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
|
||||
|
@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
|
|||
{"local", 2}
|
||||
};
|
||||
static std::unordered_map<std::string, int> thread_tag_rank{
|
||||
{"gridIdx.x", 0},
|
||||
{"gridIdx.y", 0},
|
||||
{"gridIdx.z", 0},
|
||||
{"blockIdx.x", 0},
|
||||
{"blockIdx.y", 0},
|
||||
{"blockIdx.z", 0},
|
||||
{"threadIdx.x", 1},
|
||||
{"threadIdx.y", 1},
|
||||
{"threadIdx.z", 1}
|
||||
|
@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
|
|||
(*rmap)[iv] = iv->dom;
|
||||
}
|
||||
}
|
||||
// get range of all child iter vars.
|
||||
PassDown(stage, rmap);
|
||||
|
||||
if (stage->attach_type == kScope) {
|
||||
Stage parent = stage->attach_stage;
|
||||
|
@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
|
|||
|
||||
bool fix_value = true;
|
||||
for (auto iv : parent->leaf_iter_vars) {
|
||||
Range vrange = rmap->at(iv);
|
||||
CHECK(is_zero(vrange->min))
|
||||
<< "InferBound requires every leaf iter var's min equals 0, "
|
||||
<< "call schedule.normalize to achieve this.";
|
||||
// special optimization to remove trivial loop
|
||||
if (is_one(vrange->extent)) {
|
||||
up_state[iv] = IntSet::single_point(vrange->min);
|
||||
}
|
||||
if (fix_value && !ScopeRelax(iv, stage->scope)) {
|
||||
up_state[iv] = IntSet::make_point(iv->var);
|
||||
up_state[iv] = IntSet::single_point(iv->var);
|
||||
} else {
|
||||
up_state[iv] = IntSet::make_range(rmap->at(iv));
|
||||
up_state[iv] = IntSet::range(vrange);
|
||||
}
|
||||
if (stage->attach_ivar == iv) {
|
||||
fix_value = false;
|
||||
|
@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
|
|||
bp_state[iv] = {up_state.at(iv)};
|
||||
}
|
||||
auto result = BoundProp(post_order, &bp_state);
|
||||
|
||||
// Set relaxation
|
||||
Map<IterVar, IntSet> relax_set;
|
||||
Stage s = stage;
|
||||
while (s->attach_type == kScope) {
|
||||
s = s->attach_stage;
|
||||
for (auto iv : s->leaf_iter_vars) {
|
||||
if (ScopeRelax(iv, stage->scope)) {
|
||||
relax_set.Set(iv, IntSet::range(rmap->at(iv)));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto iv : stage->op->root_iter_vars()) {
|
||||
CHECK(result.count(iv));
|
||||
CHECK(!rmap->count(iv));
|
||||
(*rmap)[iv] = result.at(iv).GetCoverRange();
|
||||
Range r = result.at(iv).cover_range(iv->dom);
|
||||
if (relax_set.size() != 0) {
|
||||
r = EvalSet(r, relax_set).cover_range(iv->dom);
|
||||
}
|
||||
(*rmap)[iv] = r;
|
||||
}
|
||||
}
|
||||
// get range of all child iter vars.
|
||||
PassDown(stage, rmap);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file compute_expr.h
|
||||
* \brief Utility integer expression with quick eager simplification.
|
||||
* This is weaker than Simplify but can be done Eagerly.
|
||||
*/
|
||||
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
|
||||
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <pass/Interval.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace schedule {
|
||||
|
||||
using Halide::Internal::add_would_overflow;
|
||||
using Halide::Internal::sub_would_overflow;
|
||||
using Halide::Internal::mul_would_overflow;
|
||||
|
||||
/*!
|
||||
* \brief Compute the expression with the given binary op.
|
||||
* \param lhs The left operand
|
||||
* \param rhs The right operand
|
||||
* \return The result.
|
||||
*/
|
||||
template<typename OP>
|
||||
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
|
||||
return OP::make(lhs, rhs);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool GetConst(Expr e, T* out);
|
||||
|
||||
template<>
|
||||
bool GetConst<int64_t>(Expr e, int64_t *out) {
|
||||
if (e.type().is_vector()) return false;
|
||||
const int64_t *v = as_const_int(e);
|
||||
if (v) {
|
||||
*out = *v; return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
template<>
|
||||
bool GetConst<uint64_t>(Expr e, uint64_t *out) {
|
||||
if (e.type().is_vector()) return false;
|
||||
const uint64_t *v = as_const_uint(e);
|
||||
if (v) {
|
||||
*out = *v; return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
|
||||
int64_t ia = 0, ib = 0; \
|
||||
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
|
||||
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
|
||||
LOG(FATAL) << "signed int overflow"; \
|
||||
} \
|
||||
return ir::IntImm::make(a.type(), ia OP ib); \
|
||||
} \
|
||||
uint64_t ua = 0, ub = 0; \
|
||||
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
|
||||
return ir::UIntImm::make(a.type(), ua + ub); \
|
||||
} \
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
|
||||
if (is_zero(a)) return b;
|
||||
if (is_zero(b)) return a;
|
||||
TVM_CONST_PROPAGATION(add, +);
|
||||
return ir::Add::make(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
|
||||
if (is_zero(b)) return a;
|
||||
TVM_CONST_PROPAGATION(sub, -);
|
||||
return ir::Add::make(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
|
||||
if (is_one(a)) return b;
|
||||
if (is_one(b)) return a;
|
||||
TVM_CONST_PROPAGATION(mul, *);
|
||||
return ir::Mul::make(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
|
||||
if (is_one(b)) return a;
|
||||
return ir::Mul::make(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
|
||||
return Halide::Internal::Interval::make_max(a, b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
|
||||
return Halide::Internal::Interval::make_min(a, b);
|
||||
}
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
||||
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
|
|
@ -1,212 +1,355 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file int_set.cc
|
||||
* \file int_set_impl.cc
|
||||
* \brief The integer set functions
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <pass/Interval.h>
|
||||
#include "./int_set.h"
|
||||
#include "./compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace schedule {
|
||||
|
||||
using Halide::Internal::Interval;
|
||||
|
||||
using namespace ir;
|
||||
|
||||
/*!
|
||||
* \brief Internal node container of int set.
|
||||
*/
|
||||
class IntSetNode : public Node {
|
||||
public:
|
||||
/*! \brief The base range scope */
|
||||
Range base;
|
||||
/*! \brief additional strided domain */
|
||||
Array<Range> domain;
|
||||
/*! \brief The stride of each strided domain */
|
||||
Array<Expr> stride;
|
||||
/*!
|
||||
* \brief The concrete set,
|
||||
* used when concrete execution is enabled.
|
||||
*/
|
||||
std::vector<int32_t> concrete;
|
||||
/*! \brief Set of continuous interval */
|
||||
struct IntervalSet : public IntSetNode {
|
||||
/*! \brief the internal interval*/
|
||||
Interval i;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("base", &base);
|
||||
v->Visit("domain", &domain);
|
||||
v->Visit("stride", &stride);
|
||||
static IntSet make(Interval i) {
|
||||
std::shared_ptr<IntervalSet> n =
|
||||
std::make_shared<IntervalSet>();
|
||||
n->i = i;
|
||||
return IntSet(n);
|
||||
}
|
||||
static IntSet make(Expr min, Expr max) {
|
||||
std::shared_ptr<IntervalSet> n =
|
||||
std::make_shared<IntervalSet>();
|
||||
n->i.min = min;
|
||||
n->i.max = max;
|
||||
return IntSet(n);
|
||||
}
|
||||
|
||||
static constexpr const char* _type_key = "IntSet";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(IntSetNode);
|
||||
static constexpr const char* _type_key = "IntervalSet";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet);
|
||||
};
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(IntSetNode);
|
||||
/*!
|
||||
* \brief set represented by strided integers
|
||||
* Reserved for cases where strided access is supported.
|
||||
*/
|
||||
struct StrideSet : public IntSetNode {
|
||||
/*! \brief the base inetrval */
|
||||
Interval base;
|
||||
/*! \brief additional extents in positive number */
|
||||
Array<Expr> extents;
|
||||
/*! \brief additional strides in positive number */
|
||||
Array<Expr> strides;
|
||||
|
||||
namespace {
|
||||
static constexpr const char* _type_key = "StrideSet";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(StrideSet);
|
||||
};
|
||||
|
||||
inline bool Match(const Expr& e, int64_t value) {
|
||||
const ir::IntImm* v = e.as<ir::IntImm>();
|
||||
return v != nullptr && v->value;
|
||||
}
|
||||
|
||||
// whether a exactly matches b.
|
||||
inline bool Match(const IntSet& a,
|
||||
const Range& b) {
|
||||
if (a->base == b &&
|
||||
a->domain.size() == 0 &&
|
||||
a->concrete.size() == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
inline IntSet IntSet::cover_interval() const {
|
||||
if ((*this).as<IntervalSet>()) return *this;
|
||||
const StrideSet* s = (*this).as<StrideSet>();
|
||||
if (s) {
|
||||
CHECK_NE(s->extents.size(), 0U);
|
||||
Expr max = s->base.max;
|
||||
for (size_t i = 0; i < s->extents.size(); ++i) {
|
||||
max = max + s->extents[i] * s->strides[i] - s->strides[i];
|
||||
}
|
||||
return IntervalSet::make(s->base.min, max);
|
||||
}
|
||||
LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
|
||||
return IntSet::everything();
|
||||
}
|
||||
|
||||
// whether a exactly matches b.
|
||||
inline bool Match(const IntSet& a,
|
||||
const Expr& b) {
|
||||
if (a->domain.size() == 0 &&
|
||||
a->concrete.size() == 0) {
|
||||
return Match(a->base->extent, 1) && a->base->min.same_as(b);
|
||||
} else {
|
||||
return false;
|
||||
Range IntSet::cover_range(Range max_range) const {
|
||||
IntSet temp;
|
||||
const IntervalSet* s_int = (*this).as<IntervalSet>();
|
||||
if (s_int == nullptr) {
|
||||
temp = this->cover_interval();
|
||||
s_int = temp.as<IntervalSet>();
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsNumber(const IntSet& s) {
|
||||
if (s->domain.size() != 0) return false;
|
||||
if (s->concrete.size() != 0) {
|
||||
return s->concrete.size() == 1;
|
||||
if (s_int->i.is_bounded()) {
|
||||
return Range::make_with_min_extent(
|
||||
s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
|
||||
}
|
||||
return Match(s->base->extent, 1);
|
||||
return max_range;
|
||||
}
|
||||
|
||||
inline Expr AsNumber(const IntSet& s) {
|
||||
return s->base->min;
|
||||
bool IntSet::is_everything() const {
|
||||
const IntervalSet* s_int = (*this).as<IntervalSet>();
|
||||
return (s_int && s_int->i.is_everything());
|
||||
}
|
||||
|
||||
// set combination rule by operators
|
||||
template<typename T>
|
||||
inline IntSet BinaryCombine(IntSet a, IntSet b) {
|
||||
LOG(WARNING) << "cannot evaluate binary op " << T::_type_key;
|
||||
return IntSet::make_all_set();
|
||||
bool IntSet::is_single_point() const {
|
||||
const IntervalSet* s_int = (*this).as<IntervalSet>();
|
||||
return (s_int && s_int->i.is_single_point());
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet BinaryCombine<Add>(IntSet a, IntSet b) {
|
||||
auto n = std::make_shared<IntSetNode>(*(a.operator->()));
|
||||
for (size_t i = 0; i < b->domain.size(); ++i) {
|
||||
n->domain.push_back(b->domain[i]);
|
||||
n->stride.push_back(b->stride[i]);
|
||||
IntSet IntSet::everything() {
|
||||
return IntervalSet::make(Interval::everything());
|
||||
}
|
||||
|
||||
IntSet IntSet::single_point(Expr x) {
|
||||
return IntervalSet::make(Interval::single_point(x));
|
||||
}
|
||||
|
||||
IntSet IntSet::range(Range r) {
|
||||
// must make sure it can be matched back by MatchRange.
|
||||
if (is_one(r->extent)) {
|
||||
return IntSet::single_point(r->min);
|
||||
}
|
||||
|
||||
if (IsNumber(a)) {
|
||||
n->base = Range::make_with_min_extent(
|
||||
a->base->min + b->base->min,
|
||||
b->base->extent);
|
||||
} else if (IsNumber(b)) {
|
||||
n->base = Range::make_with_min_extent(
|
||||
a->base->min + b->base->min,
|
||||
a->base->extent);
|
||||
} else {
|
||||
n->base = Range::make_with_min_extent(
|
||||
a->base->min + b->base->min,
|
||||
a->base->extent + b->base->extent - 1);
|
||||
if (is_positive_const(r->extent) && is_const(r->min)) {
|
||||
return IntervalSet::make(
|
||||
r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
|
||||
}
|
||||
return IntSet(n);
|
||||
return IntervalSet::make(r->min, (r->extent + r->min) - 1);
|
||||
}
|
||||
|
||||
inline Range Negation(Range a) {
|
||||
if (Match(a->extent, 1)) {
|
||||
return Range::make_with_min_extent(-a->min, a->extent);
|
||||
} else {
|
||||
return Range::make_with_min_extent(-(a->min + a->extent - 1), a->extent);
|
||||
// Check if a is created from b.
|
||||
inline bool MatchRange(const IntSet& a,
|
||||
const Range& b) {
|
||||
const IntervalSet* a_int = a.as<IntervalSet>();
|
||||
if (!a_int) return false;
|
||||
const Interval& i = a_int->i;
|
||||
if (!i.min.same_as(b)) return false;
|
||||
if (is_one(b->extent)) return i.is_single_point();
|
||||
if (is_positive_const(b->extent) && is_const(b->min)) {
|
||||
// deep equality
|
||||
return Equal(
|
||||
ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1),
|
||||
a_int->i.max);
|
||||
}
|
||||
const Sub* sub = i.max.as<Sub>();
|
||||
if (!sub) return false;
|
||||
if (is_one(sub->b)) return false;
|
||||
const Add* add = sub->a.as<Add>();
|
||||
return add &&
|
||||
add->a.same_as(b->min) &&
|
||||
add->b.same_as(b->extent);
|
||||
}
|
||||
|
||||
inline IntSet Negation(IntSet a) {
|
||||
CHECK_EQ(a->concrete.size(), 0U);
|
||||
auto n = std::make_shared<IntSetNode>();
|
||||
n->base = Negation(a->base);
|
||||
for (size_t i = 0; i < a->domain.size(); ++i) {
|
||||
n->domain.push_back(Negation(a->domain[i]));
|
||||
n->stride.push_back(a->stride[i]);
|
||||
}
|
||||
return IntSet(a);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet BinaryCombine<Sub>(IntSet a, IntSet b) {
|
||||
return BinaryCombine<Add>(a, Negation(b));
|
||||
}
|
||||
|
||||
inline IntSet BinaryMul(IntSet a, Expr b) {
|
||||
// copy construct
|
||||
if (Match(b, 1)) return a;
|
||||
if (Match(b, -1)) return Negation(a);
|
||||
auto n = std::make_shared<IntSetNode>();
|
||||
n->base = Range::make_with_min_extent(0, 1);
|
||||
n->domain.push_back(a->base);
|
||||
n->stride.push_back(b);
|
||||
for (size_t i = 0; i < a->domain.size(); ++i) {
|
||||
n->domain.push_back(a->domain[i]);
|
||||
n->stride.push_back(a->stride[i] * b);
|
||||
}
|
||||
return IntSet(a);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet BinaryCombine<Mul>(IntSet a, IntSet b) {
|
||||
if (IsNumber(a)) {
|
||||
return BinaryMul(a, AsNumber(b));
|
||||
} else if (IsNumber(b)) {
|
||||
return BinaryMul(b, AsNumber(a));
|
||||
} else {
|
||||
return IntSet::make_all_set();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
inline const IntSetNode* IntSet::operator->() const {
|
||||
return static_cast<const IntSetNode*>(node_.get());
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||
.set_dispatch<IntSetNode>([](const IntSetNode *op, IRPrinter *p) {
|
||||
p->stream << "int-set(base=";
|
||||
p->print(op->base);
|
||||
p->stream << ')';
|
||||
});
|
||||
|
||||
IntSet IntSet::make_range(Range dom) {
|
||||
auto n = std::make_shared<IntSetNode>();
|
||||
n->base = dom;
|
||||
return IntSet(n);
|
||||
}
|
||||
|
||||
Range IntSet::GetCoverRange() const {
|
||||
const IntSetNode* s = operator->();
|
||||
CHECK(s != nullptr) << "empty set";
|
||||
if (s->domain.size() == 0 && s->concrete.size() == 0) {
|
||||
return s->base;
|
||||
}
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
return Range();
|
||||
}
|
||||
|
||||
IntSet IntSet::make_point(Expr point) {
|
||||
return IntSet::make_range(Range::make_with_min_extent(point, 1));
|
||||
}
|
||||
|
||||
IntSet IntSet::make_all_set() {
|
||||
LOG(FATAL) << "TODO";
|
||||
return IntSet();
|
||||
inline bool MatchPoint(const IntSet& a,
|
||||
const Expr& b) {
|
||||
const IntervalSet* a_int = a.as<IntervalSet>();
|
||||
if (!a_int) return false;
|
||||
const Interval& i = a_int->i;
|
||||
return i.is_single_point() && i.min.same_as(b);
|
||||
}
|
||||
|
||||
IntSet Union(const Array<IntSet>& set) {
|
||||
if (set.size() == 1) return set[0];
|
||||
LOG(FATAL) << "TODO";
|
||||
return IntSet();
|
||||
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
|
||||
for (size_t i = 1; i < set.size(); ++i) {
|
||||
x.include(set[i].cover_interval().as<IntervalSet>()->i);
|
||||
}
|
||||
return IntervalSet::make(x);
|
||||
}
|
||||
|
||||
// type traits
|
||||
template<typename OP>
|
||||
struct is_logical_op {
|
||||
static const bool value = false;
|
||||
};
|
||||
|
||||
#define TVM_DECLARE_LOGICAL_OP(OP) \
|
||||
template<> \
|
||||
struct is_logical_op<ir::OP> { \
|
||||
static const bool value = true; \
|
||||
};
|
||||
|
||||
// interval related.
|
||||
template<typename OP>
|
||||
inline IntSet CombineInterval(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
|
||||
}
|
||||
LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
|
||||
return IntSet::everything();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineInterval<Add>(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
|
||||
}
|
||||
Interval r = Interval::everything();
|
||||
if (a.has_lower_bound() && b.has_lower_bound()) {
|
||||
r.min = ComputeExpr<Add>(a.min, b.min);
|
||||
}
|
||||
if (a.has_upper_bound() && b.has_upper_bound()) {
|
||||
r.max = ComputeExpr<Add>(a.max, b.max);
|
||||
}
|
||||
return IntervalSet::make(r);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
|
||||
}
|
||||
Interval r = Interval::everything();
|
||||
if (a.has_lower_bound() && b.has_upper_bound()) {
|
||||
r.min = ComputeExpr<Sub>(a.min, b.max);
|
||||
}
|
||||
if (a.has_upper_bound() && b.has_lower_bound()) {
|
||||
r.max = ComputeExpr<Sub>(a.max, b.min);
|
||||
}
|
||||
return IntervalSet::make(r);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
|
||||
}
|
||||
if (a.is_single_point() && !b.is_single_point()) {
|
||||
std::swap(a, b);
|
||||
}
|
||||
if (b.is_single_point()) {
|
||||
if (is_zero(b.min)) return IntSet::single_point(0);
|
||||
if (is_one(b.min)) return IntervalSet::make(a);
|
||||
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
|
||||
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
|
||||
// This is relaxiation
|
||||
// TODO(tqchen): consider convert to StrideSet.
|
||||
if (is_positive_const(b.min)) {
|
||||
return IntervalSet::make(e1, e2);
|
||||
} else if (is_negative_const(b.min)) {
|
||||
return IntervalSet::make(e2, e1);
|
||||
} else if (a.is_bounded()) {
|
||||
Expr cmp = b.min >= make_zero(b.min.type().element_of());
|
||||
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << "Return Everything in CombineInterval Mul";
|
||||
return IntSet::everything();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
|
||||
}
|
||||
return IntervalSet::make(Interval::make_max(a.min, b.min),
|
||||
Interval::make_max(a.max, b.max));
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineInterval<Min>(Interval a, Interval b) {
|
||||
if (a.is_single_point() && b.is_single_point()) {
|
||||
return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
|
||||
}
|
||||
return IntervalSet::make(Interval::make_min(a.min, b.min),
|
||||
Interval::make_min(a.max, b.max));
|
||||
}
|
||||
|
||||
template<typename OP>
|
||||
inline IntSet CombineInterval_(IntSet a, IntSet b) {
|
||||
return CombineInterval<OP>(
|
||||
a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
|
||||
}
|
||||
|
||||
// stride related
|
||||
inline IntSet AsStrideSet(IntSet a) {
|
||||
if (a.as<StrideSet>()) return a;
|
||||
const IntervalSet* s = a.as<IntervalSet>();
|
||||
CHECK(s->i.is_bounded());
|
||||
std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>();
|
||||
n->base = s->i;
|
||||
return IntSet(n);
|
||||
}
|
||||
template<typename OP>
|
||||
inline IntSet CombineSets(IntSet a, IntSet b) {
|
||||
return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
|
||||
const IntervalSet* a_int = a.as<IntervalSet>();
|
||||
const IntervalSet* b_int = b.as<IntervalSet>();
|
||||
if (a_int && is_zero(a_int->i.min)) return b;
|
||||
if (b_int && is_zero(b_int->i.min)) return a;
|
||||
a = AsStrideSet(a);
|
||||
b = AsStrideSet(b);
|
||||
const StrideSet* a_stride = a.as<StrideSet>();
|
||||
const StrideSet* b_stride = b.as<StrideSet>();
|
||||
auto n = std::make_shared<StrideSet>(*a_stride);
|
||||
for (size_t i = 0; i < b_stride->extents.size(); ++i) {
|
||||
n->extents.push_back(b_stride->extents[i]);
|
||||
n->strides.push_back(b_stride->strides[i]);
|
||||
}
|
||||
n->base = CombineInterval<Add>(
|
||||
a_stride->base, b_stride->base).as<IntervalSet>()->i;
|
||||
return IntSet(n);
|
||||
}
|
||||
|
||||
inline IntSet NegateSet(IntSet a) {
|
||||
const IntervalSet* a_int = a.as<IntervalSet>();
|
||||
if (a_int) {
|
||||
if (a_int->i.is_single_point()) {
|
||||
return IntSet::single_point(-a_int->i.min);
|
||||
} else {
|
||||
Interval r = Interval::everything();
|
||||
if (a_int->i.has_upper_bound()) {
|
||||
r.min = -(a_int->i.max);
|
||||
}
|
||||
if (a_int->i.has_lower_bound()) {
|
||||
r.max = -(a_int->i.min);
|
||||
}
|
||||
return IntervalSet::make(r);
|
||||
}
|
||||
} else {
|
||||
return NegateSet(a.cover_interval());
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
|
||||
return CombineSets<Add>(a, NegateSet(b));
|
||||
}
|
||||
|
||||
TVM_DECLARE_LOGICAL_OP(And);
|
||||
TVM_DECLARE_LOGICAL_OP(Or);
|
||||
TVM_DECLARE_LOGICAL_OP(EQ);
|
||||
TVM_DECLARE_LOGICAL_OP(NE);
|
||||
TVM_DECLARE_LOGICAL_OP(GE);
|
||||
TVM_DECLARE_LOGICAL_OP(GT);
|
||||
TVM_DECLARE_LOGICAL_OP(LE);
|
||||
TVM_DECLARE_LOGICAL_OP(LT);
|
||||
TVM_DECLARE_LOGICAL_OP(Not);
|
||||
|
||||
// generic combine operations of two sets
|
||||
template<typename OP>
|
||||
inline IntSet Combine(const IntSet& a, const IntSet &b) {
|
||||
if (is_logical_op<OP>::value) {
|
||||
return IntervalSet::make(0, 1);
|
||||
}
|
||||
const IntervalSet* a_int = a.as<IntervalSet>();
|
||||
const IntervalSet* b_int = b.as<IntervalSet>();
|
||||
if (a_int && a_int->i.is_everything()) return a;
|
||||
if (b_int && b_int->i.is_everything()) return b;
|
||||
if (a_int && b_int) {
|
||||
return CombineInterval<OP>(a_int->i, b_int->i);
|
||||
}
|
||||
if (a_int && !(a_int->i.is_bounded())) {
|
||||
return CombineInterval_<OP>(a, b.cover_interval());
|
||||
}
|
||||
if (b_int && !(b_int->i.is_bounded())) {
|
||||
return CombineInterval_<OP>(a.cover_interval(), b);
|
||||
}
|
||||
return CombineSets<OP>(a, b);
|
||||
}
|
||||
|
||||
// Implementation of Evaluations and passing.
|
||||
void PassUp(const SplitNode* s,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const IntSet& outer,
|
||||
|
@ -215,33 +358,21 @@ void PassUp(const SplitNode* s,
|
|||
if (dom_map.count(s->outer) &&
|
||||
dom_map.count(s->inner) &&
|
||||
dom_map.count(s->parent) &&
|
||||
Match(outer, dom_map.at(s->outer)) &&
|
||||
Match(inner, dom_map.at(s->inner))) {
|
||||
*parent = IntSet::make_range(dom_map.at(s->parent));
|
||||
MatchRange(outer, dom_map.at(s->outer)) &&
|
||||
MatchRange(inner, dom_map.at(s->inner))) {
|
||||
*parent = IntSet::range(dom_map.at(s->parent));
|
||||
return;
|
||||
}
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
Expr parent_min = dom_map.at(s->parent)->min;
|
||||
CHECK(outer.defined());
|
||||
CHECK(inner.defined());
|
||||
CHECK(factor.defined());
|
||||
// copy construct
|
||||
auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
|
||||
|
||||
if (IsNumber(outer)) {
|
||||
// shift the base offset
|
||||
n->base = Range::make_with_min_extent(
|
||||
AsNumber(outer) * factor + inner->base->min,
|
||||
inner->base->extent);
|
||||
} else {
|
||||
// default use all domains in the data.
|
||||
n->domain.push_back(outer->base);
|
||||
n->stride.push_back(factor);
|
||||
for (size_t i = 0; i < outer->domain.size(); ++i) {
|
||||
n->domain.push_back(outer->domain[i]);
|
||||
n->stride.push_back(outer->stride[i] * factor);
|
||||
}
|
||||
}
|
||||
*parent = IntSet(n);
|
||||
*parent = Combine<Add>(
|
||||
Combine<Add>(
|
||||
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
|
||||
IntSet::single_point(parent_min));
|
||||
}
|
||||
|
||||
void PassUp(const FuseNode* s,
|
||||
|
@ -253,29 +384,51 @@ void PassUp(const FuseNode* s,
|
|||
CHECK(dom_map.count(s->inner));
|
||||
CHECK(dom_map.count(s->fused));
|
||||
|
||||
if (Match(fused, dom_map.at(s->fused))) {
|
||||
*outer = IntSet::make_range(dom_map.at(s->outer));
|
||||
*inner = IntSet::make_range(dom_map.at(s->inner));
|
||||
if (MatchRange(fused, dom_map.at(s->fused))) {
|
||||
*outer = IntSet::range(dom_map.at(s->outer));
|
||||
*inner = IntSet::range(dom_map.at(s->inner));
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsNumber(fused)) {
|
||||
Expr value = AsNumber(fused);
|
||||
Expr outer_min = dom_map.at(s->outer)->min;
|
||||
Expr inner_min = dom_map.at(s->inner)->min;
|
||||
|
||||
const IntervalSet* fused_int = fused.as<IntervalSet>();
|
||||
|
||||
if (fused_int && fused_int->i.is_single_point()) {
|
||||
Expr value = fused_int->i.min;
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
*outer = IntSet::make_point(value / factor);
|
||||
*inner = IntSet::make_point(value % factor);
|
||||
Expr v_outer = value / factor;
|
||||
Expr v_inner = value % factor;
|
||||
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
|
||||
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
|
||||
*outer = IntSet::single_point(v_outer);
|
||||
*inner = IntSet::single_point(v_inner);
|
||||
} else {
|
||||
LOG(WARNING) << "use fallback inference rule in fuse";
|
||||
// simply use the entire set, this rule can be enhanced.
|
||||
*outer = IntSet::make_range(dom_map.at(s->outer));
|
||||
*inner = IntSet::make_range(dom_map.at(s->inner));
|
||||
*outer = IntSet::range(dom_map.at(s->outer));
|
||||
*inner = IntSet::range(dom_map.at(s->inner));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// evaluator to evaluate the int set
|
||||
class IRSetEvaluator {
|
||||
|
||||
void PassUp(const RebaseNode* s,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const IntSet& rebased,
|
||||
IntSet* parent) {
|
||||
CHECK(dom_map.count(s->parent));
|
||||
if (MatchRange(rebased, dom_map.at(s->rebased))) {
|
||||
*parent = IntSet::range(dom_map.at(s->parent));
|
||||
return;
|
||||
}
|
||||
Expr parent_min = dom_map.at(s->parent)->min;
|
||||
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
|
||||
}
|
||||
|
||||
// Evaluator to evalute the epxression.
|
||||
class IntSetEvaluator {
|
||||
public:
|
||||
inline IntSet Eval(Expr expr) {
|
||||
static const FType& f = vtable();
|
||||
|
@ -283,11 +436,11 @@ class IRSetEvaluator {
|
|||
return f(expr, expr, this);
|
||||
} else {
|
||||
LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
|
||||
return IntSet::make_all_set();
|
||||
return IntSet::everything();
|
||||
}
|
||||
}
|
||||
|
||||
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IRSetEvaluator *)>;
|
||||
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IntSetEvaluator *)>;
|
||||
static FType& vtable() { // NOLINT(*)
|
||||
static FType inst; return inst;
|
||||
}
|
||||
|
@ -295,76 +448,84 @@ class IRSetEvaluator {
|
|||
std::unordered_map<const Variable*, IntSet> dom_map;
|
||||
};
|
||||
|
||||
inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) {
|
||||
return IntSet::make_point(e);
|
||||
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
|
||||
return IntSet::single_point(e);
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
|
||||
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
|
||||
.set_dispatch<IntImm>(ConstOp)
|
||||
.set_dispatch<UIntImm>(ConstOp)
|
||||
.set_dispatch<FloatImm>(ConstOp);
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
|
||||
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IRSetEvaluator* m) {
|
||||
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
|
||||
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IntSetEvaluator* m) {
|
||||
auto it = m->dom_map.find(op);
|
||||
if (it != m->dom_map.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return IntSet::make_point(e);
|
||||
return IntSet::single_point(e);
|
||||
}
|
||||
});
|
||||
|
||||
// binary operator
|
||||
template<typename T>
|
||||
inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
|
||||
inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) {
|
||||
IntSet a = m->Eval(op->a);
|
||||
IntSet b = m->Eval(op->b);
|
||||
if (IsNumber(a) && IsNumber(b)) {
|
||||
if (Match(a, op->a) &&
|
||||
Match(b, op->b)) {
|
||||
return IntSet::make_point(e);
|
||||
} else {
|
||||
return IntSet::make_point(T::make(AsNumber(a), AsNumber(b)));
|
||||
}
|
||||
} else {
|
||||
return BinaryCombine<T>(a, b);
|
||||
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
|
||||
return IntSet::single_point(e);
|
||||
}
|
||||
IntSet r = Combine<T>(a, b);
|
||||
return r;
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
|
||||
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
|
||||
.set_dispatch<Add>(Binary<Add>)
|
||||
.set_dispatch<Sub>(Binary<Sub>)
|
||||
.set_dispatch<Mul>(Binary<Mul>)
|
||||
.set_dispatch<Div>(Binary<Div>)
|
||||
.set_dispatch<Mod>(Binary<Mod>)
|
||||
.set_dispatch<Min>(Binary<Min>)
|
||||
.set_dispatch<Max>(Binary<Max>);
|
||||
|
||||
// use simply bound for logical expressions for now.
|
||||
inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) {
|
||||
return IntSet::make_range(Range::make_with_min_extent(0, 2));
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
|
||||
.set_dispatch<EQ>(Logical)
|
||||
.set_dispatch<NE>(Logical)
|
||||
.set_dispatch<LT>(Logical)
|
||||
.set_dispatch<LE>(Logical)
|
||||
.set_dispatch<GT>(Logical)
|
||||
.set_dispatch<GE>(Logical)
|
||||
.set_dispatch<And>(Logical)
|
||||
.set_dispatch<Or>(Logical);
|
||||
|
||||
} // namespace
|
||||
.set_dispatch<Max>(Binary<Max>)
|
||||
.set_dispatch<EQ>(Binary<EQ>)
|
||||
.set_dispatch<NE>(Binary<NE>)
|
||||
.set_dispatch<LT>(Binary<LT>)
|
||||
.set_dispatch<LE>(Binary<LE>)
|
||||
.set_dispatch<GT>(Binary<GT>)
|
||||
.set_dispatch<GE>(Binary<GE>)
|
||||
.set_dispatch<And>(Binary<And>)
|
||||
.set_dispatch<Or>(Binary<Or>);
|
||||
|
||||
IntSet EvalSet(Expr e,
|
||||
const Map<IterVar, IntSet>& dom_map) {
|
||||
IRSetEvaluator m;
|
||||
IntSetEvaluator m;
|
||||
for (auto kv : dom_map) {
|
||||
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
|
||||
}
|
||||
return m.Eval(e);
|
||||
}
|
||||
|
||||
IntSet EvalSet(Range r,
|
||||
const Map<IterVar, IntSet>& dom_map) {
|
||||
IntSetEvaluator m;
|
||||
for (auto kv : dom_map) {
|
||||
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
|
||||
}
|
||||
IntSet min_set = m.Eval(r->min);
|
||||
IntSet ext_set = m.Eval(r->extent).cover_interval();
|
||||
const Interval& ei = ext_set.as<IntervalSet>()->i;
|
||||
if (!ei.has_upper_bound()) return IntSet::everything();
|
||||
ext_set = IntervalSet::make(0, ComputeExpr<Sub>(ei.max, 1));
|
||||
return Combine<Add>(min_set, ext_set);
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
||||
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
|
||||
p->stream << "interval-set["
|
||||
<< "[" << op->i.min << ", "
|
||||
<< op->i.max << ']';
|
||||
});
|
||||
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
||||
|
|
|
@ -22,35 +22,48 @@ class IntSet : public NodeRef {
|
|||
public:
|
||||
/*! \brief constructor */
|
||||
IntSet() {}
|
||||
// constructor from not deontainer.
|
||||
// constructor from not container.
|
||||
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
|
||||
/*! \return whether the set is empty */
|
||||
inline bool is_empty() const {
|
||||
return !defined();
|
||||
}
|
||||
/*!
|
||||
* \return a range that covers the IntSet
|
||||
*/
|
||||
Range GetCoverRange() const;
|
||||
/*!
|
||||
* \brief access the internal node container
|
||||
* \return the pointer to the internal node container
|
||||
*/
|
||||
inline const IntSetNode* operator->() const;
|
||||
/*!
|
||||
* \param dom The domain to be created.
|
||||
* \return create integer set from existing domain
|
||||
* \brief Find a range that covers the region.
|
||||
* \param max_range The range to be covered.
|
||||
* \return The covering range.
|
||||
*/
|
||||
static IntSet make_range(Range dom);
|
||||
Range cover_range(Range max_range) const;
|
||||
/*!
|
||||
* \param point
|
||||
* \return create integer set that only contains one point
|
||||
* \brief find an interval that covers the set.
|
||||
* \return The covering interval set.
|
||||
*/
|
||||
static IntSet make_point(Expr point);
|
||||
IntSet cover_interval() const;
|
||||
/*! \return Whether the set represent everything */
|
||||
bool is_everything() const;
|
||||
/*! \return Whether the set is a single point */
|
||||
bool is_single_point() const;
|
||||
/*! \return Whether the set contains everything */
|
||||
static IntSet everything();
|
||||
/*!
|
||||
* \return create integer set that represents everything
|
||||
* \brief construct a point set.
|
||||
* \param point The point in the set.
|
||||
* \return construct a single point set
|
||||
*/
|
||||
static IntSet make_all_set();
|
||||
static IntSet single_point(Expr point);
|
||||
/*!
|
||||
* \brief Construct a set representing a range.
|
||||
* \param r The range
|
||||
* \return constructed set.
|
||||
*/
|
||||
static IntSet range(Range r);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Base class of all IntSet containers.
|
||||
*/
|
||||
struct IntSetNode : public Node {
|
||||
};
|
||||
|
||||
/*!
|
||||
|
@ -63,6 +76,18 @@ class IntSet : public NodeRef {
|
|||
*/
|
||||
IntSet EvalSet(Expr e,
|
||||
const Map<IterVar, IntSet>& dom_map);
|
||||
|
||||
/*!
|
||||
* \brief Find an symbolic integer set that contains is union over
|
||||
* all the possible conditional values in dom_map.
|
||||
*
|
||||
* \param r The initial range.
|
||||
* \param dom_map The domain of each variable.
|
||||
* \return An integer set that can cover all the possible values.
|
||||
*/
|
||||
IntSet EvalSet(Range r,
|
||||
const Map<IterVar, IntSet>& dom_map);
|
||||
|
||||
/*!
|
||||
* \brief Conditional upward message passing.
|
||||
*
|
||||
|
@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
|
|||
const IntSet& fused,
|
||||
IntSet* outer,
|
||||
IntSet* inner);
|
||||
|
||||
/*!
|
||||
* \brief Conditional upward message passing.
|
||||
*
|
||||
* Get domain of parent, condition on domain of children.
|
||||
* Domain is represented as IntSet.
|
||||
*
|
||||
* \param s The Fuse relation node.
|
||||
* \param dom_map The old domain result from downward message passing.
|
||||
* Contains the domain set if all the children are full set.
|
||||
* \param rebased domain of rebased iteration.
|
||||
* \param parent The result domain of parent iteration.
|
||||
*/
|
||||
void PassUp(const RebaseNode* s,
|
||||
const std::unordered_map<IterVar, Range>& dom_map,
|
||||
const IntSet& fused,
|
||||
IntSet* parent);
|
||||
/*!
|
||||
* \brief Create an union set of all sets
|
||||
* \param sets The sets to be unioned
|
||||
|
@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
|
|||
*/
|
||||
IntSet Union(const Array<IntSet>& sets);
|
||||
|
||||
// implementation
|
||||
inline const IntSetNode* IntSet::operator->() const {
|
||||
return static_cast<const IntSetNode*>(node_.get());
|
||||
}
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
|
|||
}
|
||||
}
|
||||
CHECK(found)
|
||||
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
|
||||
<< "Cannot find the specified axis in parent stage's leaf_iter_vars";
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
|
|||
return *this;
|
||||
}
|
||||
|
||||
|
||||
Schedule::Schedule(Array<Operation> ops) {
|
||||
auto n = std::make_shared<ScheduleNode>();
|
||||
n->roots = ops;
|
||||
|
@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
|
|||
return IterVarRelation(n);
|
||||
}
|
||||
|
||||
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
|
||||
auto n = std::make_shared<RebaseNode>();
|
||||
n->parent = parent;
|
||||
n->rebased = rebased;
|
||||
return IterVarRelation(n);
|
||||
}
|
||||
|
||||
void Schedule::normalize() {
|
||||
std::unordered_map<IterVar, IterVar> rebase_map;
|
||||
std::unordered_map<const Node*, int> attach_mark;
|
||||
|
||||
|
||||
for (Stage s : (*this)->stages) {
|
||||
if (s->attach_type == kScope) {
|
||||
attach_mark[s->attach_stage.get()] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (Stage s : (*this)->stages) {
|
||||
if (!attach_mark.count(s.get())) continue;
|
||||
auto root_iter_vars = s->op->root_iter_vars();
|
||||
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
|
||||
|
||||
for (IterVar iv : root_iter_vars) {
|
||||
size_t idx = FindIterVar(leaf_vars, iv);
|
||||
if (idx < leaf_vars->data.size()) {
|
||||
// insert rebase
|
||||
IterVar rebased(Range(), iv->var->name_hint + ".rb");
|
||||
s->relations.push_back(RebaseNode::make(iv, rebased));
|
||||
leaf_vars->data[idx] = rebased.node_;
|
||||
rebase_map[iv] = rebased;
|
||||
}
|
||||
}
|
||||
}
|
||||
// remap the parent relation
|
||||
for (Stage s : (*this)->stages) {
|
||||
if (s->attach_type != kScope) continue;
|
||||
if (rebase_map.count(s->attach_ivar)) {
|
||||
s->attach_ivar = rebase_map.at(s->attach_ivar);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(StageNode);
|
||||
TVM_REGISTER_NODE_TYPE(SplitNode);
|
||||
TVM_REGISTER_NODE_TYPE(FuseNode);
|
||||
TVM_REGISTER_NODE_TYPE(RebaseNode);
|
||||
TVM_REGISTER_NODE_TYPE(ScheduleNode);
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -0,0 +1,388 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file schedule_ops.cc
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/schedule_pass.h>
|
||||
|
||||
#include "../pass/ir_util.h"
|
||||
#include "./int_set.h"
|
||||
#include "./graph.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace schedule {
|
||||
|
||||
using namespace ir;
|
||||
|
||||
/*!
|
||||
* \brief message passing to find if IterVar is related to reduction.
|
||||
* \param s The stage to be used.
|
||||
* \param p_state The message passing state
|
||||
* IterVar->flag
|
||||
*/
|
||||
void PassDownFlag(const Stage& s,
|
||||
std::unordered_map<IterVar, int>* p_state) {
|
||||
auto& state = *p_state;
|
||||
for (IterVarRelation rel : s->relations) {
|
||||
if (rel.as<SplitNode>()) {
|
||||
const SplitNode* s = rel.as<SplitNode>();
|
||||
int flag = state.at(s->parent);
|
||||
state[s->outer] = flag;
|
||||
state[s->inner] = flag;
|
||||
} else if (rel.as<FuseNode>()) {
|
||||
const FuseNode* s = rel.as<FuseNode>();
|
||||
int flag_outer = state.at(s->outer);
|
||||
int flag_inner = state.at(s->inner);
|
||||
state[s->fused] = flag_outer | flag_inner;
|
||||
} else if (rel.as<RebaseNode>()) {
|
||||
const RebaseNode* s = rel.as<RebaseNode>();
|
||||
int flag = state.at(s->parent);
|
||||
state[s->rebased] = flag;
|
||||
} else {
|
||||
LOG(FATAL) << "unknown relation type";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief use message passing to calculate the assignment of each Var inside the loop body.
|
||||
* \param s The schedule to be used.
|
||||
* \param dom_map The domain map of each iteration variable's domain
|
||||
* \param p_state The message passing state
|
||||
* IterVar->The assignment.
|
||||
*/
|
||||
void PassUpOffset(const Stage& s,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
std::unordered_map<IterVar, Expr>* p_state) {
|
||||
auto& state = *p_state;
|
||||
for (size_t i = s->relations.size(); i != 0; --i) {
|
||||
IterVarRelation rel = s->relations[i - 1];
|
||||
if (rel.as<SplitNode>()) {
|
||||
const SplitNode* s = rel.as<SplitNode>();
|
||||
Expr outer = state.at(s->outer);
|
||||
Expr inner = state.at(s->inner);
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
Expr parent_min = dom_map.at(s->parent)->min;
|
||||
state[s->parent] = inner + outer * factor;
|
||||
// add min if they exist
|
||||
if (!is_zero(parent_min)) {
|
||||
state[s->parent] = state[s->parent] + parent_min;
|
||||
}
|
||||
} else if (rel.as<FuseNode>()) {
|
||||
const FuseNode* s = rel.as<FuseNode>();
|
||||
Expr value = state.at(s->fused);
|
||||
Expr factor = dom_map.at(s->inner)->extent;
|
||||
Expr outer_min = dom_map.at(s->outer)->min;
|
||||
Expr inner_min = dom_map.at(s->inner)->min;
|
||||
state[s->outer] = value / factor;
|
||||
state[s->inner] = value % factor;
|
||||
// add min if they exist
|
||||
if (!is_zero(outer_min)) {
|
||||
state[s->outer] = state[s->outer] + outer_min;
|
||||
}
|
||||
if (!is_zero(inner_min)) {
|
||||
state[s->inner] = state[s->inner] + inner_min;
|
||||
}
|
||||
} else if (rel.as<RebaseNode>()) {
|
||||
const RebaseNode* s = rel.as<RebaseNode>();
|
||||
Expr value = state.at(s->rebased);
|
||||
Expr parent_min = dom_map.at(s->parent)->min;
|
||||
// add min if they exist
|
||||
if (!is_zero(parent_min)) {
|
||||
state[s->parent] = value + parent_min;
|
||||
} else {
|
||||
state[s->parent] = value;
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "unknown relation type";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<Stmt> >
|
||||
MakeLoopNest(const Stage& sch,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
size_t begin_loop,
|
||||
bool reduce_init_loop,
|
||||
std::unordered_map<IterVar, Expr>* p_value_map,
|
||||
const std::unordered_map<IterVar, bool>& skip_iter) {
|
||||
auto leaf_iter_vars = sch->leaf_iter_vars;
|
||||
Stmt no_op = Evaluate::make(0);
|
||||
// create the loop nest
|
||||
std::vector<std::vector<Stmt> > nest;
|
||||
nest.resize(leaf_iter_vars.size() + 1);
|
||||
std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
|
||||
|
||||
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
|
||||
auto iv = leaf_iter_vars[i];
|
||||
if (skip_iter.count(iv) && skip_iter.at(iv)) {
|
||||
// skip this iteration.
|
||||
value_map[iv] = iv->var;
|
||||
continue;
|
||||
}
|
||||
|
||||
Range dom = dom_map.at(iv);
|
||||
// initialize the offset and loop_level
|
||||
Var var = iv->var;
|
||||
if (reduce_init_loop) {
|
||||
var = Var(iv->var->name_hint + ".init", iv->var.type());
|
||||
}
|
||||
// Mark the iter var in the IR, to remember the point
|
||||
if (iv->thread_tag.length() == 0) {
|
||||
if (is_one(dom->extent)) {
|
||||
nest[i + 1].emplace_back(
|
||||
LetStmt::make(var, dom->min, no_op));
|
||||
value_map[iv] = dom->min;
|
||||
} else if (is_zero(dom->min)) {
|
||||
nest[i + 1].emplace_back(
|
||||
For::make(var, 0, dom->extent,
|
||||
ForType::Serial, DeviceAPI::None, no_op));
|
||||
value_map[iv] = var;
|
||||
} else {
|
||||
Var idx(iv->var->name_hint + ".idx", iv->var.type());
|
||||
nest[i + 1].emplace_back(
|
||||
For::make(idx, 0, dom->extent,
|
||||
ForType::Serial, DeviceAPI::None, no_op));
|
||||
Expr new_value = dom->min + idx;
|
||||
value_map[iv] = new_value;
|
||||
nest[i + 1].emplace_back(
|
||||
LetStmt::make(var, new_value, no_op));
|
||||
}
|
||||
} else {
|
||||
// Always restrict threaded IterVar to starts from 0.
|
||||
CHECK(is_zero(dom->min));
|
||||
// annotate the extent of the IterVar
|
||||
nest[i + 1].emplace_back(
|
||||
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
|
||||
value_map[iv] = var;
|
||||
}
|
||||
if (!reduce_init_loop) {
|
||||
// annotate the extent of the IterVar
|
||||
nest[i + 1].emplace_back(
|
||||
AttrStmt::make(iv, "scope", iv->var, no_op));
|
||||
}
|
||||
}
|
||||
// message passing to get offset of root iter vars.
|
||||
PassUpOffset(sch, dom_map, &value_map);
|
||||
return nest;
|
||||
}
|
||||
|
||||
Stmt MakeLoop(const Stage& s,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
Stmt provide,
|
||||
Stmt init) {
|
||||
std::unordered_map<IterVar, Expr> value_map;
|
||||
auto nest = MakeLoopNest(s, dom_map, 0, false, &value_map, {});
|
||||
provide = Substitute(provide, value_map);
|
||||
if (init.defined()) {
|
||||
// try to find the location to insert the initialization.
|
||||
// Fuse the initialization and provide loop when possible.
|
||||
std::unordered_map<IterVar, int> reduce_state;
|
||||
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
|
||||
for (IterVar iv : compute->reduce_axis) {
|
||||
reduce_state[iv] = 2;
|
||||
}
|
||||
for (IterVar iv : compute->axis) {
|
||||
reduce_state[iv] = 1;
|
||||
}
|
||||
// find which iter var is related to reduction and which is related to axis.
|
||||
PassDownFlag(s, &reduce_state);
|
||||
auto leaf_iter_vars = s->leaf_iter_vars;
|
||||
std::unordered_map<IterVar, Expr> init_value_map;
|
||||
// first first loop that is related to reduction.
|
||||
size_t begin_loop = leaf_iter_vars.size();
|
||||
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
|
||||
auto iv = leaf_iter_vars[i];
|
||||
int flag = reduce_state.at(iv);
|
||||
if ((flag & 2) != 0) {
|
||||
begin_loop = i; break;
|
||||
}
|
||||
init_value_map[iv] = value_map.at(iv);
|
||||
}
|
||||
// skip loops that does not relates to axis.
|
||||
std::unordered_map<IterVar, bool> skip_iter;
|
||||
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
|
||||
auto iv = leaf_iter_vars[i];
|
||||
int flag = reduce_state.at(iv);
|
||||
if ((flag & 1) == 0) skip_iter[iv] = true;
|
||||
}
|
||||
auto init_nest = MakeLoopNest(
|
||||
s, dom_map, begin_loop, true, &init_value_map, skip_iter);
|
||||
init = Substitute(init, init_value_map);
|
||||
init = MergeNest(init_nest, init);
|
||||
// common nest
|
||||
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop);
|
||||
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop, nest.end());
|
||||
provide = MergeNest(reduce, provide);
|
||||
return MergeNest(
|
||||
common, Block::make(init, provide));
|
||||
} else {
|
||||
return MergeNest(nest, provide);
|
||||
}
|
||||
}
|
||||
|
||||
Stmt MakeProvide(const ComputeOpNode* op,
|
||||
const std::vector<Tensor>& tensors) {
|
||||
Tensor t = tensors[0];
|
||||
Array<Expr> args;
|
||||
for (IterVar iv : op->axis) {
|
||||
args.push_back(iv->var);
|
||||
}
|
||||
return Provide::make(t->op, t->value_index, op->body, args);
|
||||
}
|
||||
|
||||
Stmt MakeRealize(const ComputeOpNode* op,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
const std::vector<Tensor>& tensors,
|
||||
Stmt body) {
|
||||
Tensor t = tensors[0];
|
||||
Halide::Internal::Region bounds;
|
||||
for (IterVar iv : op->axis) {
|
||||
bounds.push_back(dom_map.at(iv));
|
||||
}
|
||||
return Realize::make(t->op, t->value_index, t->dtype,
|
||||
bounds, make_const(Bool(1), true), body);
|
||||
}
|
||||
|
||||
|
||||
void MakeReduction(const ComputeOpNode* op,
|
||||
const std::vector<Tensor>& tensors,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
Stmt* init,
|
||||
Stmt* provide) {
|
||||
Stmt no_op = Evaluate::make(0);
|
||||
Tensor t = tensors[0];
|
||||
std::vector<Stmt> nest;
|
||||
Array<Expr> args;
|
||||
for (IterVar iv : op->axis) {
|
||||
args.push_back(iv->var);
|
||||
}
|
||||
const Reduce* reduce = op->body.as<Reduce>();
|
||||
CHECK(reduce);
|
||||
Expr init_value, update_value;
|
||||
if (reduce->op == "Add") {
|
||||
init_value = make_zero(reduce->type);
|
||||
update_value = Add::make(t(args), reduce->source);
|
||||
} else if (reduce->op == "Max") {
|
||||
init_value = reduce->type.min();
|
||||
update_value = Max::make(t(args), reduce->source);
|
||||
} else if (reduce->op == "Min") {
|
||||
init_value = reduce->type.max();
|
||||
update_value = Min::make(t(args), reduce->source);
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported reduction " << reduce->op;
|
||||
}
|
||||
*init = Provide::make(t->op, t->value_index, init_value, args);
|
||||
*provide = Provide::make(t->op, t->value_index, update_value, args);
|
||||
}
|
||||
|
||||
Stmt MakePipeline(const Stage& sch,
|
||||
const Map<IterVar, Range>& dom_map,
|
||||
Stmt consumer) {
|
||||
std::vector<Tensor> tensors;
|
||||
for (int i = 0; i < sch->op->num_outputs(); ++i) {
|
||||
tensors.emplace_back(sch->op.output(i));
|
||||
}
|
||||
|
||||
Stmt init, provide;
|
||||
|
||||
const ComputeOpNode* compute = sch->op.as<ComputeOpNode>();
|
||||
if (compute) {
|
||||
if (compute->reduce_axis.size() == 0) {
|
||||
provide = MakeProvide(compute, tensors);
|
||||
} else {
|
||||
MakeReduction(compute, tensors, dom_map, &init, &provide);
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "not supported op " << sch->op->type_key();
|
||||
}
|
||||
|
||||
Stmt producer = MakeLoop(sch, dom_map, provide, init);
|
||||
producer = ProducerConsumer::make(sch->op, true, producer);
|
||||
|
||||
Stmt pipeline = producer;
|
||||
if (consumer.defined()) {
|
||||
consumer = ProducerConsumer::make(sch->op, false, consumer);
|
||||
pipeline = Block::make(producer, consumer);
|
||||
}
|
||||
|
||||
if (sch->op.as<ComputeOpNode>()) {
|
||||
return MakeRealize(sch->op.as<ComputeOpNode>(),
|
||||
dom_map, tensors, pipeline);
|
||||
} else {
|
||||
LOG(FATAL) << "not supported op";
|
||||
return Stmt();
|
||||
}
|
||||
}
|
||||
|
||||
// inject the operator's realization on the stmt.
|
||||
class InjectRealize : public IRMutator {
|
||||
public:
|
||||
InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
|
||||
: schedule(schedule), dom_map(dom_map) {}
|
||||
|
||||
Stmt Mutate(Stmt stmt) final {
|
||||
CHECK(stmt.defined());
|
||||
stmt = IRMutator::Mutate(stmt);
|
||||
const AttrStmt* op = stmt.as<AttrStmt>();
|
||||
if (op != nullptr &&
|
||||
op->type_key == "scope") {
|
||||
if (op->node == schedule->attach_ivar) {
|
||||
CHECK(!found_attach);
|
||||
found_attach = true;
|
||||
stmt = AttrStmt::make(
|
||||
op->node, op->type_key, op->value,
|
||||
MakePipeline(schedule, dom_map,
|
||||
IRMutator::Mutate(op->body)));
|
||||
}
|
||||
}
|
||||
return stmt;
|
||||
}
|
||||
// the operations to be carried
|
||||
Stage schedule;
|
||||
// domain map
|
||||
Map<IterVar, Range> dom_map;
|
||||
// whether attach point is found
|
||||
bool found_attach{false};
|
||||
};
|
||||
|
||||
Stmt InjectInline(const Operation op, Stmt body) {
|
||||
CHECK(body.defined());
|
||||
const ComputeOpNode* compute = op.as<ComputeOpNode>();
|
||||
CHECK(compute != nullptr)
|
||||
<< "can only inline compute op";
|
||||
Array<Var> args;
|
||||
for (auto iv : compute->axis) {
|
||||
args.push_back(iv->var);
|
||||
}
|
||||
return Inline(body, op, args, compute->body);
|
||||
}
|
||||
|
||||
Stmt ScheduleOps(
|
||||
Schedule sch, Map<IterVar, Range> dom_map) {
|
||||
Stmt body = Stmt();
|
||||
// reverse the post DFS order.
|
||||
for (size_t i = sch->stages.size(); i != 0; --i) {
|
||||
Stage s = sch->stages[i - 1];
|
||||
// no need to specify place holder op.
|
||||
if (s->op.as<PlaceholderOpNode>()) continue;
|
||||
if (s->attach_type == kInline) {
|
||||
body = InjectInline(s->op, body);
|
||||
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
|
||||
body = MakePipeline(s, dom_map, body);
|
||||
} else if (s->attach_type == kScope) {
|
||||
CHECK(body.defined());
|
||||
InjectRealize mutator(s, dom_map);
|
||||
body = mutator.Mutate(body);
|
||||
CHECK(mutator.found_attach)
|
||||
<< "did not find attachment point";
|
||||
}
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
} // namespace schedule
|
||||
} // namespace tvm
|
|
@ -18,7 +18,8 @@ def test_add():
|
|||
|
||||
# one line to build the function.
|
||||
codes = []
|
||||
fadd = tvm.build(s, args=[A, B, C],
|
||||
fadd = tvm.build(s,
|
||||
args=[A, B, C],
|
||||
target="cuda", name="myadd",
|
||||
record_codes=codes)
|
||||
for c in codes:
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
def test_sum():
|
||||
# graph
|
||||
n = tvm.Var('n')
|
||||
m = tvm.Var('m')
|
||||
A = tvm.placeholder((n, m), name='A')
|
||||
k = tvm.IterVar((0, m))
|
||||
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
|
||||
# schedule
|
||||
s = tvm.Schedule(B.op)
|
||||
# create iter var and assign them tags.
|
||||
num_thread = 1
|
||||
block_x = tvm.IterVar(thread_tag="blockIdx.x")
|
||||
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
|
||||
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
|
||||
_, x = s[B].split(x, outer=thread_x)
|
||||
|
||||
tvm.init_opencl()
|
||||
codes = []
|
||||
fsum = tvm.build(s,
|
||||
args=[A, B],
|
||||
target="opencl", name="myadd",
|
||||
record_codes=codes)
|
||||
for c in codes:
|
||||
print(c)
|
||||
num_device = 1
|
||||
for i in range(num_device):
|
||||
ctx = tvm.opencl(i)
|
||||
if not ctx.enabled:
|
||||
continue
|
||||
# launch the kernel.
|
||||
n = 1028
|
||||
m = 129
|
||||
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
|
||||
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
|
||||
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
|
||||
fsum(a, b)
|
||||
np.testing.assert_allclose(
|
||||
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sum()
|
|
@ -18,8 +18,7 @@ def test_add_pipeline():
|
|||
|
||||
# compile to IR
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
|
||||
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
|
||||
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
|
||||
|
|
|
@ -10,12 +10,13 @@ def test_makeapi():
|
|||
s = tvm.Schedule(C.op)
|
||||
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
|
||||
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
|
||||
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
|
||||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
|
||||
|
||||
num_packed_args = 2
|
||||
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
|
||||
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
|
||||
|
|
|
@ -26,7 +26,7 @@ def test_tensor_reduce():
|
|||
B = tvm.placeholder((n, l), name='B')
|
||||
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
|
||||
rv = tvm.IterVar((0, A.shape[1]), name="k")
|
||||
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
|
||||
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv))
|
||||
# json load save
|
||||
C_json = tvm.save_json(C)
|
||||
C_loaded = tvm.load_json(C_json)
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_flatten2():
|
|||
s[A1].compute_at(s[A2], xo)
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
print(stmt)
|
||||
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
|
||||
|
|
|
@ -11,7 +11,7 @@ def test_schedule0():
|
|||
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
print(stmt)
|
||||
|
||||
def test_schedule1():
|
||||
|
@ -24,7 +24,7 @@ def test_schedule1():
|
|||
xo, xi = s[A1].split(A1.op.axis[0], 8)
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
print(stmt)
|
||||
|
||||
def test_schedule2():
|
||||
|
@ -39,7 +39,7 @@ def test_schedule2():
|
|||
s[A1].compute_at(s[A2], xo)
|
||||
bounds = tvm.schedule.InferBound(s)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
print(stmt)
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче