[CODEGEN/PASS] Improve callpacked lowering, allow pass array callback. (#110)

* [CODEGEN/PASS] Improve callpacked lowering, allow pass array callback.

* fix cython
This commit is contained in:
Tianqi Chen 2017-04-29 19:43:04 -07:00 коммит произвёл GitHub
Родитель d45b6d4b84
Коммит 9ba40dc0fe
32 изменённых файлов: 1239 добавлений и 451 удалений

Просмотреть файл

@ -46,6 +46,7 @@ tvm.ir_pass
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.UnrollLoop
tvm.ir_pass.StorageSync
tvm.ir_pass.StorageRewrite
tvm.ir_pass.MakeAPI
tvm.ir_pass.SplitHostDevice
tvm.ir_pass.InjectVirtualThread
@ -53,6 +54,8 @@ tvm.ir_pass
tvm.ir_pass.RemoveNoOp
tvm.ir_pass.SplitPipeline
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.LowerIntrin
tvm.ir_pass.LowerPackedCall
tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass

Просмотреть файл

@ -245,7 +245,6 @@ Expr max(Expr source, Array<IterVar> axis);
*/
Expr min(Expr source, Array<IterVar> axis);
// print functions for expr
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)

Просмотреть файл

@ -140,6 +140,10 @@ constexpr const char* volatile_scope = "volatile_scope";
constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation context for global malloc in host. */
constexpr const char* device_context_id = "device_context_id";
/*! \brief The device type. */
constexpr const char* device_context_type = "device_context_type";
/*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
@ -167,25 +171,24 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field;
* }
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
constexpr const char* tvm_struct_get = "tvm_struct_get";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
* arr[index]->field = value;
* }
* \sa TVMArrayFieldKind
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
constexpr const char* tvm_struct_set = "tvm_struct_set";
/*!
* \brief See pesudo code
*
@ -194,6 +197,48 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field";
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* dtype in {shape, array, arg_value, arg_tcode}
*
* Handle tvm_stack_alloca(string dtype, int num) {
* return new on stack dtype[num];
* }
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
/*!
* \brief Allocate a shape tuple on stack, return the handle.
*
* Handle tvm_stack_make_shape(list args) {
* ret = alloca stack int64_t[len(args)];
* for i in range(len(args)):
* ret[i] = args[i]
* return &ret[0];
* }
*/
constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
/*!
* \brief Allocate a NDArray(DLTensor) on stack, return the handle.
*
* Type tvm_stack_make_array(Expr data,
* Expr shape,
* Expr strides,
* Expr ndim,
* Expr dtype,
* Expr byte_offset) {
* ret = alloca stack DLTensor();
* ret->data = data;
* ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim;
* ret->dtype = dtype.type();
* ret->byte_offset = byte_offset;
* return ret;
* }
*/
constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
/*!
* \brief See pesudo code
*
@ -205,6 +250,23 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
* }
*/
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* }
*/
constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
/*!
* \brief See pesudo code
*
@ -231,16 +293,24 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6,
kByteOffset = 7
/*! \brief The kind of structre field info */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic

Просмотреть файл

@ -251,6 +251,13 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerPackedCall(LoweredFunc f);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.

Просмотреть файл

@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase
from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
@ -188,6 +188,7 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
@ -195,7 +196,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
_CLASS_MODULE = None
_CLASS_FUNCTION = None

Просмотреть файл

@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray
from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array
print("TVM: Initializing cython mode...")
@ -29,7 +29,10 @@ cdef int tvm_callback(TVMValue* args,
tcode == kFuncHandle or
tcode == kModuleHandle):
CALL(TVMCbArgToReturn(&value, tcode))
pyargs.append(make_ret(value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(_make_array(ctypes_handle(value.v_handle), True))
try:
rv = local_pyfunc(*pyargs)
except Exception:
@ -64,7 +67,9 @@ def convert_to_tvm_func(object pyfunc):
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
return _CLASS_FUNCTION(ctypes_handle(chandle), False)
ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle
return ret
cdef inline void make_arg(object arg,

Просмотреть файл

@ -198,9 +198,9 @@ def sync(ctx):
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"]
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle):
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
@ -209,9 +209,11 @@ class NDArrayBase(object):
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
check_call(_LIB.TVMArrayFree(self.handle))
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
@property
def shape(self):
@ -302,6 +304,10 @@ class NDArrayBase(object):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):

Просмотреть файл

@ -146,7 +146,7 @@ def build(sch,
warp_size = 32 if target == "cuda" else 1
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
if len(fsplits) > 1:
if not target_host:
target_host = "llvm" if codegen.enabled("llvm") else "stackvm"

Просмотреть файл

@ -1,21 +1,56 @@
"""Intrinsics and math functions in TVM."""
from __future__ import absolute_import as _abs
from .expr import Call as _Call
from . import make as _make
from ._ffi.function import register_func as _register_func
from .api import convert
from . import make as _make
from .api import convert, const
from .expr import Call as _Call
from .schedule import Buffer as _Buffer
def _pack_buffer(buf):
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape,
_Call.Intrinsic, None, 0)
strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides,
_Call.Intrinsic, None, 0) if buf.strides else 0
pack_args = [buf.data,
shape,
strides,
len(buf.shape),
const(0, dtype=buf.dtype),
buf.byte_offset]
return _make.Call("handle", "tvm_stack_make_array",
pack_args, _Call.Intrinsic, None, 0)
def call_packed(*args):
"""Build expression by call an external packed function
"""Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will recieve an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters
----------
args : list
args : list of Expr or Buffer.
Positional arguments.
Returns
-------
call : Expr
The call expression.
See Also
--------
tvm.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
return _make.Call(
"int32", "tvm_call_packed", args, _Call.Intrinsic, None, 0)
"int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0)
def call_pure_intrin(dtype, func_name, *args):
@ -34,6 +69,11 @@ def call_pure_intrin(dtype, func_name, *args):
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
args = convert(args)
return _make.Call(
@ -53,6 +93,11 @@ def call_pure_extern(dtype, func_name, *args):
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
return _make.Call(
dtype, func_name, convert(args), _Call.PureExtern, None, 0)

Просмотреть файл

@ -36,7 +36,6 @@ TVM_REGISTER_API("_const")
}
});
TVM_REGISTER_API("_str")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ir::StringImm::make(args[0]);

Просмотреть файл

@ -76,5 +76,6 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
} // namespace ir
} // namespace tvm

Просмотреть файл

@ -89,8 +89,7 @@ void CodeGenC::PrintSSAAssign(
// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(
const Variable* buffer,
Type t, Expr index) {
Type t, const Variable* buffer, Expr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
@ -151,6 +150,58 @@ std::string CodeGenC::GetBufferRef(
return os.str();
}
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(
Type t, const Expr& buffer, const Expr& index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os;
os << "(((TVMArray*)";
this->PrintExpr(buffer, os);
os << ")";
if (kind == intrinsic::kArrAddr) {
os << " + ";
this->PrintExpr(index, os);
os << ")";
return os.str();
}
os << '[';
this->PrintExpr(index, os);
os << "].";
// other case: get fields.
switch (kind) {
case intrinsic::kArrData: os << "data"; break;
case intrinsic::kArrShape: os << "shape"; break;
case intrinsic::kArrStrides: os << "strides"; break;
case intrinsic::kArrNDim: os << "ndim"; break;
case intrinsic::kArrTypeCode: os << "dtype.code"; break;
case intrinsic::kArrTypeBits: os << "dtype.bits"; break;
case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break;
case intrinsic::kArrDeviceId: os << "ctx.device_id"; break;
case intrinsic::kArrDeviceType: os << "ctx.device_type"; break;
default: LOG(FATAL) << "unknown field code";
}
os << ')';
return os.str();
} else {
CHECK_LT(kind, intrinsic::kTVMValueKindBound_);
std::ostringstream os;
os << "(((TVMValue*)";
this->PrintExpr(buffer, os);
os << ")[" << index << "].";
if (t.is_handle()) {
os << "v_handle";
} else if (t.is_float()) {
os << "v_float64";
} else if (t.is_int()) {
os << "v_int64";
} else {
LOG(FATAL) << "donot know how to handle type" << t;
}
os << ")";
return os.str();
}
}
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
auto it = handle_data_type_.find(buf_var);
@ -182,15 +233,15 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
<< " = " << value << ";\n";
}
std::string CodeGenC::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
return GetBufferRef(buffer, t, base);
std::string CodeGenC::GetVecLoad(
Type t, const Variable* buffer, Expr base) {
return GetBufferRef(t, buffer, base);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
std::string ref = GetBufferRef(buffer, t, base);
std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
}
@ -430,42 +481,11 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
<< " + ";
this->PrintExpr(l->index, os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
if (!op->type.is_handle()) {
os << '(';
this->PrintType(op->type, os);
os << ')';
}
os << "(((TVMArg*)";
this->PrintExpr(op->args[0], os);
os << ")[" << op->args[2] << "].";
if (op->type.is_handle()) {
os << "v_handle";
} else if (op->type.is_float()) {
os << "v_double";
} else if (op->type.is_int() || op->type.is_uint()) {
os << "v_long";
} else {
LOG(FATAL) << "donot know how to handle type" << op->type;
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
os << "(((TVMArray*)";
this->PrintExpr(op->args[0], os);
os << ")->";
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: os << "data"; break;
case intrinsic::kShape: os << "shape"; break;
case intrinsic::kStrides: os << "strides"; break;
case intrinsic::kNDim: os << "ndim"; break;
case intrinsic::kTypeCode: os << "dtype.type_code"; break;
case intrinsic::kTypeBits: os << "dtype.bits"; break;
case intrinsic::kTypeLanes: os << "dtype.lanes"; break;
default: LOG(FATAL) << "unknown field code";
}
os << ')';
os << GetStructRef(
op->type, op->args[0], op->args[1],
op->args[2].as<IntImm>()->value);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
os << "(";
@ -513,12 +533,12 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
// delcare type.
if (op->type.lanes() == 1) {
std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index);
std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
os << ref;
} else {
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base);
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref;
} else {
// load seperately.
@ -552,7 +572,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index);
std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
@ -744,14 +764,25 @@ void CodeGenC::VisitStmt_(const Block *op) {
void CodeGenC::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call);
} else {
std::string vid = this->PrintExpr(op->value);
this->PrintIndent();
this->stream << "(void)" << vid << ";\n";
if (call) {
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call); return;
} else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(call->args.size(), 4);
std::string value = PrintExpr(call->args[3]);
std::string ref = GetStructRef(
call->args[3].type(),
call->args[0],
call->args[1],
call->args[2].as<IntImm>()->value);
this->PrintIndent();
this->stream << ref << " = " << value << ";\n";
return;
}
}
std::string vid = this->PrintExpr(op->value);
this->PrintIndent();
this->stream << "(void)" << vid << ";\n";
}
void CodeGenC::VisitStmt_(const ProducerConsumer *op) {

Просмотреть файл

@ -133,8 +133,7 @@ class CodeGenC :
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual std::string GetVecLoad(const Variable* buffer,
Type t, Expr base);
virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
@ -146,11 +145,13 @@ class CodeGenC :
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
protected:
// Print reference to struct location
std::string GetStructRef(
Type t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index.
std::string GetBufferRef(const Variable* buffer,
Type t, Expr index);
std::string GetBufferRef(
Type t, const Variable* buffer, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.

Просмотреть файл

@ -95,8 +95,8 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
std::string CodeGenOpenCL::GetVecLoad(
Type t, const Variable* buffer, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);

Просмотреть файл

@ -24,11 +24,11 @@ class CodeGenOpenCL : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
std::string GetVecLoad(const Variable* buffer,
Type t, Expr base) final;
std::string GetVecLoad(Type t, const Variable* buffer,
Expr base) final;
void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) final; // NOLINT(*)
Type t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os); // NOLINT(*)

Просмотреть файл

@ -7,6 +7,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
@ -89,7 +90,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
void CodeGenLLVM::InitTarget(const std::string& target) {
llvm::TargetMachine* tm;
std::string target_triple;
std::tie(tm, target_triple) = LLVMGetTarget(target);
std::tie(tm, target_triple) = GetLLVMTarget(target);
module_->setTargetTriple(target_triple);
module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get()));
@ -318,6 +319,74 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::CreateStructRefPtr(
Type t, llvm::Value* buf, llvm::Value* index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
} else {
CHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
}
}
switch (kind) {
case intrinsic::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
}
case intrinsic::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
}
case intrinsic::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
}
case intrinsic::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
}
case intrinsic::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
}
case intrinsic::kArrTypeCode: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(0)});
}
case intrinsic::kArrTypeBits: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(1)});
}
case intrinsic::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(3), ConstInt32(2)});
}
case intrinsic::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
}
case intrinsic::kArrDeviceId: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(0)});
}
case intrinsic::kArrDeviceType: {
return builder_->CreateInBoundsGEP(
buf, {index, ConstInt32(1), ConstInt32(1)});
}
case intrinsic::kTVMValueContent: {
CHECK_EQ(t.lanes(), 1);
CHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
} else {
CHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreateInBoundsGEP(buf, index);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
}
}
default: LOG(FATAL) << "unknown field code"; return nullptr;
}
}
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
@ -394,39 +463,33 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
CHECK_EQ(op->args.size(), 5U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
int64_t begin = op->args[3].as<IntImm>()->value;
int64_t end = op->args[4].as<IntImm>()->value;
int64_t nargs = end - begin;
CHECK_GE(nargs, 0);
llvm::Value* stack_value = MakeValue(op->args[1]);
llvm::Value* stack_tcode = MakeValue(op->args[2]);
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
llvm::Value* arg_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(
Int(32), stack_tcode, ConstInt32(end));
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
{handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
Type r_api_type = ir::APIType(r_type);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
@ -649,62 +712,48 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
llvm::Value* ptr = MakeValue(op->args[0]);
return builder_->CreateICmpEQ(
ptr, llvm::Constant::getNullValue(ptr->getType()));
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
CHECK_EQ(op->type.lanes(), 1);
llvm::Value* args = builder_->CreatePointerCast(
MakeValue(op->args[0]), t_tvm_value_->getPointerTo());
llvm::Value* ptr = builder_->CreateInBoundsGEP(
args, MakeValue(op->args[2]));
// always pass via 64 bit pointers
// For handle type, Handle(64) will simply become 32 bit void*
Type value_type = op->type.with_bits(64);
ptr = builder_->CreatePointerCast(
ptr, LLVMType(value_type)->getPointerTo());
llvm::Value* value = builder_->CreateAlignedLoad(ptr, 8);
// cast to the desired type
if (value_type != op->type) {
value = CreateCast(value_type, op->type, value);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
op->type, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
if (kind == intrinsic::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
} else {
return builder_->CreateLoad(ref);
}
return value;
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].type(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
CHECK(kind != intrinsic::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(
value, ref->getType()->getPointerElementType());
}
builder_->CreateStore(value, ref);
return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
llvm::Value* arr = builder_->CreatePointerCast(
MakeValue(op->args[0]), t_tvm_array_->getPointerTo());
llvm::Constant* zero = ConstInt32(0);
llvm::Value* ret = nullptr;
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break;
}
case intrinsic::kShape: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(4)}); break;
}
case intrinsic::kStrides: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(5)}); break;
}
case intrinsic::kNDim: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
}
case intrinsic::kTypeCode: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(3), ConstInt32(0)}); break;
}
case intrinsic::kTypeBits: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(3), ConstInt32(1)}); break;
}
case intrinsic::kTypeLanes: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(3), ConstInt32(2)}); break;
}
case intrinsic::kByteOffset: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(6)}); break;
}
default: LOG(FATAL) << "unknown field code";
const std::string& type = op->args[0].as<StringImm>()->value;
llvm::Value* num = MakeValue(op->args[1]);
if (type == "shape") {
return builder_->CreateAlloca(t_tvm_shape_index_, num);
} else if (type == "arg_value") {
return builder_->CreateAlloca(t_tvm_value_, num);
} else if (type == "arg_tcode") {
return builder_->CreateAlloca(t_int_, num);
} else if (type == "array") {
return builder_->CreateAlloca(t_tvm_array_, num);
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
return builder_->CreateLoad(ret);
} else if (op->is_intrinsic(Call::null_handle)) {
return llvm::Constant::getNullValue(t_void_p_);
} else {
LOG(FATAL) << "Unknown intrinstic " << op->name;
}
@ -1180,9 +1229,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
@ -1194,7 +1242,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
}
}
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
@ -1263,8 +1310,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
this->VisitStmt(op->body);
} else {
this->VisitStmt(op->body);
}
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {

Просмотреть файл

@ -190,6 +190,7 @@ class CodeGenLLVM :
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Vector concatenation.

Просмотреть файл

@ -37,7 +37,7 @@ void InitializeLLVM() {
}
std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str) {
GetLLVMTarget(const std::string& target_str) {
// setup target triple
std::string target_triple;
CHECK_EQ(target_str.substr(0, 4), "llvm");

Просмотреть файл

@ -55,7 +55,7 @@ void InitializeLLVM();
* \return Pair of target machine and target triple.
*/
std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str);
GetLLVMTarget(const std::string& target_str);
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -94,7 +94,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
void Init(const Array<LoweredFunc>& funcs, std::string target) {
InitializeLLVM();
std::tie(tm_, target_triple_) = LLVMGetTarget(target);
std::tie(tm_, target_triple_) = GetLLVMTarget(target);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;

Просмотреть файл

@ -70,25 +70,6 @@ int CodeGenStackVM::AllocVarID(const Variable* v) {
return vid;
}
void CodeGenStackVM::PushCallPacked(
int fid, const std::vector<int>& arg_type_codes) {
StackVM::Code code;
// CALL_PACKED_FUNC
code.op_code = StackVM::CALL_PACKED_FUNC;
vm_.code.push_back(code);
// num_args
code.v_int = static_cast<int>(arg_type_codes.size());
vm_.code.push_back(code);
// fid
code.v_int = fid;
vm_.code.push_back(code);
// type codes.
for (int tcode : arg_type_codes) {
code.v_int = tcode;
vm_.code.push_back(code);
}
}
int CodeGenStackVM::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
@ -97,26 +78,33 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
}
void CodeGenStackVM::VisitExpr_(const Load* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
if (op->type == UInt(32) && op->index.as<IntImm>()) {
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
if (const IntImm* index = op->index.as<IntImm>()) {
this->PushOp(code, op->index.as<IntImm>()->value);
} else {
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->PushOp(StackVM::GetLoad(Type2TVMType(op->type)));
this->PushOp(code, 0);
}
}
void CodeGenStackVM::VisitStmt_(const Store* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->Push(op->value);
this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type())));
this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type()));
if (const IntImm* index = op->index.as<IntImm>()) {
this->Push(op->value);
this->PushOp(code, op->index.as<IntImm>()->value);
} else {
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->Push(op->value);
this->PushOp(code, 0);
}
}
void CodeGenStackVM::VisitStmt_(const Allocate* op) {
@ -141,41 +129,29 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
} else if (op->is_intrinsic(Call::null_handle)) {
this->PushOp(StackVM::PUSH_I64, 0);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
if (op->type.is_handle()) {
this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE);
} else if (op->type.is_float()) {
this->PushOp(StackVM::TVM_LOAD_ARG_FP64);
} else if (op->type.is_int() || op->type.is_uint()) {
this->PushOp(StackVM::TVM_LOAD_ARG_INT64);
} else {
LOG(FATAL) << "donot know how to handle type" << op->type;
}
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
this->Push(op->args[0]);
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
case intrinsic::kByteOffset: PushOp(StackVM::TVM_ARRAY_GET_BYTE_OFFSET); break;
default: LOG(FATAL) << "unknown field code";
}
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
CHECK_GE(op->args.size(), 1U);
const IntImm* index = op->args[1].as<IntImm>();
CHECK(index != nullptr);
StackVM::Code code;
code.op_code = StackVM::TVM_STRUCT_GET;
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = kind;
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U);
const StringImm* s = op->args[0].as<StringImm>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
for (size_t i = 1; i < op->args.size(); ++i) {
this->Push(op->args[i]);
}
this->Push(op->args[1]);
this->Push(op->args[2]);
int begin = op->args[3].as<IntImm>()->value;
int end = op->args[4].as<IntImm>()->value;
// find the fuction id.
const std::string& func_name = s->value;
auto it = extern_fun_idmap_.find(func_name);
@ -187,16 +163,39 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
vm_.extern_func_name.push_back(func_name);
extern_fun_idmap_[func_name] = fid;
}
// get the argument type code.
std::vector<int> arg_type_codes;
for (size_t i = 1; i < op->args.size(); ++i) {
Type t = op->args[i].type();
int code = t.code();
int lanes = t.lanes();
CHECK_EQ(lanes, 1);
arg_type_codes.push_back(code);
// CALL_PACKED_FUNC
StackVM::Code code;
code.op_code = StackVM::CALL_PACKED_LOWERED;
vm_.code.push_back(code);
code.v_int = fid;
vm_.code.push_back(code);
code.v_int = begin;
vm_.code.push_back(code);
code.v_int = end;
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value;
const IntImm* num = op->args[1].as<IntImm>();
CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
size_t unit = sizeof(TVMValue);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
} else if (type == "arg_value") {
size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
} else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(TVMArray) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PushCallPacked(fid, arg_type_codes);
// add stack size to be safe.
vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
this->Push(op->args[0]);
@ -389,10 +388,26 @@ void CodeGenStackVM::VisitStmt_(const Block *op) {
if (op->rest.defined()) this->Push(op->rest);
}
void CodeGenStackVM::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
this->Push(op->value);
this->PushOp(StackVM::POP);
void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
if (is_const(ev->value)) return;
const Call* op = ev->value.as<Call>();
if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
this->Push(op->args[0]);
this->Push(op->args[3]);
const IntImm* index = op->args[1].as<IntImm>();
CHECK(index != nullptr);
StackVM::Code code;
code.op_code = StackVM::TVM_STRUCT_SET;
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImm>()->value;
vm_.code.push_back(code);
} else {
this->Push(ev->value);
this->PushOp(StackVM::POP);
}
}
void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {

Просмотреть файл

@ -55,13 +55,6 @@ class CodeGenStackVM
* \return operand_index, indicating location of operand
*/
int64_t PushOp(StackVM::OpCode opcode, int operand);
/*!
* \brief Push a call packed function.
* \param fid The function id.
* \param arg_type_codes The type codes of arguments.
*/
void PushCallPacked(int fid,
const std::vector<int>& arg_type_codes);
/*!
* \brief Set the relative jump offset to be offset.
* \param operand_index The indexed returned by PushOp.

Просмотреть файл

@ -4,6 +4,7 @@
* \file stack_vm.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/ir.h>
#include "./stack_vm.h"
namespace tvm {
@ -21,58 +22,50 @@ StackVM::State* StackVM::ThreadLocalState() {
sp -= 1; pc += 1; \
}
#define STACK_VM_CMPOP(OP, FIELD) \
{ \
#define STACK_VM_CMPOP(OP, FIELD) \
{ \
stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
sp -= 1; pc += 1; \
sp -= 1; pc += 1; \
}
#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
{ \
stack[sp].FIELD = static_cast<DST_TYPE>( \
*static_cast<SRC_TYPE*>(stack[sp].v_handle)); \
pc += 1; \
int index = code[pc + 1].v_int; \
stack[sp]FIELD = static_cast<DST_TYPE>( \
static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \
pc += 2; \
}
#define STACK_VM_STORE(FIELD, DST_TYPE) \
{ \
*static_cast<DST_TYPE*>(stack[sp - 1].v_handle) = \
static_cast<DST_TYPE>(stack[sp].FIELD); \
sp -= 2; pc += 1; \
int index = code[pc + 1].v_int; \
static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \
static_cast<DST_TYPE>(stack[sp]FIELD); \
sp -= 2; pc += 2; \
}
#define STACK_VM_TVM_LOAD_ARG(OP, TYPE) \
{ \
TVMValue* args = static_cast<TVMValue*>(stack[sp - 2].v_handle); \
int64_t index = stack[sp].v_int64; \
int tc = static_cast<int*>(stack[sp - 1].v_handle)[index]; \
CHECK(OP) \
<< " argument " << index << " is expected to be " << TYPE; \
stack[sp - 2] = args[index]; \
sp -= 2; \
pc += 1; \
}
#define STACK_VM_TVM_ARRARY_GET(FIELD, TYPE, SFIELD) \
{ \
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle); \
stack[sp].FIELD = (TYPE)(arr->SFIELD); \
pc += 1; \
}
#define STACK_VM_PRINT_CODE0(CODE) \
case CODE: { \
#define STACK_VM_PRINT_CODE0(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \
}
#define STACK_VM_PRINT_CODE1(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
<< "[" << pc + 1 << "]" << std::endl; \
<< "[" << pc + 1 << "]" << std::endl; \
return pc + 2; \
}
#define STACK_VM_PRINT_CODE2(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE \
<< " " << code[pc + 1].v_int \
<< " " << code[pc + 2].v_int << "\n" \
<< "[" << pc + 1 << "]" << std::endl \
<< "[" << pc + 2 << "]" << std::endl; \
return pc + 3; \
}
#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \
@ -110,13 +103,18 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
STACK_VM_PRINT_CODE0(LT_F64);
STACK_VM_PRINT_CODE0(LE_F64);
// addressing load
STACK_VM_PRINT_CODE0(ADDR_LOAD_UINT32);
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT32);
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT64);
STACK_VM_PRINT_CODE0(ADDR_LOAD_FP64);
STACK_VM_PRINT_CODE0(ADDR_LOAD_HANDLE);
STACK_VM_PRINT_CODE0(ADDR_STORE_INT64);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT32);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT64);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_FP64);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_HANDLE);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_TVMVALUE);
STACK_VM_PRINT_CODE1(ARRAY_STORE_UINT32);
STACK_VM_PRINT_CODE1(ARRAY_STORE_INT32);
STACK_VM_PRINT_CODE1(ARRAY_STORE_INT64);
STACK_VM_PRINT_CODE1(ARRAY_STORE_FP64);
STACK_VM_PRINT_CODE1(ARRAY_STORE_HANDLE);
STACK_VM_PRINT_CODE1(ARRAY_STORE_TVMVALUE);
STACK_VM_PRINT_CODE0(NOT);
STACK_VM_PRINT_CODE0(ADDR_ADD);
// stack ops
@ -132,32 +130,24 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
STACK_VM_PRINT_JUMP(RJUMP);
STACK_VM_PRINT_CODE1(ASSERT_SP);
// Intrinsics
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_INT64);
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_FP64);
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_HANDLE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_DATA);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_BYTE_OFFSET);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
STACK_VM_PRINT_CODE2(TVM_STRUCT_GET);
STACK_VM_PRINT_CODE2(TVM_STRUCT_SET);
// Allocate data by 8 bytes.
STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE);
// packed function.
case CALL_PACKED_FUNC: {
int num_args = code[pc + 1].v_int;
case CALL_PACKED_LOWERED: {
int call_fid = code[pc + 1].v_int;
int begin = code[pc + 2].v_int;
int end = code[pc + 3].v_int;
os << "[" << pc << "]\tCALL_PACKED_FUNC "
<< " num_args=" << num_args
<< " fid=" << code[pc + 2].v_int;
os << " type_codes:";
for (int i = 0; i < num_args; ++i) {
os << ' ' << code[pc + 3 + i].v_int;
}
<< " fid=" << call_fid
<< " begin=" << begin
<< " end=" << end;
os << '\n';
for (int i = 0; i < num_args + 2; ++i) {
os << "[" << pc + 1 << "]" << std::endl;
for (int i = 0; i < 3; ++i) {
os << "[" << pc + 1 + i << "]" << std::endl;
}
return pc + 3 + num_args;
return pc + 4;
}
}
LOG(FATAL) << "unknown op code " << code[pc].op_code;
@ -193,6 +183,7 @@ void StackVM::operator()(const runtime::TVMArgs& args) const {
void StackVM::Run(State* s) const {
int64_t sp = s->sp;
int64_t pc = s->pc;
int64_t alloca_sp = s->sp;
std::vector<TVMValue>& stack = s->stack;
std::vector<TVMValue>& heap = s->heap;
s->extern_func.clear();
@ -223,23 +214,26 @@ void StackVM::Run(State* s) const {
case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
// addressing
case ADDR_LOAD_UINT32: STACK_VM_LOAD(v_int64, int64_t, uint32_t); break;
case ADDR_LOAD_INT32: STACK_VM_LOAD(v_int64, int64_t, int32_t); break;
case ADDR_LOAD_INT64: STACK_VM_LOAD(v_int64, int64_t, int64_t); break;
case ADDR_LOAD_FP64: STACK_VM_LOAD(v_float64, double, double); break;
case ADDR_LOAD_HANDLE: STACK_VM_LOAD(v_handle, void*, void*); break;
case ADDR_STORE_INT64: STACK_VM_STORE(v_int64, int64_t); break;
case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break;
case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break;
case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break;
case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break;
case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break;
case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break;
// store
case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break;
case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break;
case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break;
case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break;
case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break;
case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break;
// add
case ADDR_ADD: {
stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
sp = sp - 1;
pc = pc + 1;
break;
}
case ARRAY_LOAD_UINT32: {
stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int]; // NOLINT(*)
pc = pc + 2;
break;
}
case NOT: {
stack[sp].v_int64 = !stack[sp].v_int64;
pc += 1;
@ -282,21 +276,6 @@ void StackVM::Run(State* s) const {
pc += 2;
break;
}
case CALL_PACKED_FUNC: {
// call packed function.
int num_args = code[pc + 1].v_int;
int call_fid = code[pc + 2].v_int;
static_assert(sizeof(Code) == sizeof(int) &&
alignof(Code) == alignof(int), "asusmption");
const int* type_codes = &(code[pc].v_int) + 3;
runtime::TVMRetValue rv;
GetExtern(s, call_fid).CallPacked(
runtime::TVMArgs(&stack[sp + 1 - num_args], type_codes, num_args), &rv);
sp = sp + 1 - num_args;
stack[sp] = rv.value();
pc += 3 + num_args;
break;
}
case ASSERT: {
CHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
sp -= 1;
@ -331,41 +310,145 @@ void StackVM::Run(State* s) const {
pc += 2;
break;
}
case TVM_LOAD_ARG_INT64: {
STACK_VM_TVM_LOAD_ARG(tc == kInt, "int"); break;
case CALL_PACKED_LOWERED: {
// call packed function.
TVMValue* value_stack = static_cast<TVMValue*>(stack[sp - 1].v_handle);
int* type_stack = static_cast<int*>(stack[sp].v_handle);
int call_fid = code[pc + 1].v_int;
int begin = code[pc + 2].v_int;
int end = code[pc + 3].v_int;
int num_args = end - begin;
static_assert(sizeof(Code) == sizeof(int) &&
alignof(Code) == alignof(int), "asusmption");
runtime::TVMRetValue rv;
GetExtern(s, call_fid).CallPacked(
runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
sp = sp - 1;
stack[sp] = rv.value();
pc += 4;
break;
}
case TVM_LOAD_ARG_FP64: {
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
// intrinsics
case TVM_STRUCT_GET: {
using namespace ir;
int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int;
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle);
switch (kind) {
case intrinsic::kArrData: {
stack[sp].v_handle = arr[index].data; break;
}
case intrinsic::kArrShape: {
stack[sp].v_handle = arr[index].shape; break;
}
case intrinsic::kArrStrides: {
stack[sp].v_handle = arr[index].strides; break;
}
case intrinsic::kArrNDim: {
stack[sp].v_int64 = arr[index].ndim; break;
}
case intrinsic::kArrTypeCode: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.code); break;
}
case intrinsic::kArrTypeBits: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.bits); break;
}
case intrinsic::kArrTypeLanes: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.lanes); break;
}
case intrinsic::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break;
break;
}
case intrinsic::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break;
}
case intrinsic::kArrDeviceType: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].ctx.device_type); break;
}
case intrinsic::kArrAddr: {
stack[sp].v_handle = arr + index; break;
}
case intrinsic::kTVMValueContent: {
stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
}
default: LOG(FATAL) << "unhandled get " << kind;
}
pc = pc + 3;
break;
}
case TVM_LOAD_ARG_HANDLE: {
STACK_VM_TVM_LOAD_ARG(
tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break;
case TVM_STRUCT_SET: {
using namespace ir;
int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int;
TVMArray* arr = static_cast<TVMArray*>(stack[sp - 1].v_handle);
switch (kind) {
case intrinsic::kArrData: {
arr[index].data = stack[sp].v_handle; break;
}
case intrinsic::kArrShape: {
arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
break;
}
case intrinsic::kArrStrides: {
arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
break;
}
case intrinsic::kArrNDim: {
arr[index].ndim = static_cast<int>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeCode: {
arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeBits: {
arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeLanes: {
arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrByteOffset: {
arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrDeviceId: {
arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
break;
}
case intrinsic::kArrDeviceType: {
arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
break;
}
case intrinsic::kTVMValueContent: {
static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
}
default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
}
sp -= 2;
pc += 3;
break;
}
case TVM_ARRAY_GET_DATA: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
}
case TVM_ARRAY_GET_SHAPE: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, shape); break;
}
case TVM_ARRAY_GET_STRIDES: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, strides); break;
}
case TVM_ARRAY_GET_NDIM: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
}
case TVM_ARRAY_GET_BYTE_OFFSET: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, byte_offset); break;
}
case TVM_ARRAY_GET_TYPE_CODE: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
}
case TVM_ARRAY_GET_TYPE_BITS: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
}
case TVM_ARRAY_GET_TYPE_LANES: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.lanes); break;
// alloca
case TVM_STACK_ALLOCA_BY_8BYTE: {
static_assert(sizeof(TVMValue) == 8, "invariance");
int num = code[pc + 1].v_int;
void* addr = &stack[sp] + 1;
sp = sp + num + 1;
alloca_sp = sp - 1;
stack[sp].v_handle = addr;
pc = pc + 2;
break;
}
}
CHECK_GE(sp, alloca_sp) << "touch allocated space";
CHECK_LT(sp, stack_cap) << "Stack overflow";
}
}

Просмотреть файл

@ -55,24 +55,33 @@ class StackVM {
EQ_F64,
LT_F64,
LE_F64,
// load operation
ADDR_LOAD_UINT32,
ADDR_LOAD_INT32,
ADDR_LOAD_INT64,
ADDR_LOAD_FP64,
ADDR_LOAD_HANDLE,
// store operations
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
// sp = sp - 2;
ADDR_STORE_INT64,
/*!
* \brief Quick routine to load uint32 from constant offset.
* \brief Routine to load data from address with const offset.
* \code
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int];
* stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
* pc = pc + 2;
* \endcode
*/
ARRAY_LOAD_UINT32,
ARRAY_LOAD_INT32,
ARRAY_LOAD_INT64,
ARRAY_LOAD_FP64,
ARRAY_LOAD_HANDLE,
ARRAY_LOAD_TVMVALUE,
/*!
* \brief Routine to store data from constant offset.
* \code
* ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
* pc = pc + 2;
* sp = sp - 2;
* \endcode
*/
ARRAY_STORE_UINT32,
ARRAY_STORE_INT32,
ARRAY_STORE_INT64,
ARRAY_STORE_FP64,
ARRAY_STORE_HANDLE,
ARRAY_STORE_TVMVALUE,
// logical ops
NOT,
/*!
@ -128,20 +137,6 @@ class StackVM {
* \endcode
*/
SELECT,
/*!
* \brief call an extern packed function
* \code
* num_args = stack[sp].v_int64;
* call_fid = code[pc + 1].v_int;
* f = extern_func[call_fid];
* int* type_codes = &(code[pc + 2].v_int)
* stack[sp - num_args] = f(&stack[sp - num_args], type_codes, num_args);
* sp = sp - num_args;
* // The type codes are hidden in the code space.
* pc = pc + 2 + num_args
* \endcode
*/
CALL_PACKED_FUNC,
/*!
* \brief Assert condition is true.
* \code
@ -189,18 +184,56 @@ class StackVM {
* \code
*/
ASSERT_SP,
// Intrinsics for API function,
TVM_LOAD_ARG_INT64,
TVM_LOAD_ARG_FP64,
TVM_LOAD_ARG_HANDLE,
TVM_ARRAY_GET_DATA,
TVM_ARRAY_GET_SHAPE,
TVM_ARRAY_GET_STRIDES,
TVM_ARRAY_GET_NDIM,
TVM_ARRAY_GET_TYPE_CODE,
TVM_ARRAY_GET_TYPE_BITS,
TVM_ARRAY_GET_TYPE_LANES,
TVM_ARRAY_GET_BYTE_OFFSET
/*!
* \brief call an extern packed function
* \code
* value_stack = stack[sp - 1].v_handle;
* type_stack = stack[sp - 0].v_handle;
* call_fid = code[pc + 1].v_int;
* begin = code[pc + 2].v_int;
* end = code[pc + 3].v_int;
* num_args = end - begin - 1;
* f = extern_func[call_fid];
* stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
* sp = sp - 1;
* // The type codes are hidden in the code space.
* pc = pc + 4
* \endcode
*/
CALL_PACKED_LOWERED,
// Allocate things on stack
/*!
* \brief allocate data from stack.
* \code
* num = code[pc + 1].v_int;
* void* addr = &stack[sp];
* sp = sp + num;
* stack[sp].v_handle = addr;
* pc = pc + 1;
* \endcode
*/
TVM_STACK_ALLOCA_BY_8BYTE,
/*!
* \brief get data from structure.
* \code
* index = code[pc + 1].v_int;
* field = code[pc + 2].v_int;
* stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
* pc = pc + 3
* \endcode
*/
TVM_STRUCT_GET,
/*!
* \brief set data into structure.
* \code
* index = code[pc + 1].v_int;
* field = code[pc + 2].v_int;
* ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
* pc = pc + 3
* sp = sp - 1
* \endcode
*/
TVM_STRUCT_SET
};
/*! \brief The code structure */
union Code {
@ -276,23 +309,23 @@ class StackVM {
*/
static OpCode GetLoad(TVMType t) {
CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ADDR_LOAD_HANDLE;
if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kInt) {
switch (t.bits) {
case 32 : return ADDR_LOAD_INT32;
case 64 : return ADDR_LOAD_INT64;
case 32 : return ARRAY_LOAD_INT32;
case 64 : return ARRAY_LOAD_INT64;
}
} else if (t.code == kUInt) {
switch (t.bits) {
case 32 : return ADDR_LOAD_UINT32;
case 32 : return ARRAY_LOAD_UINT32;
}
} else if (t.code == kFloat) {
switch (t.bits) {
case 64 : return ADDR_LOAD_FP64;
case 64 : return ARRAY_LOAD_FP64;
}
}
LOG(FATAL) << "Cannot load type " << t;
return ADDR_LOAD_FP64;
return ARRAY_LOAD_FP64;
}
/*!
* \brief Get store opcode for type t
@ -301,13 +334,23 @@ class StackVM {
*/
static OpCode GetStore(TVMType t) {
CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_STORE_HANDLE;
if (t.code == kInt) {
switch (t.bits) {
case 64 : return ADDR_STORE_INT64;
case 32 : return ARRAY_STORE_INT32;
case 64 : return ARRAY_STORE_INT64;
}
} else if (t.code == kUInt) {
switch (t.bits) {
case 32 : return ARRAY_STORE_UINT32;
}
} else if (t.code == kFloat) {
switch (t.bits) {
case 64 : return ARRAY_STORE_FP64;
}
}
LOG(FATAL) << "Cannot store type " << t;
return ADDR_LOAD_FP64;
return ARRAY_STORE_FP64;
}
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)

Просмотреть файл

@ -85,6 +85,70 @@ inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
return body;
}
/*!
* \brief Get construct from struct
* \param dtype The data type.
* \param handle the struct handle.
* \param index the offset index.
* \param kind The data kind.
* \return the get expression.
*/
inline Expr TVMStructGet(
Type dtype, Var handle, int index,
intrinsic::TVMStructFieldKind kind) {
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind)};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
}
/*!
* \brief Address of handle + offset
* \param handle the array handle.
* \param dtype The data type.
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, int offset) {
return Call::make(
Handle(), Call::address_of,
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic);
}
/*!
* \brief Set value into struct.
* \param handle the struct handle.
* \param index the offset index.
* \param kind The data kind.
* \param value The value to be set.
* \return the set stmt.
*/
inline Stmt TVMStructSet(
Var handle, int index,
intrinsic::TVMStructFieldKind kind, Expr value) {
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind),
value};
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
}
/*!
* \brief Get the type that is passed around TVM PackedFunc API.
* \param t The original type.
* \return The corresponding API type.
*/
inline Type APIType(Type t) {
if (t.is_handle()) return t;
CHECK_EQ(t.lanes(), 1)
<< "Cannot pass vector type through packed API.";
if (t.is_uint() || t.is_int()) return Int(64);
CHECK(t.is_float());
return Float(64);
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_

Просмотреть файл

@ -0,0 +1,231 @@
/*!
* Copyright (c) 2017 by Contributors
* Lower calls to packed function.
* \file lower_packed_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
namespace tvm {
namespace ir {
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
}
inline Expr StackAlloca(std::string type, size_t num) {
Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
}
// Calculate the statistics of packed function.
// These information are needed during codegen.
class PackedCallBuilder : public IRMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", Handle());
stack_array_ = Var("stack_array", Handle());
stack_value_ = Var("stack_value", Handle());
stack_tcode_ = Var("stack_tcode", Handle());
stmt = this->Mutate(stmt);
if (max_shape_stack_ != 0) {
stmt = LetStmt::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
stmt = LetStmt::make(
stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
stmt = LetStmt::make(
stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
stmt = LetStmt::make(
stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0);
CHECK_EQ(run_arg_stack_, 0);
while (prep_seq_.size() != 0) {
stmt = Block::make(prep_seq_.back(), stmt);
prep_seq_.pop_back();
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
} else {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
} else {
return e;
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
ConstInt32(stack_begin +i)));
}
return AddressOffset(stack_shape_, Int(64), stack_begin);
}
// make array
Expr MakeArray(const Call* op, const Expr& e) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
Expr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(Handle());
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
Type dtype = op->args[4].type();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
make_const(UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
make_const(UInt(8), dtype.bits())));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
make_const(UInt(16), dtype.lanes())));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(Int(64), op->args[5])));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
Convert(Int(32), device_id_)));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
Convert(Int(32), device_type_)));
return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packled.
Expr MakeCallPacked(const Call* op, const Expr& e) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
Type t = arg.type();
Type api_type = APIType(t);
if (t != api_type) {
arg = Cast::make(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(
stack_value_, static_cast<int>(arg_stack_begin + i - 1),
intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
prep_seq_.emplace_back(
Store::make(stack_tcode_,
ConstInt32(arg_tcode),
stack_index));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
run_arg_stack_ = arg_stack_begin;
Array<Expr> packed_args = {
op->args[0],
stack_value_,
stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)
};
return Call::make(
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
private:
bool IsArrayHandle(const Expr& arg) {
// specially set array handle.
if (const Call* buf = arg.as<Call>()) {
if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
return true;
}
}
return false;
}
// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
Expr device_type_;
Expr device_id_;
// Var handle for each stack.
Var stack_shape_;
Var stack_array_;
Var stack_tcode_;
Var stack_value_;
// The running statistics
uint64_t run_shape_stack_{0};
uint64_t run_array_stack_{0};
uint64_t run_arg_stack_{0};
// statistics of stacks
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
};
LoweredFunc LowerPackedCall(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = PackedCallBuilder().Build(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm

Просмотреть файл

@ -4,6 +4,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/buffer.h>
#include <vector>
@ -15,11 +16,8 @@
namespace tvm {
namespace ir {
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
return Call::make(
t, intrinsic::tvm_array_get_field,
{arr, IntImm::make(Int(32), kind)},
Call::PureIntrinsic);
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
inline Stmt AssertNull(Var handle, std::string msg) {
@ -55,15 +53,25 @@ LoweredFunc MakeAPI(Stmt body,
std::unordered_set<const Variable*> visited;
// the handle data types
Map<Var, Expr> handle_data_type;
// The device context
Var device_id, device_type;
// ---------------------------
// local function defintiions
// load i-th argument as type t
auto f_arg_value = [&](Type t, int i) {
Array<Expr> call_args{
v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)};
return Call::make(
t, intrinsic::tvm_api_load_arg, call_args,
Array<Expr> call_args{v_packed_args,
IntImm::make(Int(32), i),
IntImm::make(Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
Type api_type = APIType(t);
Expr res = Call::make(
api_type, intrinsic::tvm_struct_get, call_args,
Call::PureIntrinsic);
// cast to the target version.
if (api_type != t) {
res = Cast::make(t, res);
}
return res;
};
// get declaration of argument i
auto f_arg_decl = [&](int i) {
@ -107,8 +115,32 @@ LoweredFunc MakeAPI(Stmt body,
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i);
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt::make(
v_arg, f_arg_value(v_arg.type(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", Int(32));
seq_init.emplace_back(LetStmt::make(
tcode, Load::make(
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop));
Type t = v_arg.type();
if (t.is_handle()) {
std::ostringstream msg;
msg << "Expect argument " << i << " to be pointer";
seq_check.emplace_back(
AssertStmt::make(tcode == kHandle ||
tcode == kArrayHandle ||
tcode == kNull, msg.str()));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << "Expect argument " << i << " to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str()));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << "Expect argument " << i << " to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str()));
}
} else {
args.push_back(v_arg);
}
@ -121,7 +153,7 @@ LoweredFunc MakeAPI(Stmt body,
<< "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kNDim);
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kArrNDim);
std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal "
@ -135,15 +167,15 @@ LoweredFunc MakeAPI(Stmt body,
Type dtype = buf->dtype;
std::ostringstream type_err_msg;
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) ==
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) ==
TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) ==
TVMArrayGet(UInt(16), v_arg, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kArrData),
v_arg->name_hint + ".data")) {
Var vptr(buf->data);
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
@ -152,20 +184,22 @@ LoweredFunc MakeAPI(Stmt body,
Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_shape, IntImm::make(Int(32), k))),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k))),
field_name.str());
}
// strides field
Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kArrStrides),
nop));
if (buf->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << "arg_" << i << ".strides:"
@ -177,13 +211,22 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_strides, IntImm::make(Int(32), k))),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k))),
field_name.str());
}
}
// Byte_offset field.
f_push(buf->byte_offset, TVMArrayGet(UInt(64), v_arg, intrinsic::kByteOffset),
f_push(buf->byte_offset,
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset");
// device info.
f_push(device_id,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
v_arg->name_hint + ".device_id");
f_push(device_type,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceType),
v_arg->name_hint + ".device_type");
}
}
@ -192,6 +235,16 @@ LoweredFunc MakeAPI(Stmt body,
n->args = args;
n->handle_data_type = handle_data_type;
n->is_packed_func = num_unpacked_args == 0;
// Set device context
if (visited.count(device_id.get())) {
Expr node = StringImm::make("default");
CHECK(visited.count(device_type.get()));
seq_init.push_back(AttrStmt::make(
node, attr::device_context_id, device_id, nop));
seq_init.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop));
}
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f->body, f->args);

Просмотреть файл

@ -70,6 +70,17 @@ class StorageAccessPatternFinder : public IRVisitor {
linear_seq_.push_back(e);
}
}
void Visit_(const Evaluate* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void Visit_(const Load* op) final {
// Add write access.
IRVisitor::Visit_(op);
@ -79,14 +90,14 @@ class StorageAccessPatternFinder : public IRVisitor {
CHECK_LT(it->second, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kRead, GetScope(buf)));
AccessEntry(buf, op->index, kRead, GetScope(buf)));
}
}
void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size());
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second].access.emplace_back(
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
}

Просмотреть файл

@ -25,7 +25,8 @@ def test_add_pipeline():
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(host):

Просмотреть файл

@ -34,5 +34,78 @@ def test_add_pipeline():
check_llvm()
def test_pack_buffer_simple():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.call_packed("my_extern_array_func1", ins[0], outs[0])
C = tvm.extern(A.shape, [A], extern_generator, name='C')
s = tvm.create_schedule(C.op)
@tvm.register_func
def my_extern_array_func1(aa, bb):
aa.copyto(bb)
def check_target(target):
if not tvm.codegen.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy())
check_target("stackvm")
check_target("llvm")
def test_pack_buffer_intermediate():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.call_packed("my_extern_array_func2", ins[0], outs[0])
C = tvm.extern(B.shape, [B], extern_generator, name='C')
s = tvm.create_schedule(C.op)
def check_target(target):
if not tvm.codegen.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
@tvm.register_func
def my_extern_array_func2(aa, bb):
assert aa.shape == a.shape
np.testing.assert_allclose(
aa.asnumpy(), a.asnumpy() + 1)
aa.copyto(bb)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1)
check_target("llvm")
if __name__ == "__main__":
test_pack_buffer_simple()
test_pack_buffer_intermediate()
test_add_pipeline()

Просмотреть файл

@ -9,7 +9,6 @@ def run_jit(fapi, check):
s = f.get_source()
check(f)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
@ -21,6 +20,7 @@ def test_stack_vm_basic():
Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
run_jit(fapi, lambda f: f(a))
@ -42,6 +42,7 @@ def test_stack_vm_loop():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
@ -64,6 +65,7 @@ def test_stack_vm_cond():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)

Просмотреть файл

@ -38,7 +38,8 @@ def test_add_pipeline():
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
print(fsplits[1].body)
print("------")