329 строки
9.4 KiB
C++
329 строки
9.4 KiB
C++
/*!
|
|
* Copyright (c) 2016 by Contributors
|
|
* \file schedule.h
|
|
* \brief Define a schedule.
|
|
*/
|
|
#ifndef TVM_SCHEDULE_H_
|
|
#define TVM_SCHEDULE_H_
|
|
|
|
#include <string>
|
|
#include "./base.h"
|
|
#include "./operation.h"
|
|
|
|
namespace tvm {
|
|
|
|
// Node container for Stage
|
|
class StageNode;
|
|
// Node container for Schedule
|
|
class ScheduleNode;
|
|
// Node container for IterVarRelation
|
|
class IterVarRelationNode;
|
|
|
|
/*! \brief the attachment type */
|
|
enum AttachType : int {
|
|
kNone = 0,
|
|
kRoot = 1,
|
|
kInline = 2,
|
|
kScope = 3
|
|
};
|
|
|
|
/*! \brief Stage, contains scheduling for a stage of computation. */
|
|
class Stage : public NodeRef {
|
|
public:
|
|
Stage() {}
|
|
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief create a new schedule for op.
|
|
* \param op The operator in the schedule
|
|
*/
|
|
explicit Stage(Operation op);
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const StageNode* operator->() const;
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline StageNode* operator->();
|
|
/*!
|
|
* \brief set the memory scope of the stage
|
|
* \param scope The memory scope.
|
|
*/
|
|
Stage& set_scope(std::string scope); // NOLINT(*)
|
|
/*!
|
|
* \brief specify the schedule to be computed at the parent schedule's scope.
|
|
* \param parent The parent schedule.
|
|
* \param scope The iteration point to carry the schedule.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
|
|
/*!
|
|
* \brief Compute the function inline, attach it at parent.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& compute_inline(); // NOLINT(*)
|
|
/*!
|
|
* \brief Compute the function at root, attach it to its parent.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& compute_root(); // NOLINT(*)
|
|
/*!
|
|
* \brief Split the parent by factor, generate
|
|
* \param parent The parent iteration domain.
|
|
* \param p_outer The result outer domain
|
|
* \param p_inner The result inner domain.
|
|
* \param factor The split factor of the loop.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
|
|
/*!
|
|
* \brief Split the iteration with a given outer domain,
|
|
* the outer domain must have a thread-tag.
|
|
*
|
|
* \param parent The parent domain.
|
|
* \param outer The outer domain to be spliited, must have a thread_tag.
|
|
* \param p_inner The result inner domain.
|
|
* \param factor Optional, the factor of the split,
|
|
* factor must be provided such that factor * outer.extent >= parent.extent.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
|
|
/*!
|
|
* \brief Fuse the inner outer domain to the target
|
|
* \param inner The inner domain to be fused
|
|
* \param outer The outer domain to be fused.
|
|
* \param p_target The result target domain.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
|
|
/*!
|
|
* \brief Reorder the iteration
|
|
* \param order The order of iteration variable.
|
|
* \return reference to self.
|
|
*/
|
|
Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
|
|
/*!
|
|
* \brief Perform tiling on two dimensions
|
|
* The final loop order from outmost to inner most are
|
|
* [x_outer, y_outer, x_inner, y_inner]
|
|
*
|
|
* \param x_parent The original x dimension
|
|
* \param y_parent The original y dimension
|
|
* \param p_x_outer Outer axis of x dimension
|
|
* \param p_y_outer Outer axis of y dimension
|
|
* \param p_x_inner Inner axis of x dimension
|
|
* \param p_y_inner Inner axis of y dimension
|
|
* \param x_factor The stride factor on x axis
|
|
* \param y_factor The stride factor on y axis
|
|
* \return reference to self.
|
|
*/
|
|
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
|
|
IterVar* p_x_outer, IterVar* p_y_outer,
|
|
IterVar* p_x_inner, IterVar* p_y_inner,
|
|
Expr x_factor, Expr y_factor);
|
|
};
|
|
|
|
/*!
|
|
* \brief Global schedule container
|
|
* For operations and all the operations they depend on.
|
|
* The schedule per Operation is named as stage.
|
|
*/
|
|
class Schedule : public NodeRef {
|
|
public:
|
|
Schedule() {}
|
|
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief construct schedule for array of ops(and their dependencies).
|
|
* \param ops The ops to be scheduled.
|
|
*/
|
|
explicit Schedule(Array<Operation> ops);
|
|
/*!
|
|
* \brief Get the stage corresponds to the op
|
|
* \param op The operation.
|
|
*/
|
|
Stage operator[](const Operation& op);
|
|
/*!
|
|
* \brief Short hand for getting the stage of tensor's operation.
|
|
* \param tensor The tensor
|
|
* \return The stage corresponding to the tensor's op
|
|
*/
|
|
Stage operator[](const Tensor& tensor) {
|
|
return this->operator[](tensor->op);
|
|
}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const ScheduleNode* operator->() const;
|
|
};
|
|
|
|
/*!
|
|
* \brief The schedule relation between IterVars
|
|
* can be Split, Fuse.
|
|
*/
|
|
class IterVarRelation : public NodeRef {
|
|
public:
|
|
IterVarRelation() {}
|
|
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const IterVarRelationNode* operator->() const;
|
|
};
|
|
|
|
// defintion of node containers
|
|
/*!
|
|
* \brief represents the schedule of the tensor
|
|
*
|
|
* A schedule is a Directed acylic hypergraph.
|
|
* With each node is represented by a IterVar,
|
|
* and each hyper-edge is represented by a IterVarRelation.
|
|
*
|
|
* The relations can be Split/Fuse.
|
|
*
|
|
* The current data structure stores the hyper graph in its
|
|
* bipartite representation.
|
|
*
|
|
* The relations connects the IterVars in the graph.
|
|
*/
|
|
class StageNode : public Node {
|
|
public:
|
|
/*! \brief The operation to be scheduled */
|
|
Operation op;
|
|
/*! \brief The thread scope level of the stage */
|
|
std::string scope;
|
|
/*! \brief All the nodes in the iter var */
|
|
Array<IterVar> all_iter_vars;
|
|
/*!
|
|
* \brief The current leafs in the schedule.
|
|
* Operations can only be performed in leaves.
|
|
*/
|
|
Array<IterVar> leaf_iter_vars;
|
|
/*! \brief The relation bwteen of IterVars */
|
|
Array<IterVarRelation> relations;
|
|
/*! \brief The attachment type of the schedule */
|
|
AttachType attach_type{kNone};
|
|
/*! \brief The attach point of this schedule. */
|
|
IterVar attach_ivar;
|
|
/*! \brief The stage this node attaches to */
|
|
Stage attach_stage;
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("scope", &scope);
|
|
v->Visit("op", &op);
|
|
v->Visit("all_iter_vars", &all_iter_vars);
|
|
v->Visit("leaf_iter_vars", &leaf_iter_vars);
|
|
v->Visit("relations", &relations);
|
|
v->Visit("attach_type", &attach_type);
|
|
v->Visit("attach_ivar", &attach_ivar);
|
|
v->Visit("attach_stage", &attach_stage);
|
|
}
|
|
|
|
static constexpr const char* _type_key = "Stage";
|
|
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
|
|
};
|
|
|
|
/*! \brief node container for schedule */
|
|
class ScheduleNode : public Node {
|
|
public:
|
|
/*! \brief The root operations */
|
|
Array<Operation> roots;
|
|
/*!
|
|
* \brief list of all stages for non-placeholder ops
|
|
* The stage are ordered in PostDFS order of their op.
|
|
*/
|
|
Array<Stage> stages;
|
|
/*! \brief map of operation to the stages */
|
|
Map<Operation, Stage> stage_map;
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("roots", &roots);
|
|
v->Visit("stages", &stages);
|
|
v->Visit("stage_map", &stage_map);
|
|
}
|
|
|
|
static constexpr const char* _type_key = "Schedule";
|
|
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
|
|
};
|
|
|
|
/*! \brief base node of iteration var */
|
|
class IterVarRelationNode : public Node {
|
|
};
|
|
|
|
/*!
|
|
* \brief Split the parent domain into product of
|
|
* outer and iter.
|
|
*/
|
|
class SplitNode : public IterVarRelationNode {
|
|
public:
|
|
/*! \brief The parent domain */
|
|
IterVar parent;
|
|
/*! \brief The outer domain */
|
|
IterVar outer;
|
|
/*! \brief The inner domain */
|
|
IterVar inner;
|
|
/*! \brief The split factor */
|
|
Expr factor;
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("parent", &parent);
|
|
v->Visit("outer", &outer);
|
|
v->Visit("inner", &inner);
|
|
v->Visit("factor", &factor);
|
|
}
|
|
|
|
static IterVarRelation make(
|
|
IterVar parent, IterVar outer,
|
|
IterVar inner, Expr factor);
|
|
|
|
static constexpr const char* _type_key = "Split";
|
|
TVM_DECLARE_NODE_TYPE_INFO(SplitNode);
|
|
};
|
|
|
|
/*!
|
|
* \brief Fuse two domains into one domain.
|
|
*/
|
|
class FuseNode : public IterVarRelationNode {
|
|
public:
|
|
/*! \brief The outer domain */
|
|
IterVar outer;
|
|
/*! \brief The inner domain */
|
|
IterVar inner;
|
|
/*! \brief The target domain */
|
|
IterVar fused;
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("outer", &outer);
|
|
v->Visit("inner", &inner);
|
|
v->Visit("fused", &fused);
|
|
}
|
|
|
|
static IterVarRelation make(
|
|
IterVar outer, IterVar inner, IterVar fused);
|
|
|
|
static constexpr const char* _type_key = "Fuse";
|
|
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
|
|
};
|
|
|
|
// implementations
|
|
inline const StageNode* Stage::operator->() const {
|
|
return static_cast<const StageNode*>(node_.get());
|
|
}
|
|
inline StageNode* Stage::operator->() {
|
|
return static_cast<StageNode*>(node_.get());
|
|
}
|
|
|
|
inline const ScheduleNode* Schedule::operator->() const {
|
|
return static_cast<const ScheduleNode*>(node_.get());
|
|
}
|
|
|
|
inline const IterVarRelationNode* IterVarRelation::operator->() const {
|
|
return static_cast<const IterVarRelationNode*>(node_.get());
|
|
}
|
|
|
|
} // namespace tvm
|
|
#endif // TVM_SCHEDULE_H_
|