324 строки
10 KiB
C++
324 строки
10 KiB
C++
/*!
|
|
* Copyright (c) 2016 by Contributors
|
|
* \file tvm/arithmetic.h
|
|
* \brief Algebra and set operations and simplifications.
|
|
*/
|
|
#ifndef TVM_ARITHMETIC_H_
|
|
#define TVM_ARITHMETIC_H_
|
|
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <memory>
|
|
#include "expr.h"
|
|
|
|
namespace tvm {
|
|
|
|
class Tensor;
|
|
|
|
/*! \brief namespace of arithmetic */
|
|
namespace arith {
|
|
/*!
|
|
* \brief Sign of an expression or set.
|
|
*/
|
|
enum SignType {
|
|
kPositive,
|
|
kNegative,
|
|
kZero,
|
|
kUnknown
|
|
};
|
|
|
|
// internal node container of int set.
|
|
struct IntSetNode;
|
|
|
|
/*!
|
|
* \brief Integer set class, represent a set of integers in one dimension.
|
|
*/
|
|
class IntSet : public NodeRef {
|
|
public:
|
|
/*! \brief constructor */
|
|
IntSet() {}
|
|
// constructor from not container.
|
|
explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const IntSetNode* operator->() const;
|
|
/*!
|
|
* \brief Find a range that covers the region.
|
|
* \param max_range The range to be covered.
|
|
* \return The covering range.
|
|
*/
|
|
Range cover_range(Range max_range) const;
|
|
/*!
|
|
* \brief find an interval that covers the set.
|
|
* \return The covering interval set.
|
|
*/
|
|
IntSet cover_interval() const;
|
|
/*! \return Lower bound of the set */
|
|
Expr min() const;
|
|
/*! \return upper bound of the set */
|
|
Expr max() const;
|
|
/*! \return Whether the set represent nothing */
|
|
bool is_nothing() const;
|
|
/*! \return Whether the set represent everything */
|
|
bool is_everything() const;
|
|
/*! \return Whether the set is a single point */
|
|
bool is_single_point() const;
|
|
/*! \return Whether the set is proved to be bigger than 0 */
|
|
bool can_prove_positive() const;
|
|
/*! \return Whether the set is proved to be smaller than 0 */
|
|
bool can_prove_negative() const;
|
|
/*! \return The sign of the elements in the integer set */
|
|
SignType sign_type() const;
|
|
/*!
|
|
* \brief The single point value, call only if is_single_point is true
|
|
* \return The point value.
|
|
*/
|
|
Expr point_value() const;
|
|
/*!
|
|
* \brief Try to match IntSet with range r.
|
|
*
|
|
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
|
|
* \return true if we can prove they are the same.
|
|
*/
|
|
bool match_range(const Range& r) const;
|
|
/*! \return The set contains nothing */
|
|
static IntSet nothing();
|
|
/*! \return The set contains everything */
|
|
static IntSet everything();
|
|
/*!
|
|
* \brief construct a point set.
|
|
* \param point The point in the set.
|
|
* \return construct a single point set
|
|
*/
|
|
static IntSet single_point(Expr point);
|
|
/*!
|
|
* \brief construct a integer set from vector expression.
|
|
* \param vec The vector expression, can also be single point.
|
|
* \return The result set containing the indices in the vector.
|
|
*/
|
|
static IntSet vector(Expr vec);
|
|
/*!
|
|
* \brief Construct a set representing a range.
|
|
* \param r The range
|
|
* \return constructed set.
|
|
*/
|
|
static IntSet range(Range r);
|
|
/*!
|
|
* \brief Construct a set representing a interval.
|
|
* \param min The minimum value of the interval.
|
|
* \param max The maximum value of the interval.
|
|
* \return constructed set.
|
|
*/
|
|
static IntSet interval(Expr min, Expr max);
|
|
};
|
|
|
|
/*!
|
|
* \brief Range of a linear integer function.
|
|
* Use to do specify the possible index values.
|
|
*
|
|
* set = { coeff * x + base | x in Z }
|
|
*
|
|
* When coeff != 0, it can also be written as
|
|
* set = { n | n % coeff == base }
|
|
*
|
|
* This is useful to decide if the index is dividable by certain value.
|
|
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
|
|
*/
|
|
struct ModularEntry {
|
|
/*! \brief linear co-efficient */
|
|
int coeff{1};
|
|
/*! \brief The base */
|
|
int base{0};
|
|
|
|
/*! \return entry represent everything */
|
|
static ModularEntry everything() {
|
|
// always safe to set 0 + x, so it can be everything.
|
|
ModularEntry e;
|
|
e.coeff = 1;
|
|
e.base = 0;
|
|
return e;
|
|
}
|
|
/*!
|
|
* \brief Add two modular entries together to get a new modular entry.
|
|
* \param a The left operand.
|
|
* \param b The right operand.
|
|
* \return The combined modular entry.
|
|
*/
|
|
static ModularEntry Add(const ModularEntry& a,
|
|
const ModularEntry& b);
|
|
};
|
|
|
|
/*!
|
|
* \brief Base class of all IntSet containers.
|
|
*/
|
|
struct IntSetNode : public Node {
|
|
static constexpr const char* _type_key = "IntSet";
|
|
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
|
|
};
|
|
|
|
/*!
|
|
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
|
|
* Where coeff[i] and base are invariant of var[j] for all i and j.
|
|
*
|
|
* \param e The expression to be detected.
|
|
* \param vars List of variables to be used in detection.
|
|
* \return [coeff[i]] if it is possible, empty array if it is not.
|
|
*/
|
|
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
|
|
|
|
/*!
|
|
* \brief Detect if expression corresponds to clip bound of the vars
|
|
*
|
|
* \param e The expression to be detected.
|
|
* \param vars List of variables to be used in detection.
|
|
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
|
|
* return empty if the e does not match the pattern.
|
|
*/
|
|
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
|
|
|
|
/*!
|
|
* \brief Find an symbolic integer set that contains all possible values of
|
|
* e given the domain of each iteration variables.
|
|
*
|
|
* \param e The expression to be evaluated.
|
|
* \param dom_map The domain of each variable.
|
|
* \return An integer set that can cover all the possible values of e.
|
|
*/
|
|
IntSet EvalSet(Expr e,
|
|
const Map<IterVar, IntSet>& dom_map);
|
|
/*!
|
|
* \brief Same as EvalSet, but takes unordered_map
|
|
*
|
|
* \param e The expression to be evaluated.
|
|
* \param dom_map The domain of each variable.
|
|
* \return An integer set that can cover all the possible values of e.
|
|
*/
|
|
IntSet EvalSet(Expr e,
|
|
const std::unordered_map<const Variable*, IntSet>& dom_map);
|
|
|
|
/*!
|
|
* \brief Find an symbolic integer set that contains is union over
|
|
* all the possible conditional values in dom_map.
|
|
*
|
|
* \param r The initial range.
|
|
* \param dom_map The domain of each variable.
|
|
* \return An integer set that can cover all the possible values.
|
|
*/
|
|
IntSet EvalSet(Range r,
|
|
const Map<IterVar, IntSet>& dom_map);
|
|
|
|
/*!
|
|
* \brief Find an symbolic integer set that contains is union over
|
|
* all the possible conditional values in dom_map.
|
|
*
|
|
* \param s The initial set.
|
|
* \param dom_map The domain of each variable.
|
|
* \return An integer set that can cover all the possible values.
|
|
*/
|
|
IntSet EvalSet(IntSet s,
|
|
const std::unordered_map<const Variable*, IntSet>& dom_map);
|
|
/*!
|
|
* \brief Same as EvalSet, but takes unordered_map
|
|
*
|
|
* \param r The range to be evaluated.
|
|
* \param dom_map The domain of each variable.
|
|
* \return An integer set that can cover all the possible values of e.
|
|
*/
|
|
IntSet EvalSet(Range r,
|
|
const std::unordered_map<const Variable*, IntSet>& dom_map);
|
|
|
|
/*! \brief Map from Expr to IntSet */
|
|
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
|
|
/*!
|
|
* \brief Find the integer set of every sub-expression, given the
|
|
* domain of each iteration variables.
|
|
*
|
|
* \param e The expression to be evaluated.
|
|
* \param dom_map The domain of each variable.
|
|
* \return the map from the expression to its possible value.
|
|
*/
|
|
ExprIntSetMap EvalSetForEachSubExpr(
|
|
Expr e,
|
|
const std::unordered_map<const Variable*, IntSet>& dom_map);
|
|
|
|
/*!
|
|
* \brief Create an union set of all sets
|
|
* \param sets The sets to be unioned
|
|
* \return the set after union
|
|
*/
|
|
IntSet Union(const Array<IntSet>& sets);
|
|
|
|
/*!
|
|
* \brief Create an union set of all sets
|
|
* \param sets The sets to be intersected
|
|
* \return the set after intersected
|
|
*/
|
|
IntSet Intersect(const Array<IntSet>& sets);
|
|
|
|
/*!
|
|
* \brief Deduce the bound of the target variable in a expression,
|
|
* give the domain of each variables. Return undefined IntSet to
|
|
* represent failure.
|
|
*
|
|
* \param v The target variable to be deduced.
|
|
* \param cond The conditional expression.
|
|
* \param hint_map The domain of variable, used to help deduce.
|
|
* \param relax_map The domain of each variable, used to relax the domain,
|
|
* The deduce bound mush implies e for all value in relax_map
|
|
* \return An integer set that can cover all the possible values.
|
|
*/
|
|
IntSet DeduceBound(Expr v, Expr cond,
|
|
const Map<Var, IntSet>& hint_map,
|
|
const Map<Var, IntSet>& relax_map);
|
|
/*!
|
|
* \brief Same as DeduceBound with unordered_map signature.
|
|
*
|
|
* \param v The target variable to be deduced.
|
|
* \param cond The conditional expression.
|
|
* \param hint_map The domain of variable, used to help deduce.
|
|
* \param relax_map The domain of each variable, used to relax the domain,
|
|
* The deduce bound mush implies e for all value in relax_map
|
|
* \return An integer set that can cover all the possible values.
|
|
*/
|
|
IntSet DeduceBound(Expr v, Expr cond,
|
|
const std::unordered_map<const Variable*, IntSet>& hint_map,
|
|
const std::unordered_map<const Variable*, IntSet>& relax_map);
|
|
|
|
/*!
|
|
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
|
|
* \param body The given statement.
|
|
* \param tensor The name of the calls or provides.
|
|
* \param consider_calls If calls (read) are considered.
|
|
* \param consider_provides If provides (write) are considered.
|
|
* \return The domain that covers all the calls or provides within the given statement.
|
|
*/
|
|
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
|
|
|
|
/*!
|
|
* \brief Evaluate the expression with modular analysis
|
|
* \param e The expression to be evaluated.
|
|
* \param mod_map Map of modular statistics of known variables.
|
|
* \return The ModularEntry covering all possible value of e.
|
|
*/
|
|
ModularEntry EvalModular(
|
|
const Expr& e,
|
|
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
|
|
|
|
/*!
|
|
* \brief Same as EvalModular, used by front-end.
|
|
* \param e The expression to be evaluated.
|
|
* \param mod_map Map of modular statistics of known variables.
|
|
* \return A ModularSet covering all possible value of e.
|
|
*/
|
|
IntSet EvalModular(const Expr& e,
|
|
const Map<Var, IntSet>& mod_map);
|
|
// implementation
|
|
inline const IntSetNode* IntSet::operator->() const {
|
|
return static_cast<const IntSetNode*>(node_.get());
|
|
}
|
|
} // namespace arith
|
|
} // namespace tvm
|
|
#endif // TVM_ARITHMETIC_H_
|