192 строки
4.8 KiB
C++
192 строки
4.8 KiB
C++
/*!
|
|
* Copyright (c) 2016 by Contributors
|
|
* \file expr.h
|
|
* \brief The Expr and related elements in DataFlow construction.
|
|
*/
|
|
#ifndef TVM_EXPR_H_
|
|
#define TVM_EXPR_H_
|
|
|
|
#include <ir/Expr.h>
|
|
#include <ir/IRPrinter.h>
|
|
#include <ir/IROperator.h>
|
|
#include <string>
|
|
#include <algorithm>
|
|
#include "./base.h"
|
|
#include "./runtime/packed_func.h"
|
|
|
|
namespace tvm {
|
|
|
|
using Halide::Type;
|
|
using Halide::Float;
|
|
using Halide::Bool;
|
|
using Halide::Int;
|
|
using Halide::UInt;
|
|
using Halide::Handle;
|
|
|
|
using Halide::Expr;
|
|
using Halide::VarExpr;
|
|
using Halide::IR::FunctionRef;
|
|
using Halide::IR::FunctionBaseNode;
|
|
using Halide::Internal::Stmt;
|
|
using Halide::Internal::IRPrinter;
|
|
using Halide::Internal::Variable;
|
|
|
|
using Halide::Internal::make_const;
|
|
|
|
/*! \brief a named variable in TVM */
|
|
class Var : public Halide::VarExpr {
|
|
public:
|
|
explicit Var(const std::string& name_hint = "v",
|
|
Type t = Int(32)) : VarExpr(name_hint, t) {}
|
|
|
|
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
|
|
|
|
/*! \brief type indicate the container type */
|
|
using ContainerType = Variable;
|
|
};
|
|
|
|
|
|
/*! \brief container class of iteration variable. */
|
|
class IterVarNode;
|
|
|
|
/*!
|
|
* \brief same as Halide::IR::Range
|
|
* except it provide an constructor with (begin, end)
|
|
*
|
|
* \note Traditional Halide's Range have a constructor with
|
|
* (begin, extent), which does not match the convention in e.g. python.
|
|
* We decided to correct it by removing the constructor in HalideIR,
|
|
* and add it back in TVM's range.
|
|
*/
|
|
class Range : public Halide::IR::Range {
|
|
public:
|
|
/*! \brief constructor */
|
|
Range() {}
|
|
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
|
|
/*!
|
|
* \brief constructor by begin and end
|
|
* \param begin The begin of the range.
|
|
* \param end The end of the range.
|
|
*/
|
|
Range(Expr begin, Expr end);
|
|
|
|
static Range make_with_min_extent(Expr min, Expr extent);
|
|
};
|
|
|
|
/*!
|
|
* \brief Iteration Variable,
|
|
* represents an iteration over an integer interval.
|
|
*/
|
|
class IterVar : public NodeRef {
|
|
public:
|
|
// construct a new iter var without a domain
|
|
IterVar() {}
|
|
// construct from shared ptr.
|
|
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief construction of iteration variable.
|
|
* \param dom The iteration domain.
|
|
* \param var_name The name of iteration variable.
|
|
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
|
|
*/
|
|
explicit IterVar(Range dom, std::string var_name = "i", std::string thread_tag = "");
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const IterVarNode* operator->() const;
|
|
/*!
|
|
* \return the corresponding var in the IterVar.
|
|
*/
|
|
inline operator Expr() const;
|
|
/*! \brief specify container node */
|
|
using ContainerType = IterVarNode;
|
|
};
|
|
|
|
using Domain = Array<Range>;
|
|
|
|
// functions
|
|
using Halide::cast;
|
|
using Halide::min;
|
|
using Halide::max;
|
|
using Halide::abs;
|
|
using Halide::select;
|
|
|
|
/*!
|
|
* \brief sum of of source expression over rdom
|
|
* \param source The source expression.
|
|
* \param rdom List of iteration variables that will be used for reduction.
|
|
*/
|
|
Expr sum(Expr source, Array<IterVar> rdom);
|
|
|
|
/*!
|
|
* \brief max of of source expression over rdom
|
|
* \param source The source expression.
|
|
* \param rdom List of iteration variables that will be used for reduction.
|
|
*/
|
|
Expr max(Expr source, Array<IterVar> rdom);
|
|
|
|
/*!
|
|
* \brief max of of source expression over rdom
|
|
* \param source The source expression.
|
|
* \param rdom List of iteration variables that will be used for reduction.
|
|
*/
|
|
Expr min(Expr source, Array<IterVar> rdom);
|
|
|
|
|
|
// print functions for expr
|
|
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
|
|
|
|
// definition of Node.
|
|
/*!
|
|
* \brief An iteration variable representing an iteration
|
|
* over a one dimensional interval.
|
|
*/
|
|
class IterVarNode : public Node {
|
|
public:
|
|
/*!
|
|
* \brief the domain of iteration, if known, can be None
|
|
* For the intermediate schedule node, before schedule.
|
|
*/
|
|
Range dom;
|
|
/*! \brief The looping variable */
|
|
Var var;
|
|
/*!
|
|
* \brief additional tag on the iteration variable,
|
|
* set this if this is binded already to a known thread tag.
|
|
*/
|
|
std::string thread_tag;
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("dom", &dom);
|
|
v->Visit("var", &var);
|
|
v->Visit("thread_tag", &thread_tag);
|
|
}
|
|
|
|
static IterVar make(Range dom, Var var, std::string thread_tag);
|
|
|
|
static constexpr const char* _type_key = "IterVar";
|
|
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
|
|
};
|
|
|
|
// inline implementations
|
|
inline const IterVarNode* IterVar::operator->() const {
|
|
return static_cast<const IterVarNode*>(node_.get());
|
|
}
|
|
|
|
inline IterVar::operator Expr() const {
|
|
return (*this)->var;
|
|
}
|
|
|
|
} // namespace tvm
|
|
|
|
namespace std {
|
|
template <>
|
|
struct hash<::tvm::IterVar> {
|
|
std::size_t operator()(const ::tvm::IterVar& k) const {
|
|
return k.hash();
|
|
}
|
|
};
|
|
}
|
|
#endif // TVM_EXPR_H_
|