onnxruntime-tvm/include/tvm/ir.h

530 строки
17 KiB
C++

/*!
* Copyright (c) 2016 by Contributors
* \file tvm/ir.h
* \brief Additional high level nodes in the IR
*/
#ifndef TVM_IR_H_
#define TVM_IR_H_
#include <ir/Expr.h>
#include <ir/IR.h>
#include <type_traits>
#include <string>
#include "base.h"
#include "expr.h"
#include "runtime/util.h"
namespace tvm {
namespace ir {
using HalideIR::Internal::ExprNode;
using HalideIR::Internal::StmtNode;
using HalideIR::Internal::IRNodeType;
using HalideIR::Internal::ForType;
using HalideIR::DeviceAPI;
// Node container for CommReducer
struct CommReducerNode;
struct CommReducer : public NodeRef {
CommReducer() {}
explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* get() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* operator->() const;
/*! \brief type indicate the container type */
using ContainerType = CommReducerNode;
};
/*!
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
Array<Expr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
TVM_DLL static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
static constexpr const char* _type_key = "CommReducer";
TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node);
};
inline const CommReducerNode* CommReducer::get() const {
return static_cast<CommReducerNode*>(node_.get());
}
inline const CommReducerNode* CommReducer::operator->() const {
return static_cast<CommReducerNode*>(node_.get());
}
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Array<Expr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief the index of this reduce node */
int value_index;
/*! \brief construct expr from op and rdom */
TVM_DLL static Expr make(CommReducer combiner,
Array<Expr> src,
Array<IterVar> rdom,
Expr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("combiner", &combiner);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
};
/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
constexpr const char* thread_extent = "thread_extent";
/*! \brief Mark launching of a virtual thread. */
constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark region is processed by a co-proccesor */
constexpr const char* coproc_scope = "coproc_scope";
/*!
* \brief Mark region creates coprocessor micro ops,
* can be reused if corresponding variable is independent.
*/
constexpr const char* coproc_uop_scope = "coproc_uop_scope";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark the scope as generated by extern primitive.
* such scope can contain arbitrary ir program and we need to be careful
* when make certain assumptions about the structure of the program.
*/
constexpr const char* extern_scope = "extern_scope";
/*!
* \brief Mark the scope as when computation start to happen
* This can hint some code generator to create a new function for compute.
*/
constexpr const char* compute_scope = "compute_scope";
/*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage alignement requirement of buffers */
constexpr const char* storage_alignment = "storage_alignment";
/*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation context for global malloc in host. */
constexpr const char* device_context_id = "device_context_id";
/*! \brief The device type. */
constexpr const char* device_context_type = "device_context_type";
/*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*!
* \brief Marks production of double buffer data
*/
constexpr const char* double_buffer_scope = "double_buffer_scope";
/*!
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
/*!
* \brief Mark alignment of buffer dimension
* stmt.node is Tensor
* stmt.value is tvm_tuple(dim, align, offset)
* This gives hint to require stride of dim to be k * align + offset.
*/
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*!
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
* The scope represents that we need to bind the storage region of tensor to buffer.
* This will affect replacement of some variables inside the scope that
* corresponds to field of buffer to be the actual expressions of tensor during
* storage flattening phase.
*/
constexpr const char* buffer_bind_scope = "buffer_bind_scope";
// Pipeline related attributes
/*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope";
/*! \brief Advance step of channel after end of scope */
constexpr const char* channel_read_advance = "channel_read_advance";
/*! \brief channel write scope */
constexpr const char* channel_write_scope = "channel_write_scope";
/*! \brief Advance step of channel after end of scope */
constexpr const char* channel_write_advance = "channel_write_advance";
/*! \brief pipeline stage scope, implies always execution */
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
/*! \brief pipeline execution scope, implies the scope can be pipelined. */
constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
/*!
* \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only
* allows writing out to one element of the output texture, the Provide node
* gets translated to a special Call::glsl_texture_store statement instead of a
* Store statement.
*/
constexpr const char* opengl_stage_scope = "opengl_stage_scope";
/*!
* \brief Mark that it is in the device scope.
*/
constexpr const char* device_scope = "device_scope";
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
* \return true if it is a pragma key
*/
inline bool IsPragmaKey(const std::string& attr_key) {
return attr_key.compare(0, 7, "pragma_") == 0;
}
} // namespace attr
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
/*!
* \brief See pesudo code
*
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief Same as select, used for unsafe memory access.
*
* Type tvm_if_then_else(cond, a, b) {
* return cond ? a : b;
* }
*/
constexpr const char* tvm_if_then_else = "tvm_if_then_else";
/*!
* \brief Get head access address with memory access pattern info.
*
* This operator also marks range of the memory access
* The offset and extent are in unit of the DType(including vectorization factor).
* rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
* The access is assume to happen in the current expression.
*
* PtrType tvm_access_ptr(Expr dtype, DType* data,
* int offset, int extent,
* int rw_mask) {
* // DType == dtype.type();
* return &data[offset];
* }
*/
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
* \brief Create a function local static handle that iniitalizes to nullptr.
* can be used to cache function local static resources.
*/
constexpr const char* tvm_static_handle = "tvm_static_handle";
/*!
* \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace.
*/
constexpr const char* tvm_context_id = "tvm_context_id";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
*
* Handle tvm_tuple(value0, value1, ..., value_n);
*/
constexpr const char* tvm_tuple = "tvm_tuple";
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field;
* }
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_struct_get = "tvm_struct_get";
/*!
* \brief See pesudo code
*
* Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
* arr[index]->field = value;
* }
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_struct_set = "tvm_struct_set";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* void tvm_throw_last_error() {
* throw TVMGetLastError();
* }
*/
constexpr const char* tvm_throw_last_error = "tvm_throw_last_error";
/*!
* \brief See pesudo code
*
* dtype in {shape, array, arg_value, arg_tcode}
*
* Handle tvm_stack_alloca(string dtype, int num) {
* return new on stack dtype[num];
* }
*/
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
/*!
* \brief Allocate a shape tuple on stack, return the handle.
*
* Handle tvm_stack_make_shape(list args) {
* ret = alloca stack int64_t[len(args)];
* for i in range(len(args)):
* ret[i] = args[i]
* return &ret[0];
* }
*/
constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
/*!
* \brief Allocate a NDArray(DLTensor) on stack, return the handle.
*
* Type tvm_stack_make_array(Expr data,
* Expr shape,
* Expr strides,
* Expr ndim,
* Expr dtype,
* Expr elem_offset) {
* ret = alloca stack DLTensor();
* ret->data = data;
* ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim;
* ret->dtype = dtype.type();
* ret->byte_offset = elem_offset * sizeof(dtype);
* return ret;
* }
*/
constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
/*!
* \brief See pesudo code
*
* int tvm_call_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief See pesudo code
* Mark the content as thread local context, can get optimized
* by only call the call once at thread start.
*
* Do not allow nesting(getting a thread context from another).
*
* Handle tvm_thread_context(Expr call) {
* return call;
* }
*/
constexpr const char* tvm_thread_context = "tvm_thread_context";
/*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* }
*/
constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
/*!
* \brief See pseudo code
*
* int tvm_storage_sync(std::string storage_scope) {
* __sync(storage_scope);
* return 0;
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* return (value passed in by warp indicated by warp_id);
* }
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
*/
constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
* \brief See pesudo code
*
* void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
* Var reduce_temp0, .., Var thread_idx1, ...) {
* // constraint by the other thread_idx remain the same.
* // reduce_temp is used to save intermediate result.
* reduce_temp0, ... = reduce(combiner, source0, ..., cond
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
} // namespace intrinsic
// Reuse IR node defintiion from HalideIR
using HalideIR::Internal::IntImm;
using HalideIR::Internal::UIntImm;
using HalideIR::Internal::FloatImm;
using HalideIR::Internal::StringImm;
using HalideIR::Internal::Cast;
using HalideIR::Internal::Add;
using HalideIR::Internal::Sub;
using HalideIR::Internal::Mul;
using HalideIR::Internal::Div;
using HalideIR::Internal::Mod;
using HalideIR::Internal::Min;
using HalideIR::Internal::Max;
using HalideIR::Internal::EQ;
using HalideIR::Internal::NE;
using HalideIR::Internal::LT;
using HalideIR::Internal::LE;
using HalideIR::Internal::GT;
using HalideIR::Internal::GE;
using HalideIR::Internal::And;
using HalideIR::Internal::Or;
using HalideIR::Internal::Not;
using HalideIR::Internal::Select;
using HalideIR::Internal::Load;
using HalideIR::Internal::Ramp;
using HalideIR::Internal::Broadcast;
using HalideIR::Internal::Call;
using HalideIR::Internal::Let;
using HalideIR::Internal::LetStmt;
using HalideIR::Internal::AttrStmt;
using HalideIR::Internal::AssertStmt;
using HalideIR::Internal::ProducerConsumer;
using HalideIR::Internal::For;
using HalideIR::Internal::Store;
using HalideIR::Internal::Provide;
using HalideIR::Internal::Allocate;
using HalideIR::Internal::Free;
using HalideIR::Internal::Realize;
using HalideIR::Internal::Prefetch;
using HalideIR::Internal::Block;
using HalideIR::Internal::IfThenElse;
using HalideIR::Internal::Evaluate;
using HalideIR::Internal::Shuffle;
/*!
* \brief Create a type annotation expression
* \param dtype The data type
* \return Expr a expression with dtype.
*/
inline Expr TypeAnnotation(Type dtype) {
return ir::Call::make(dtype,
"type_annotation", {},
ir::Call::PureIntrinsic);
}
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_IR_H_