[CODEGEN/PASS] Improve callpacked lowering, allow pass array callback. (#110)
* [CODEGEN/PASS] Improve callpacked lowering, allow pass array callback. * fix cython
This commit is contained in:
Родитель
d45b6d4b84
Коммит
9ba40dc0fe
|
@ -46,6 +46,7 @@ tvm.ir_pass
|
|||
tvm.ir_pass.VectorizeLoop
|
||||
tvm.ir_pass.UnrollLoop
|
||||
tvm.ir_pass.StorageSync
|
||||
tvm.ir_pass.StorageRewrite
|
||||
tvm.ir_pass.MakeAPI
|
||||
tvm.ir_pass.SplitHostDevice
|
||||
tvm.ir_pass.InjectVirtualThread
|
||||
|
@ -53,6 +54,8 @@ tvm.ir_pass
|
|||
tvm.ir_pass.RemoveNoOp
|
||||
tvm.ir_pass.SplitPipeline
|
||||
tvm.ir_pass.LowerThreadAllreduce
|
||||
tvm.ir_pass.LowerIntrin
|
||||
tvm.ir_pass.LowerPackedCall
|
||||
tvm.ir_pass.NarrowChannelAccess
|
||||
|
||||
.. automodule:: tvm.ir_pass
|
||||
|
|
|
@ -245,7 +245,6 @@ Expr max(Expr source, Array<IterVar> axis);
|
|||
*/
|
||||
Expr min(Expr source, Array<IterVar> axis);
|
||||
|
||||
|
||||
// print functions for expr
|
||||
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
|
||||
|
||||
|
|
108
include/tvm/ir.h
108
include/tvm/ir.h
|
@ -140,6 +140,10 @@ constexpr const char* volatile_scope = "volatile_scope";
|
|||
constexpr const char* storage_scope = "storage_scope";
|
||||
/*! \brief Mark storage scope of realization */
|
||||
constexpr const char* realize_scope = "realize_scope";
|
||||
/*! \brief The allocation context for global malloc in host. */
|
||||
constexpr const char* device_context_id = "device_context_id";
|
||||
/*! \brief The device type. */
|
||||
constexpr const char* device_context_type = "device_context_type";
|
||||
/*! \brief Mark of loop scope */
|
||||
constexpr const char* loop_scope = "loop_scope";
|
||||
/*! \brief Mark of reduce scope */
|
||||
|
@ -167,25 +171,24 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
|
|||
|
||||
/*! \brief namespace of TVM Intrinsic functions */
|
||||
namespace intrinsic {
|
||||
// Most of the intrinsics is to enab
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
|
||||
* assert(arg_type_id[i] == typeid(Type));
|
||||
* return args[i];
|
||||
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
|
||||
* return arr[index]->field;
|
||||
* }
|
||||
* \sa TVMStructFieldKind
|
||||
*/
|
||||
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
|
||||
constexpr const char* tvm_struct_get = "tvm_struct_get";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
|
||||
* return arr->field;
|
||||
* Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
|
||||
* arr[index]->field = value;
|
||||
* }
|
||||
* \sa TVMArrayFieldKind
|
||||
* \sa TVMStructFieldKind
|
||||
*/
|
||||
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
|
||||
constexpr const char* tvm_struct_set = "tvm_struct_set";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
|
@ -194,6 +197,48 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field";
|
|||
* }
|
||||
*/
|
||||
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
* dtype in {shape, array, arg_value, arg_tcode}
|
||||
*
|
||||
* Handle tvm_stack_alloca(string dtype, int num) {
|
||||
* return new on stack dtype[num];
|
||||
* }
|
||||
* \sa TVMStructFieldKind
|
||||
*/
|
||||
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
|
||||
/*!
|
||||
* \brief Allocate a shape tuple on stack, return the handle.
|
||||
*
|
||||
* Handle tvm_stack_make_shape(list args) {
|
||||
* ret = alloca stack int64_t[len(args)];
|
||||
* for i in range(len(args)):
|
||||
* ret[i] = args[i]
|
||||
* return &ret[0];
|
||||
* }
|
||||
*/
|
||||
constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
|
||||
/*!
|
||||
* \brief Allocate a NDArray(DLTensor) on stack, return the handle.
|
||||
*
|
||||
* Type tvm_stack_make_array(Expr data,
|
||||
* Expr shape,
|
||||
* Expr strides,
|
||||
* Expr ndim,
|
||||
* Expr dtype,
|
||||
* Expr byte_offset) {
|
||||
* ret = alloca stack DLTensor();
|
||||
* ret->data = data;
|
||||
* ret->shape = shape;
|
||||
* ret->strides = strides != 0 ? strides : nullptr;
|
||||
* ret->ndim = ndim;
|
||||
* ret->dtype = dtype.type();
|
||||
* ret->byte_offset = byte_offset;
|
||||
* return ret;
|
||||
* }
|
||||
*/
|
||||
constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
|
@ -205,6 +250,23 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
|
|||
* }
|
||||
*/
|
||||
constexpr const char* tvm_call_packed = "tvm_call_packed";
|
||||
/*!
|
||||
* \brief Lowered version of call packed, the space of value and
|
||||
* type codes are explicitly allocated.
|
||||
*
|
||||
* int tvm_call_packed_lowered(name,
|
||||
* TVMValue* value_stack,
|
||||
* int* tcode_stack,
|
||||
* int begin,
|
||||
* int end) {
|
||||
* ModuleNode* env = GetCurrentEnv();
|
||||
* const PackedFunc* f = env->GetFuncFromEnv(name);
|
||||
* f->CallPacked(TVMArgs(value_stack[begin:end],
|
||||
* tcode_stack[begin:end]),
|
||||
* TVMRetValue(value_stack + end, tcode_stack + end));
|
||||
* }
|
||||
*/
|
||||
constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
|
@ -231,16 +293,24 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
|
|||
*/
|
||||
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
|
||||
|
||||
/*! \brief The field id of each field in array */
|
||||
enum TVMArrayFieldKind {
|
||||
kData = 0,
|
||||
kNDim = 1,
|
||||
kShape = 2,
|
||||
kStrides = 3,
|
||||
kTypeCode = 4,
|
||||
kTypeBits = 5,
|
||||
kTypeLanes = 6,
|
||||
kByteOffset = 7
|
||||
/*! \brief The kind of structre field info */
|
||||
enum TVMStructFieldKind : int {
|
||||
// array head address
|
||||
kArrAddr,
|
||||
kArrData,
|
||||
kArrShape,
|
||||
kArrStrides,
|
||||
kArrNDim,
|
||||
kArrTypeCode,
|
||||
kArrTypeBits,
|
||||
kArrTypeLanes,
|
||||
kArrByteOffset,
|
||||
kArrDeviceId,
|
||||
kArrDeviceType,
|
||||
kArrKindBound_,
|
||||
// TVMValue field
|
||||
kTVMValueContent,
|
||||
kTVMValueKindBound_
|
||||
};
|
||||
} // namespace intrinsic
|
||||
|
||||
|
|
|
@ -251,6 +251,13 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
|
|||
*/
|
||||
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
|
||||
|
||||
/*!
|
||||
* \brief Lower packed function call.
|
||||
* \param f The function to be lowered.
|
||||
* \return Transformed function.
|
||||
*/
|
||||
LoweredFunc LowerPackedCall(LoweredFunc f);
|
||||
|
||||
/*!
|
||||
* \brief Lower intrinsic function calls.
|
||||
* \param f The device function to be lowered.
|
||||
|
|
|
@ -10,7 +10,7 @@ from numbers import Number, Integral
|
|||
from ..base import _LIB, check_call
|
||||
from ..base import c_str, string_types
|
||||
from ..node_generic import convert_to_node, NodeGeneric
|
||||
from ..ndarray import TVMType, TVMByteArray, NDArrayBase
|
||||
from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array
|
||||
from .types import TVMValue, TypeCode
|
||||
from .types import TVMPackedCFunc, TVMCFuncFinalizer
|
||||
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
|
||||
|
@ -188,6 +188,7 @@ def _handle_return_func(x):
|
|||
handle = FunctionHandle(handle)
|
||||
return _CLASS_FUNCTION(handle, False)
|
||||
|
||||
|
||||
# setup return handle for function type
|
||||
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
|
||||
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
|
||||
|
@ -195,7 +196,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
|
|||
_handle_return_func, TypeCode.FUNC_HANDLE)
|
||||
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
|
||||
_return_module, TypeCode.MODULE_HANDLE)
|
||||
|
||||
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
|
||||
|
||||
_CLASS_MODULE = None
|
||||
_CLASS_FUNCTION = None
|
||||
|
|
|
@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
|
|||
from numbers import Number, Integral
|
||||
from ..base import string_types
|
||||
from ..node_generic import convert_to_node, NodeGeneric
|
||||
from ..ndarray import NDArrayBase, TVMType, TVMByteArray
|
||||
from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array
|
||||
|
||||
print("TVM: Initializing cython mode...")
|
||||
|
||||
|
@ -29,7 +29,10 @@ cdef int tvm_callback(TVMValue* args,
|
|||
tcode == kFuncHandle or
|
||||
tcode == kModuleHandle):
|
||||
CALL(TVMCbArgToReturn(&value, tcode))
|
||||
pyargs.append(make_ret(value, tcode))
|
||||
if tcode != kArrayHandle:
|
||||
pyargs.append(make_ret(value, tcode))
|
||||
else:
|
||||
pyargs.append(_make_array(ctypes_handle(value.v_handle), True))
|
||||
try:
|
||||
rv = local_pyfunc(*pyargs)
|
||||
except Exception:
|
||||
|
@ -64,7 +67,9 @@ def convert_to_tvm_func(object pyfunc):
|
|||
<void*>(pyfunc),
|
||||
tvm_callback_finalize,
|
||||
&chandle))
|
||||
return _CLASS_FUNCTION(ctypes_handle(chandle), False)
|
||||
ret = _CLASS_FUNCTION(None, False)
|
||||
(<FunctionBase>ret).chandle = chandle
|
||||
return ret
|
||||
|
||||
|
||||
cdef inline void make_arg(object arg,
|
||||
|
|
|
@ -198,9 +198,9 @@ def sync(ctx):
|
|||
|
||||
class NDArrayBase(object):
|
||||
"""A simple Device/CPU Array object in runtime."""
|
||||
__slots__ = ["handle"]
|
||||
__slots__ = ["handle", "is_view"]
|
||||
# pylint: disable=no-member
|
||||
def __init__(self, handle):
|
||||
def __init__(self, handle, is_view=False):
|
||||
"""Initialize the function with handle
|
||||
|
||||
Parameters
|
||||
|
@ -209,9 +209,11 @@ class NDArrayBase(object):
|
|||
the handle to the underlying C++ TVMArray
|
||||
"""
|
||||
self.handle = handle
|
||||
self.is_view = is_view
|
||||
|
||||
def __del__(self):
|
||||
check_call(_LIB.TVMArrayFree(self.handle))
|
||||
if not self.is_view:
|
||||
check_call(_LIB.TVMArrayFree(self.handle))
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -302,6 +304,10 @@ class NDArrayBase(object):
|
|||
raise ValueError("Unsupported target type %s" % str(type(target)))
|
||||
return target
|
||||
|
||||
def _make_array(handle, is_view):
|
||||
handle = ctypes.cast(handle, TVMArrayHandle)
|
||||
return _CLASS_NDARRAY(handle, is_view)
|
||||
|
||||
_CLASS_NDARRAY = None
|
||||
|
||||
def _set_class_ndarray(cls):
|
||||
|
|
|
@ -146,7 +146,7 @@ def build(sch,
|
|||
warp_size = 32 if target == "cuda" else 1
|
||||
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
|
||||
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
|
||||
|
||||
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
|
||||
if len(fsplits) > 1:
|
||||
if not target_host:
|
||||
target_host = "llvm" if codegen.enabled("llvm") else "stackvm"
|
||||
|
|
|
@ -1,21 +1,56 @@
|
|||
"""Intrinsics and math functions in TVM."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from .expr import Call as _Call
|
||||
from . import make as _make
|
||||
from ._ffi.function import register_func as _register_func
|
||||
from .api import convert
|
||||
from . import make as _make
|
||||
from .api import convert, const
|
||||
from .expr import Call as _Call
|
||||
from .schedule import Buffer as _Buffer
|
||||
|
||||
def _pack_buffer(buf):
|
||||
"""Build intrinsics that packs the buffer.
|
||||
"""
|
||||
assert buf.shape
|
||||
shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape,
|
||||
_Call.Intrinsic, None, 0)
|
||||
strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides,
|
||||
_Call.Intrinsic, None, 0) if buf.strides else 0
|
||||
pack_args = [buf.data,
|
||||
shape,
|
||||
strides,
|
||||
len(buf.shape),
|
||||
const(0, dtype=buf.dtype),
|
||||
buf.byte_offset]
|
||||
return _make.Call("handle", "tvm_stack_make_array",
|
||||
pack_args, _Call.Intrinsic, None, 0)
|
||||
|
||||
def call_packed(*args):
|
||||
"""Build expression by call an external packed function
|
||||
"""Build expression by call an external packed function.
|
||||
|
||||
The argument to packed function can be Expr or Buffer.
|
||||
The argument is the corresponding POD type when Expr is presented.
|
||||
|
||||
When the argument is Buffer, the corresponding PackedFunc
|
||||
will recieve an TVMArrayHandle whose content is valid during the callback period.
|
||||
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : list
|
||||
args : list of Expr or Buffer.
|
||||
Positional arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
call : Expr
|
||||
The call expression.
|
||||
|
||||
See Also
|
||||
--------
|
||||
tvm.extern : Create tensor with extern function call.
|
||||
"""
|
||||
call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
|
||||
return _make.Call(
|
||||
"int32", "tvm_call_packed", args, _Call.Intrinsic, None, 0)
|
||||
"int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0)
|
||||
|
||||
|
||||
def call_pure_intrin(dtype, func_name, *args):
|
||||
|
@ -34,6 +69,11 @@ def call_pure_intrin(dtype, func_name, *args):
|
|||
|
||||
args : list
|
||||
Positional arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
call : Expr
|
||||
The call expression.
|
||||
"""
|
||||
args = convert(args)
|
||||
return _make.Call(
|
||||
|
@ -53,6 +93,11 @@ def call_pure_extern(dtype, func_name, *args):
|
|||
|
||||
args : list
|
||||
Positional arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
call : Expr
|
||||
The call expression.
|
||||
"""
|
||||
return _make.Call(
|
||||
dtype, func_name, convert(args), _Call.PureExtern, None, 0)
|
||||
|
|
|
@ -36,7 +36,6 @@ TVM_REGISTER_API("_const")
|
|||
}
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API("_str")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ir::StringImm::make(args[0]);
|
||||
|
|
|
@ -76,5 +76,6 @@ REGISTER_PASS2(SplitPipeline);
|
|||
REGISTER_PASS1(NarrowChannelAccess);
|
||||
REGISTER_PASS2(LowerThreadAllreduce);
|
||||
REGISTER_PASS2(LowerIntrin);
|
||||
REGISTER_PASS1(LowerPackedCall);
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
||||
|
|
|
@ -89,8 +89,7 @@ void CodeGenC::PrintSSAAssign(
|
|||
|
||||
// Print a reference expression to a buffer.
|
||||
std::string CodeGenC::GetBufferRef(
|
||||
const Variable* buffer,
|
||||
Type t, Expr index) {
|
||||
Type t, const Variable* buffer, Expr index) {
|
||||
std::ostringstream os;
|
||||
std::string vid = GetVarID(buffer);
|
||||
std::string scope;
|
||||
|
@ -151,6 +150,58 @@ std::string CodeGenC::GetBufferRef(
|
|||
return os.str();
|
||||
}
|
||||
|
||||
// Print a reference expression to a buffer.
|
||||
std::string CodeGenC::GetStructRef(
|
||||
Type t, const Expr& buffer, const Expr& index, int kind) {
|
||||
if (kind < intrinsic::kArrKindBound_) {
|
||||
std::ostringstream os;
|
||||
os << "(((TVMArray*)";
|
||||
this->PrintExpr(buffer, os);
|
||||
os << ")";
|
||||
if (kind == intrinsic::kArrAddr) {
|
||||
os << " + ";
|
||||
this->PrintExpr(index, os);
|
||||
os << ")";
|
||||
return os.str();
|
||||
}
|
||||
os << '[';
|
||||
this->PrintExpr(index, os);
|
||||
os << "].";
|
||||
// other case: get fields.
|
||||
switch (kind) {
|
||||
case intrinsic::kArrData: os << "data"; break;
|
||||
case intrinsic::kArrShape: os << "shape"; break;
|
||||
case intrinsic::kArrStrides: os << "strides"; break;
|
||||
case intrinsic::kArrNDim: os << "ndim"; break;
|
||||
case intrinsic::kArrTypeCode: os << "dtype.code"; break;
|
||||
case intrinsic::kArrTypeBits: os << "dtype.bits"; break;
|
||||
case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break;
|
||||
case intrinsic::kArrDeviceId: os << "ctx.device_id"; break;
|
||||
case intrinsic::kArrDeviceType: os << "ctx.device_type"; break;
|
||||
default: LOG(FATAL) << "unknown field code";
|
||||
}
|
||||
os << ')';
|
||||
return os.str();
|
||||
} else {
|
||||
CHECK_LT(kind, intrinsic::kTVMValueKindBound_);
|
||||
std::ostringstream os;
|
||||
os << "(((TVMValue*)";
|
||||
this->PrintExpr(buffer, os);
|
||||
os << ")[" << index << "].";
|
||||
if (t.is_handle()) {
|
||||
os << "v_handle";
|
||||
} else if (t.is_float()) {
|
||||
os << "v_float64";
|
||||
} else if (t.is_int()) {
|
||||
os << "v_int64";
|
||||
} else {
|
||||
LOG(FATAL) << "donot know how to handle type" << t;
|
||||
}
|
||||
os << ")";
|
||||
return os.str();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
|
||||
auto it = handle_data_type_.find(buf_var);
|
||||
|
@ -182,15 +233,15 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
|
|||
<< " = " << value << ";\n";
|
||||
}
|
||||
|
||||
std::string CodeGenC::GetVecLoad(const Variable* buffer,
|
||||
Type t, Expr base) {
|
||||
return GetBufferRef(buffer, t, base);
|
||||
std::string CodeGenC::GetVecLoad(
|
||||
Type t, const Variable* buffer, Expr base) {
|
||||
return GetBufferRef(t, buffer, base);
|
||||
}
|
||||
|
||||
void CodeGenC::PrintVecStore(const Variable* buffer,
|
||||
Type t, Expr base,
|
||||
const std::string& value) {
|
||||
std::string ref = GetBufferRef(buffer, t, base);
|
||||
std::string ref = GetBufferRef(t, buffer, base);
|
||||
this->PrintIndent();
|
||||
stream << ref << " = " << value << ";\n";
|
||||
}
|
||||
|
@ -430,42 +481,11 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
|
|||
<< " + ";
|
||||
this->PrintExpr(l->index, os);
|
||||
os << ')';
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
if (!op->type.is_handle()) {
|
||||
os << '(';
|
||||
this->PrintType(op->type, os);
|
||||
os << ')';
|
||||
}
|
||||
os << "(((TVMArg*)";
|
||||
this->PrintExpr(op->args[0], os);
|
||||
os << ")[" << op->args[2] << "].";
|
||||
if (op->type.is_handle()) {
|
||||
os << "v_handle";
|
||||
} else if (op->type.is_float()) {
|
||||
os << "v_double";
|
||||
} else if (op->type.is_int() || op->type.is_uint()) {
|
||||
os << "v_long";
|
||||
} else {
|
||||
LOG(FATAL) << "donot know how to handle type" << op->type;
|
||||
}
|
||||
os << ")";
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
os << "(((TVMArray*)";
|
||||
this->PrintExpr(op->args[0], os);
|
||||
os << ")->";
|
||||
switch (op->args[1].as<IntImm>()->value) {
|
||||
case intrinsic::kData: os << "data"; break;
|
||||
case intrinsic::kShape: os << "shape"; break;
|
||||
case intrinsic::kStrides: os << "strides"; break;
|
||||
case intrinsic::kNDim: os << "ndim"; break;
|
||||
case intrinsic::kTypeCode: os << "dtype.type_code"; break;
|
||||
case intrinsic::kTypeBits: os << "dtype.bits"; break;
|
||||
case intrinsic::kTypeLanes: os << "dtype.lanes"; break;
|
||||
default: LOG(FATAL) << "unknown field code";
|
||||
}
|
||||
os << ')';
|
||||
os << GetStructRef(
|
||||
op->type, op->args[0], op->args[1],
|
||||
op->args[2].as<IntImm>()->value);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
|
||||
CHECK_EQ(op->args.size(), 1U);
|
||||
os << "(";
|
||||
|
@ -513,12 +533,12 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
|
|||
int lanes = op->type.lanes();
|
||||
// delcare type.
|
||||
if (op->type.lanes() == 1) {
|
||||
std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index);
|
||||
std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
|
||||
os << ref;
|
||||
} else {
|
||||
Expr base;
|
||||
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
|
||||
std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base);
|
||||
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
|
||||
os << ref;
|
||||
} else {
|
||||
// load seperately.
|
||||
|
@ -552,7 +572,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
|
|||
Type t = op->value.type();
|
||||
if (t.lanes() == 1) {
|
||||
std::string value = this->PrintExpr(op->value);
|
||||
std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index);
|
||||
std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
|
||||
this->PrintIndent();
|
||||
stream << ref << " = " << value << ";\n";
|
||||
} else {
|
||||
|
@ -744,14 +764,25 @@ void CodeGenC::VisitStmt_(const Block *op) {
|
|||
void CodeGenC::VisitStmt_(const Evaluate *op) {
|
||||
if (is_const(op->value)) return;
|
||||
const Call* call = op->value.as<Call>();
|
||||
|
||||
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
|
||||
this->PrintStorageSync(call);
|
||||
} else {
|
||||
std::string vid = this->PrintExpr(op->value);
|
||||
this->PrintIndent();
|
||||
this->stream << "(void)" << vid << ";\n";
|
||||
if (call) {
|
||||
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
|
||||
this->PrintStorageSync(call); return;
|
||||
} else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
|
||||
CHECK_EQ(call->args.size(), 4);
|
||||
std::string value = PrintExpr(call->args[3]);
|
||||
std::string ref = GetStructRef(
|
||||
call->args[3].type(),
|
||||
call->args[0],
|
||||
call->args[1],
|
||||
call->args[2].as<IntImm>()->value);
|
||||
this->PrintIndent();
|
||||
this->stream << ref << " = " << value << ";\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::string vid = this->PrintExpr(op->value);
|
||||
this->PrintIndent();
|
||||
this->stream << "(void)" << vid << ";\n";
|
||||
}
|
||||
|
||||
void CodeGenC::VisitStmt_(const ProducerConsumer *op) {
|
||||
|
|
|
@ -133,8 +133,7 @@ class CodeGenC :
|
|||
const std::string&op, Type op_type,
|
||||
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
|
||||
// print vector load
|
||||
virtual std::string GetVecLoad(const Variable* buffer,
|
||||
Type t, Expr base);
|
||||
virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
|
||||
// print vector store
|
||||
virtual void PrintVecStore(const Variable* buffer,
|
||||
Type t, Expr base,
|
||||
|
@ -146,11 +145,13 @@ class CodeGenC :
|
|||
virtual void PrintVecElemStore(
|
||||
const std::string& vec, Type t, int i, const std::string& value);
|
||||
|
||||
|
||||
protected:
|
||||
// Print reference to struct location
|
||||
std::string GetStructRef(
|
||||
Type t, const Expr& buffer, const Expr& index, int kind);
|
||||
// print reference to a buffer as type t in index.
|
||||
std::string GetBufferRef(const Variable* buffer,
|
||||
Type t, Expr index);
|
||||
std::string GetBufferRef(
|
||||
Type t, const Variable* buffer, Expr index);
|
||||
/*!
|
||||
* \brief If buffer is allocated as type t.
|
||||
* \param buf_var The buffer variable.
|
||||
|
|
|
@ -95,8 +95,8 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
|
|||
os << GetVarID(buffer) << " + ";
|
||||
PrintExpr(base, os);
|
||||
}
|
||||
std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer,
|
||||
Type t, Expr base) {
|
||||
std::string CodeGenOpenCL::GetVecLoad(
|
||||
Type t, const Variable* buffer, Expr base) {
|
||||
std::ostringstream os;
|
||||
os << "vload" << t.lanes() << "(0, ";
|
||||
PrintVecAddr(buffer, t, base, os);
|
||||
|
|
|
@ -24,11 +24,11 @@ class CodeGenOpenCL : public CodeGenC {
|
|||
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
|
||||
void PrintStorageSync(const Call* op) final; // NOLINT(*)
|
||||
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
|
||||
std::string GetVecLoad(const Variable* buffer,
|
||||
Type t, Expr base) final;
|
||||
std::string GetVecLoad(Type t, const Variable* buffer,
|
||||
Expr base) final;
|
||||
void PrintVecStore(const Variable* buffer,
|
||||
Type t, Expr base,
|
||||
const std::string& value) final; // NOLINT(*)
|
||||
Type t, Expr base,
|
||||
const std::string& value) final; // NOLINT(*)
|
||||
// the address of load/store
|
||||
void PrintVecAddr(const Variable* buffer, Type t,
|
||||
Expr base, std::ostream& os); // NOLINT(*)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include "./codegen_llvm.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
#include "../../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
|
@ -89,7 +90,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
|
|||
void CodeGenLLVM::InitTarget(const std::string& target) {
|
||||
llvm::TargetMachine* tm;
|
||||
std::string target_triple;
|
||||
std::tie(tm, target_triple) = LLVMGetTarget(target);
|
||||
std::tie(tm, target_triple) = GetLLVMTarget(target);
|
||||
module_->setTargetTriple(target_triple);
|
||||
module_->setDataLayout(tm->createDataLayout());
|
||||
data_layout_.reset(new llvm::DataLayout(module_.get()));
|
||||
|
@ -318,6 +319,74 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
|
|||
return builder_->CreateInBoundsGEP(buffer, index);
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateStructRefPtr(
|
||||
Type t, llvm::Value* buf, llvm::Value* index, int kind) {
|
||||
if (kind < intrinsic::kArrKindBound_) {
|
||||
if (buf->getType() == t_void_p_) {
|
||||
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
|
||||
} else {
|
||||
CHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
|
||||
}
|
||||
}
|
||||
switch (kind) {
|
||||
case intrinsic::kArrAddr: {
|
||||
return builder_->CreateInBoundsGEP(buf, index);
|
||||
}
|
||||
case intrinsic::kArrData: {
|
||||
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
|
||||
}
|
||||
case intrinsic::kArrShape: {
|
||||
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
|
||||
}
|
||||
case intrinsic::kArrStrides: {
|
||||
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
|
||||
}
|
||||
case intrinsic::kArrNDim: {
|
||||
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
|
||||
}
|
||||
case intrinsic::kArrTypeCode: {
|
||||
return builder_->CreateInBoundsGEP(
|
||||
buf, {index, ConstInt32(3), ConstInt32(0)});
|
||||
}
|
||||
case intrinsic::kArrTypeBits: {
|
||||
return builder_->CreateInBoundsGEP(
|
||||
buf, {index, ConstInt32(3), ConstInt32(1)});
|
||||
}
|
||||
case intrinsic::kArrTypeLanes: {
|
||||
return builder_->CreateInBoundsGEP(
|
||||
buf, {index, ConstInt32(3), ConstInt32(2)});
|
||||
}
|
||||
case intrinsic::kArrByteOffset: {
|
||||
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
|
||||
}
|
||||
case intrinsic::kArrDeviceId: {
|
||||
return builder_->CreateInBoundsGEP(
|
||||
buf, {index, ConstInt32(1), ConstInt32(0)});
|
||||
}
|
||||
case intrinsic::kArrDeviceType: {
|
||||
return builder_->CreateInBoundsGEP(
|
||||
buf, {index, ConstInt32(1), ConstInt32(1)});
|
||||
}
|
||||
case intrinsic::kTVMValueContent: {
|
||||
CHECK_EQ(t.lanes(), 1);
|
||||
CHECK(t.is_handle() || t.bits() == 64);
|
||||
if (t.is_int()) {
|
||||
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
|
||||
return builder_->CreateInBoundsGEP(buf, index);
|
||||
} else if (t.is_float()) {
|
||||
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
|
||||
return builder_->CreateInBoundsGEP(buf, index);
|
||||
} else {
|
||||
CHECK(t.is_handle());
|
||||
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
|
||||
buf = builder_->CreateInBoundsGEP(buf, index);
|
||||
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
|
||||
}
|
||||
}
|
||||
default: LOG(FATAL) << "unknown field code"; return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
|
||||
llvm::Type * target = LLVMType(to);
|
||||
if (value->getType() == target) return value;
|
||||
|
@ -394,39 +463,33 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
|
|||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
|
||||
CHECK_GE(op->args.size(), 1U);
|
||||
CHECK_EQ(op->args.size(), 5U);
|
||||
std::string func_name = op->args[0].as<StringImm>()->value;
|
||||
llvm::Value* handle = GetPackedFuncHandle(func_name);
|
||||
// call the function
|
||||
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
|
||||
llvm::Value* targs = builder_->CreateAlloca(
|
||||
t_tvm_value_, ConstInt32(nargs));
|
||||
llvm::Value* tcodes = builder_->CreateAlloca(
|
||||
t_int_, ConstInt32(nargs));
|
||||
for (unsigned i = 0; i < nargs; ++i) {
|
||||
Expr expr = op->args[i + 1];
|
||||
Type t = expr.type();
|
||||
CHECK_EQ(t.lanes(), 1);
|
||||
// Always pass via 64 bit value.
|
||||
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
|
||||
Type api_type = t.with_bits(64);
|
||||
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
|
||||
llvm::Value* store_ptr = builder_->CreatePointerCast(
|
||||
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
|
||||
LLVMType(api_type)->getPointerTo());
|
||||
builder_->CreateAlignedStore(value, store_ptr, 8);
|
||||
builder_->CreateAlignedStore(
|
||||
ConstInt32(t.code()),
|
||||
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
|
||||
}
|
||||
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
|
||||
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
|
||||
int64_t begin = op->args[3].as<IntImm>()->value;
|
||||
int64_t end = op->args[4].as<IntImm>()->value;
|
||||
int64_t nargs = end - begin;
|
||||
CHECK_GE(nargs, 0);
|
||||
llvm::Value* stack_value = MakeValue(op->args[1]);
|
||||
llvm::Value* stack_tcode = MakeValue(op->args[2]);
|
||||
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
|
||||
builder_->CreatePointerCast(
|
||||
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
|
||||
llvm::Value* arg_tcode = CreateBufferPtr(
|
||||
Int(32), stack_tcode, ConstInt32(begin));
|
||||
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
|
||||
builder_->CreatePointerCast(
|
||||
stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
|
||||
llvm::Value* ret_tcode = CreateBufferPtr(
|
||||
Int(32), stack_tcode, ConstInt32(end));
|
||||
CheckCallSuccess(
|
||||
builder_->CreateCall(
|
||||
f_tvm_func_call_,
|
||||
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
|
||||
{handle, arg_value, arg_tcode, ConstInt32(nargs),
|
||||
ret_value, ret_tcode}));
|
||||
Type r_type = op->type;
|
||||
Type r_api_type = op->type.with_bits(64);
|
||||
Type r_api_type = ir::APIType(r_type);
|
||||
llvm::Value* rvalue =
|
||||
builder_->CreateAlignedLoad(
|
||||
builder_->CreatePointerCast(
|
||||
|
@ -649,62 +712,48 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
|
|||
llvm::Value* ptr = MakeValue(op->args[0]);
|
||||
return builder_->CreateICmpEQ(
|
||||
ptr, llvm::Constant::getNullValue(ptr->getType()));
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
CHECK_EQ(op->type.lanes(), 1);
|
||||
llvm::Value* args = builder_->CreatePointerCast(
|
||||
MakeValue(op->args[0]), t_tvm_value_->getPointerTo());
|
||||
llvm::Value* ptr = builder_->CreateInBoundsGEP(
|
||||
args, MakeValue(op->args[2]));
|
||||
// always pass via 64 bit pointers
|
||||
// For handle type, Handle(64) will simply become 32 bit void*
|
||||
Type value_type = op->type.with_bits(64);
|
||||
ptr = builder_->CreatePointerCast(
|
||||
ptr, LLVMType(value_type)->getPointerTo());
|
||||
llvm::Value* value = builder_->CreateAlignedLoad(ptr, 8);
|
||||
// cast to the desired type
|
||||
if (value_type != op->type) {
|
||||
value = CreateCast(value_type, op->type, value);
|
||||
int kind = op->args[2].as<IntImm>()->value;
|
||||
llvm::Value* ref = this->CreateStructRefPtr(
|
||||
op->type, MakeValue(op->args[0]),
|
||||
MakeValue(op->args[1]), kind);
|
||||
if (kind == intrinsic::kArrAddr) {
|
||||
return builder_->CreatePointerCast(ref, t_void_p_);
|
||||
} else {
|
||||
return builder_->CreateLoad(ref);
|
||||
}
|
||||
return value;
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
|
||||
CHECK_EQ(op->args.size(), 4U);
|
||||
int kind = op->args[2].as<IntImm>()->value;
|
||||
llvm::Value* value = MakeValue(op->args[3]);
|
||||
llvm::Value* ref = this->CreateStructRefPtr(
|
||||
op->args[3].type(), MakeValue(op->args[0]),
|
||||
MakeValue(op->args[1]), kind);
|
||||
CHECK(kind != intrinsic::kArrAddr);
|
||||
if (value->getType()->isPointerTy()) {
|
||||
value = builder_->CreatePointerCast(
|
||||
value, ref->getType()->getPointerElementType());
|
||||
}
|
||||
builder_->CreateStore(value, ref);
|
||||
return ConstInt32(0);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
llvm::Value* arr = builder_->CreatePointerCast(
|
||||
MakeValue(op->args[0]), t_tvm_array_->getPointerTo());
|
||||
llvm::Constant* zero = ConstInt32(0);
|
||||
llvm::Value* ret = nullptr;
|
||||
switch (op->args[1].as<IntImm>()->value) {
|
||||
case intrinsic::kData: {
|
||||
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break;
|
||||
}
|
||||
case intrinsic::kShape: {
|
||||
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(4)}); break;
|
||||
}
|
||||
case intrinsic::kStrides: {
|
||||
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(5)}); break;
|
||||
}
|
||||
case intrinsic::kNDim: {
|
||||
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
|
||||
}
|
||||
case intrinsic::kTypeCode: {
|
||||
ret = builder_->CreateInBoundsGEP(
|
||||
arr, {zero, ConstInt32(3), ConstInt32(0)}); break;
|
||||
}
|
||||
case intrinsic::kTypeBits: {
|
||||
ret = builder_->CreateInBoundsGEP(
|
||||
arr, {zero, ConstInt32(3), ConstInt32(1)}); break;
|
||||
}
|
||||
case intrinsic::kTypeLanes: {
|
||||
ret = builder_->CreateInBoundsGEP(
|
||||
arr, {zero, ConstInt32(3), ConstInt32(2)}); break;
|
||||
}
|
||||
case intrinsic::kByteOffset: {
|
||||
ret = builder_->CreateInBoundsGEP(
|
||||
arr, {zero, ConstInt32(6)}); break;
|
||||
}
|
||||
default: LOG(FATAL) << "unknown field code";
|
||||
const std::string& type = op->args[0].as<StringImm>()->value;
|
||||
llvm::Value* num = MakeValue(op->args[1]);
|
||||
if (type == "shape") {
|
||||
return builder_->CreateAlloca(t_tvm_shape_index_, num);
|
||||
} else if (type == "arg_value") {
|
||||
return builder_->CreateAlloca(t_tvm_value_, num);
|
||||
} else if (type == "arg_tcode") {
|
||||
return builder_->CreateAlloca(t_int_, num);
|
||||
} else if (type == "array") {
|
||||
return builder_->CreateAlloca(t_tvm_array_, num);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown stack alloca type " << type;
|
||||
}
|
||||
return builder_->CreateLoad(ret);
|
||||
} else if (op->is_intrinsic(Call::null_handle)) {
|
||||
return llvm::Constant::getNullValue(t_void_p_);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown intrinstic " << op->name;
|
||||
}
|
||||
|
@ -1180,9 +1229,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
|
||||
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
|
||||
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
|
||||
return CreateCallPacked(op);
|
||||
} else if (op->call_type == Call::Intrinsic ||
|
||||
op->call_type == Call::PureIntrinsic) {
|
||||
|
@ -1194,7 +1242,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const For* op) {
|
||||
CHECK(is_zero(op->min));
|
||||
if (op->for_type == ForType::Serial) {
|
||||
|
@ -1263,8 +1310,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
|
|||
const Variable* v = op->node.as<Variable>();
|
||||
CHECK(v);
|
||||
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
|
||||
this->VisitStmt(op->body);
|
||||
} else {
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
|
||||
|
|
|
@ -190,6 +190,7 @@ class CodeGenLLVM :
|
|||
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
|
||||
llvm::Value* GetConstString(const std::string& str);
|
||||
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
|
||||
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
|
||||
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
|
||||
llvm::Value* GetPackedFuncHandle(const std::string& str);
|
||||
// Vector concatenation.
|
||||
|
|
|
@ -37,7 +37,7 @@ void InitializeLLVM() {
|
|||
}
|
||||
|
||||
std::pair<llvm::TargetMachine*, std::string>
|
||||
LLVMGetTarget(const std::string& target_str) {
|
||||
GetLLVMTarget(const std::string& target_str) {
|
||||
// setup target triple
|
||||
std::string target_triple;
|
||||
CHECK_EQ(target_str.substr(0, 4), "llvm");
|
||||
|
|
|
@ -55,7 +55,7 @@ void InitializeLLVM();
|
|||
* \return Pair of target machine and target triple.
|
||||
*/
|
||||
std::pair<llvm::TargetMachine*, std::string>
|
||||
LLVMGetTarget(const std::string& target_str);
|
||||
GetLLVMTarget(const std::string& target_str);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -94,7 +94,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
|
|||
|
||||
void Init(const Array<LoweredFunc>& funcs, std::string target) {
|
||||
InitializeLLVM();
|
||||
std::tie(tm_, target_triple_) = LLVMGetTarget(target);
|
||||
std::tie(tm_, target_triple_) = GetLLVMTarget(target);
|
||||
CHECK_NE(funcs.size(), 0U);
|
||||
ctx_ = std::make_shared<llvm::LLVMContext>();
|
||||
CodeGenLLVM cg;
|
||||
|
|
|
@ -70,25 +70,6 @@ int CodeGenStackVM::AllocVarID(const Variable* v) {
|
|||
return vid;
|
||||
}
|
||||
|
||||
void CodeGenStackVM::PushCallPacked(
|
||||
int fid, const std::vector<int>& arg_type_codes) {
|
||||
StackVM::Code code;
|
||||
// CALL_PACKED_FUNC
|
||||
code.op_code = StackVM::CALL_PACKED_FUNC;
|
||||
vm_.code.push_back(code);
|
||||
// num_args
|
||||
code.v_int = static_cast<int>(arg_type_codes.size());
|
||||
vm_.code.push_back(code);
|
||||
// fid
|
||||
code.v_int = fid;
|
||||
vm_.code.push_back(code);
|
||||
// type codes.
|
||||
for (int tcode : arg_type_codes) {
|
||||
code.v_int = tcode;
|
||||
vm_.code.push_back(code);
|
||||
}
|
||||
}
|
||||
|
||||
int CodeGenStackVM::GetVarID(const Variable* v) const {
|
||||
auto it = var_idmap_.find(v);
|
||||
CHECK(it != var_idmap_.end())
|
||||
|
@ -97,26 +78,33 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
|
|||
}
|
||||
|
||||
void CodeGenStackVM::VisitExpr_(const Load* op) {
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
|
||||
if (op->type == UInt(32) && op->index.as<IntImm>()) {
|
||||
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
|
||||
this->Push(op->buffer_var);
|
||||
StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
|
||||
if (const IntImm* index = op->index.as<IntImm>()) {
|
||||
this->PushOp(code, op->index.as<IntImm>()->value);
|
||||
} else {
|
||||
this->Push(op->index);
|
||||
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
this->PushOp(StackVM::GetLoad(Type2TVMType(op->type)));
|
||||
this->PushOp(code, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::VisitStmt_(const Store* op) {
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
|
||||
this->Push(op->index);
|
||||
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
this->Push(op->value);
|
||||
this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type())));
|
||||
this->Push(op->buffer_var);
|
||||
StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type()));
|
||||
if (const IntImm* index = op->index.as<IntImm>()) {
|
||||
this->Push(op->value);
|
||||
this->PushOp(code, op->index.as<IntImm>()->value);
|
||||
} else {
|
||||
this->Push(op->index);
|
||||
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
this->Push(op->value);
|
||||
this->PushOp(code, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::VisitStmt_(const Allocate* op) {
|
||||
|
@ -141,41 +129,29 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
|
|||
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
|
||||
} else if (op->is_intrinsic(Call::null_handle)) {
|
||||
this->PushOp(StackVM::PUSH_I64, 0);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
int kind = op->args[2].as<IntImm>()->value;
|
||||
this->Push(op->args[0]);
|
||||
this->Push(op->args[1]);
|
||||
this->Push(op->args[2]);
|
||||
if (op->type.is_handle()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE);
|
||||
} else if (op->type.is_float()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_FP64);
|
||||
} else if (op->type.is_int() || op->type.is_uint()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_INT64);
|
||||
} else {
|
||||
LOG(FATAL) << "donot know how to handle type" << op->type;
|
||||
}
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
this->Push(op->args[0]);
|
||||
switch (op->args[1].as<IntImm>()->value) {
|
||||
case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
|
||||
case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
|
||||
case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
|
||||
case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
|
||||
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
|
||||
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
|
||||
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
|
||||
case intrinsic::kByteOffset: PushOp(StackVM::TVM_ARRAY_GET_BYTE_OFFSET); break;
|
||||
default: LOG(FATAL) << "unknown field code";
|
||||
}
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
|
||||
CHECK_GE(op->args.size(), 1U);
|
||||
const IntImm* index = op->args[1].as<IntImm>();
|
||||
CHECK(index != nullptr);
|
||||
StackVM::Code code;
|
||||
code.op_code = StackVM::TVM_STRUCT_GET;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = index->value;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = kind;
|
||||
vm_.code.push_back(code);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
|
||||
CHECK_GE(op->args.size(), 5U);
|
||||
const StringImm* s = op->args[0].as<StringImm>();
|
||||
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
this->Push(op->args[i]);
|
||||
}
|
||||
this->Push(op->args[1]);
|
||||
this->Push(op->args[2]);
|
||||
int begin = op->args[3].as<IntImm>()->value;
|
||||
int end = op->args[4].as<IntImm>()->value;
|
||||
// find the fuction id.
|
||||
const std::string& func_name = s->value;
|
||||
auto it = extern_fun_idmap_.find(func_name);
|
||||
|
@ -187,16 +163,39 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
|
|||
vm_.extern_func_name.push_back(func_name);
|
||||
extern_fun_idmap_[func_name] = fid;
|
||||
}
|
||||
// get the argument type code.
|
||||
std::vector<int> arg_type_codes;
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
Type t = op->args[i].type();
|
||||
int code = t.code();
|
||||
int lanes = t.lanes();
|
||||
CHECK_EQ(lanes, 1);
|
||||
arg_type_codes.push_back(code);
|
||||
// CALL_PACKED_FUNC
|
||||
StackVM::Code code;
|
||||
code.op_code = StackVM::CALL_PACKED_LOWERED;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = fid;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = begin;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = end;
|
||||
vm_.code.push_back(code);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
const std::string& type = op->args[0].as<StringImm>()->value;
|
||||
const IntImm* num = op->args[1].as<IntImm>();
|
||||
CHECK(num != nullptr);
|
||||
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
|
||||
static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
|
||||
size_t unit = sizeof(TVMValue);
|
||||
size_t size = 0;
|
||||
if (type == "shape") {
|
||||
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
|
||||
} else if (type == "arg_value") {
|
||||
size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
|
||||
} else if (type == "arg_tcode") {
|
||||
size = (num->value * sizeof(int) + unit - 1) / unit;
|
||||
} else if (type == "array") {
|
||||
size = (num->value * sizeof(TVMArray) + unit - 1) / unit;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown stack alloca type " << type;
|
||||
}
|
||||
this->PushCallPacked(fid, arg_type_codes);
|
||||
// add stack size to be safe.
|
||||
vm_.stack_size += size;
|
||||
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
|
||||
CHECK_EQ(op->args.size(), 1U);
|
||||
this->Push(op->args[0]);
|
||||
|
@ -389,10 +388,26 @@ void CodeGenStackVM::VisitStmt_(const Block *op) {
|
|||
if (op->rest.defined()) this->Push(op->rest);
|
||||
}
|
||||
|
||||
void CodeGenStackVM::VisitStmt_(const Evaluate *op) {
|
||||
if (is_const(op->value)) return;
|
||||
this->Push(op->value);
|
||||
this->PushOp(StackVM::POP);
|
||||
void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
|
||||
if (is_const(ev->value)) return;
|
||||
const Call* op = ev->value.as<Call>();
|
||||
if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
|
||||
CHECK_EQ(op->args.size(), 4U);
|
||||
this->Push(op->args[0]);
|
||||
this->Push(op->args[3]);
|
||||
const IntImm* index = op->args[1].as<IntImm>();
|
||||
CHECK(index != nullptr);
|
||||
StackVM::Code code;
|
||||
code.op_code = StackVM::TVM_STRUCT_SET;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = index->value;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = op->args[2].as<IntImm>()->value;
|
||||
vm_.code.push_back(code);
|
||||
} else {
|
||||
this->Push(ev->value);
|
||||
this->PushOp(StackVM::POP);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
|
||||
|
|
|
@ -55,13 +55,6 @@ class CodeGenStackVM
|
|||
* \return operand_index, indicating location of operand
|
||||
*/
|
||||
int64_t PushOp(StackVM::OpCode opcode, int operand);
|
||||
/*!
|
||||
* \brief Push a call packed function.
|
||||
* \param fid The function id.
|
||||
* \param arg_type_codes The type codes of arguments.
|
||||
*/
|
||||
void PushCallPacked(int fid,
|
||||
const std::vector<int>& arg_type_codes);
|
||||
/*!
|
||||
* \brief Set the relative jump offset to be offset.
|
||||
* \param operand_index The indexed returned by PushOp.
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
* \file stack_vm.cc
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/ir.h>
|
||||
#include "./stack_vm.h"
|
||||
|
||||
namespace tvm {
|
||||
|
@ -21,58 +22,50 @@ StackVM::State* StackVM::ThreadLocalState() {
|
|||
sp -= 1; pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_CMPOP(OP, FIELD) \
|
||||
{ \
|
||||
#define STACK_VM_CMPOP(OP, FIELD) \
|
||||
{ \
|
||||
stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
|
||||
sp -= 1; pc += 1; \
|
||||
sp -= 1; pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
|
||||
{ \
|
||||
stack[sp].FIELD = static_cast<DST_TYPE>( \
|
||||
*static_cast<SRC_TYPE*>(stack[sp].v_handle)); \
|
||||
pc += 1; \
|
||||
int index = code[pc + 1].v_int; \
|
||||
stack[sp]FIELD = static_cast<DST_TYPE>( \
|
||||
static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \
|
||||
pc += 2; \
|
||||
}
|
||||
|
||||
#define STACK_VM_STORE(FIELD, DST_TYPE) \
|
||||
{ \
|
||||
*static_cast<DST_TYPE*>(stack[sp - 1].v_handle) = \
|
||||
static_cast<DST_TYPE>(stack[sp].FIELD); \
|
||||
sp -= 2; pc += 1; \
|
||||
int index = code[pc + 1].v_int; \
|
||||
static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \
|
||||
static_cast<DST_TYPE>(stack[sp]FIELD); \
|
||||
sp -= 2; pc += 2; \
|
||||
}
|
||||
|
||||
#define STACK_VM_TVM_LOAD_ARG(OP, TYPE) \
|
||||
{ \
|
||||
TVMValue* args = static_cast<TVMValue*>(stack[sp - 2].v_handle); \
|
||||
int64_t index = stack[sp].v_int64; \
|
||||
int tc = static_cast<int*>(stack[sp - 1].v_handle)[index]; \
|
||||
CHECK(OP) \
|
||||
<< " argument " << index << " is expected to be " << TYPE; \
|
||||
stack[sp - 2] = args[index]; \
|
||||
sp -= 2; \
|
||||
pc += 1; \
|
||||
}
|
||||
|
||||
|
||||
#define STACK_VM_TVM_ARRARY_GET(FIELD, TYPE, SFIELD) \
|
||||
{ \
|
||||
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle); \
|
||||
stack[sp].FIELD = (TYPE)(arr->SFIELD); \
|
||||
pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_CODE0(CODE) \
|
||||
case CODE: { \
|
||||
#define STACK_VM_PRINT_CODE0(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_CODE1(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
|
||||
<< "[" << pc + 1 << "]" << std::endl; \
|
||||
<< "[" << pc + 1 << "]" << std::endl; \
|
||||
return pc + 2; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_CODE2(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE \
|
||||
<< " " << code[pc + 1].v_int \
|
||||
<< " " << code[pc + 2].v_int << "\n" \
|
||||
<< "[" << pc + 1 << "]" << std::endl \
|
||||
<< "[" << pc + 2 << "]" << std::endl; \
|
||||
return pc + 3; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \
|
||||
|
@ -110,13 +103,18 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
|
|||
STACK_VM_PRINT_CODE0(LT_F64);
|
||||
STACK_VM_PRINT_CODE0(LE_F64);
|
||||
// addressing load
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_UINT32);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT32);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT64);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_FP64);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_HANDLE);
|
||||
STACK_VM_PRINT_CODE0(ADDR_STORE_INT64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT32);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_FP64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_HANDLE);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_TVMVALUE);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_UINT32);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_INT32);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_INT64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_FP64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_HANDLE);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_STORE_TVMVALUE);
|
||||
STACK_VM_PRINT_CODE0(NOT);
|
||||
STACK_VM_PRINT_CODE0(ADDR_ADD);
|
||||
// stack ops
|
||||
|
@ -132,32 +130,24 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
|
|||
STACK_VM_PRINT_JUMP(RJUMP);
|
||||
STACK_VM_PRINT_CODE1(ASSERT_SP);
|
||||
// Intrinsics
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_INT64);
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_FP64);
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_HANDLE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_DATA);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_BYTE_OFFSET);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
|
||||
STACK_VM_PRINT_CODE2(TVM_STRUCT_GET);
|
||||
STACK_VM_PRINT_CODE2(TVM_STRUCT_SET);
|
||||
// Allocate data by 8 bytes.
|
||||
STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE);
|
||||
// packed function.
|
||||
case CALL_PACKED_FUNC: {
|
||||
int num_args = code[pc + 1].v_int;
|
||||
case CALL_PACKED_LOWERED: {
|
||||
int call_fid = code[pc + 1].v_int;
|
||||
int begin = code[pc + 2].v_int;
|
||||
int end = code[pc + 3].v_int;
|
||||
os << "[" << pc << "]\tCALL_PACKED_FUNC "
|
||||
<< " num_args=" << num_args
|
||||
<< " fid=" << code[pc + 2].v_int;
|
||||
os << " type_codes:";
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
os << ' ' << code[pc + 3 + i].v_int;
|
||||
}
|
||||
<< " fid=" << call_fid
|
||||
<< " begin=" << begin
|
||||
<< " end=" << end;
|
||||
os << '\n';
|
||||
for (int i = 0; i < num_args + 2; ++i) {
|
||||
os << "[" << pc + 1 << "]" << std::endl;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
os << "[" << pc + 1 + i << "]" << std::endl;
|
||||
}
|
||||
return pc + 3 + num_args;
|
||||
return pc + 4;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "unknown op code " << code[pc].op_code;
|
||||
|
@ -193,6 +183,7 @@ void StackVM::operator()(const runtime::TVMArgs& args) const {
|
|||
void StackVM::Run(State* s) const {
|
||||
int64_t sp = s->sp;
|
||||
int64_t pc = s->pc;
|
||||
int64_t alloca_sp = s->sp;
|
||||
std::vector<TVMValue>& stack = s->stack;
|
||||
std::vector<TVMValue>& heap = s->heap;
|
||||
s->extern_func.clear();
|
||||
|
@ -223,23 +214,26 @@ void StackVM::Run(State* s) const {
|
|||
case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
|
||||
case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
|
||||
// addressing
|
||||
case ADDR_LOAD_UINT32: STACK_VM_LOAD(v_int64, int64_t, uint32_t); break;
|
||||
case ADDR_LOAD_INT32: STACK_VM_LOAD(v_int64, int64_t, int32_t); break;
|
||||
case ADDR_LOAD_INT64: STACK_VM_LOAD(v_int64, int64_t, int64_t); break;
|
||||
case ADDR_LOAD_FP64: STACK_VM_LOAD(v_float64, double, double); break;
|
||||
case ADDR_LOAD_HANDLE: STACK_VM_LOAD(v_handle, void*, void*); break;
|
||||
case ADDR_STORE_INT64: STACK_VM_STORE(v_int64, int64_t); break;
|
||||
case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break;
|
||||
case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break;
|
||||
case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break;
|
||||
case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break;
|
||||
case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break;
|
||||
case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break;
|
||||
// store
|
||||
case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break;
|
||||
case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break;
|
||||
case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break;
|
||||
case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break;
|
||||
case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break;
|
||||
case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break;
|
||||
// add
|
||||
case ADDR_ADD: {
|
||||
stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
|
||||
sp = sp - 1;
|
||||
pc = pc + 1;
|
||||
break;
|
||||
}
|
||||
case ARRAY_LOAD_UINT32: {
|
||||
stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int]; // NOLINT(*)
|
||||
pc = pc + 2;
|
||||
break;
|
||||
}
|
||||
case NOT: {
|
||||
stack[sp].v_int64 = !stack[sp].v_int64;
|
||||
pc += 1;
|
||||
|
@ -282,21 +276,6 @@ void StackVM::Run(State* s) const {
|
|||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case CALL_PACKED_FUNC: {
|
||||
// call packed function.
|
||||
int num_args = code[pc + 1].v_int;
|
||||
int call_fid = code[pc + 2].v_int;
|
||||
static_assert(sizeof(Code) == sizeof(int) &&
|
||||
alignof(Code) == alignof(int), "asusmption");
|
||||
const int* type_codes = &(code[pc].v_int) + 3;
|
||||
runtime::TVMRetValue rv;
|
||||
GetExtern(s, call_fid).CallPacked(
|
||||
runtime::TVMArgs(&stack[sp + 1 - num_args], type_codes, num_args), &rv);
|
||||
sp = sp + 1 - num_args;
|
||||
stack[sp] = rv.value();
|
||||
pc += 3 + num_args;
|
||||
break;
|
||||
}
|
||||
case ASSERT: {
|
||||
CHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
|
||||
sp -= 1;
|
||||
|
@ -331,41 +310,145 @@ void StackVM::Run(State* s) const {
|
|||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case TVM_LOAD_ARG_INT64: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kInt, "int"); break;
|
||||
case CALL_PACKED_LOWERED: {
|
||||
// call packed function.
|
||||
TVMValue* value_stack = static_cast<TVMValue*>(stack[sp - 1].v_handle);
|
||||
int* type_stack = static_cast<int*>(stack[sp].v_handle);
|
||||
int call_fid = code[pc + 1].v_int;
|
||||
int begin = code[pc + 2].v_int;
|
||||
int end = code[pc + 3].v_int;
|
||||
int num_args = end - begin;
|
||||
static_assert(sizeof(Code) == sizeof(int) &&
|
||||
alignof(Code) == alignof(int), "asusmption");
|
||||
runtime::TVMRetValue rv;
|
||||
GetExtern(s, call_fid).CallPacked(
|
||||
runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
|
||||
sp = sp - 1;
|
||||
stack[sp] = rv.value();
|
||||
pc += 4;
|
||||
break;
|
||||
}
|
||||
case TVM_LOAD_ARG_FP64: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
|
||||
// intrinsics
|
||||
case TVM_STRUCT_GET: {
|
||||
using namespace ir;
|
||||
int index = code[pc + 1].v_int;
|
||||
int kind = code[pc + 2].v_int;
|
||||
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle);
|
||||
switch (kind) {
|
||||
case intrinsic::kArrData: {
|
||||
stack[sp].v_handle = arr[index].data; break;
|
||||
}
|
||||
case intrinsic::kArrShape: {
|
||||
stack[sp].v_handle = arr[index].shape; break;
|
||||
}
|
||||
case intrinsic::kArrStrides: {
|
||||
stack[sp].v_handle = arr[index].strides; break;
|
||||
}
|
||||
case intrinsic::kArrNDim: {
|
||||
stack[sp].v_int64 = arr[index].ndim; break;
|
||||
}
|
||||
case intrinsic::kArrTypeCode: {
|
||||
stack[sp].v_int64 = static_cast<int64_t>(
|
||||
arr[index].dtype.code); break;
|
||||
}
|
||||
case intrinsic::kArrTypeBits: {
|
||||
stack[sp].v_int64 = static_cast<int64_t>(
|
||||
arr[index].dtype.bits); break;
|
||||
}
|
||||
case intrinsic::kArrTypeLanes: {
|
||||
stack[sp].v_int64 = static_cast<int64_t>(
|
||||
arr[index].dtype.lanes); break;
|
||||
}
|
||||
case intrinsic::kArrByteOffset: {
|
||||
stack[sp].v_int64 = static_cast<int64_t>(
|
||||
arr[index].byte_offset); break;
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrDeviceId: {
|
||||
stack[sp].v_int64 = arr[index].ctx.device_id; break;
|
||||
}
|
||||
case intrinsic::kArrDeviceType: {
|
||||
stack[sp].v_int64 = static_cast<int64_t>(
|
||||
arr[index].ctx.device_type); break;
|
||||
}
|
||||
case intrinsic::kArrAddr: {
|
||||
stack[sp].v_handle = arr + index; break;
|
||||
}
|
||||
case intrinsic::kTVMValueContent: {
|
||||
stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
|
||||
}
|
||||
default: LOG(FATAL) << "unhandled get " << kind;
|
||||
}
|
||||
pc = pc + 3;
|
||||
break;
|
||||
}
|
||||
case TVM_LOAD_ARG_HANDLE: {
|
||||
STACK_VM_TVM_LOAD_ARG(
|
||||
tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break;
|
||||
case TVM_STRUCT_SET: {
|
||||
using namespace ir;
|
||||
int index = code[pc + 1].v_int;
|
||||
int kind = code[pc + 2].v_int;
|
||||
TVMArray* arr = static_cast<TVMArray*>(stack[sp - 1].v_handle);
|
||||
switch (kind) {
|
||||
case intrinsic::kArrData: {
|
||||
arr[index].data = stack[sp].v_handle; break;
|
||||
}
|
||||
case intrinsic::kArrShape: {
|
||||
arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrStrides: {
|
||||
arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrNDim: {
|
||||
arr[index].ndim = static_cast<int>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrTypeCode: {
|
||||
arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrTypeBits: {
|
||||
arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrTypeLanes: {
|
||||
arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrByteOffset: {
|
||||
arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrDeviceId: {
|
||||
arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kArrDeviceType: {
|
||||
arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
|
||||
break;
|
||||
}
|
||||
case intrinsic::kTVMValueContent: {
|
||||
static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
|
||||
}
|
||||
default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
|
||||
}
|
||||
sp -= 2;
|
||||
pc += 3;
|
||||
break;
|
||||
}
|
||||
case TVM_ARRAY_GET_DATA: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_SHAPE: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, shape); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_STRIDES: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, strides); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_NDIM: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_BYTE_OFFSET: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, byte_offset); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_CODE: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_BITS: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_LANES: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.lanes); break;
|
||||
// alloca
|
||||
case TVM_STACK_ALLOCA_BY_8BYTE: {
|
||||
static_assert(sizeof(TVMValue) == 8, "invariance");
|
||||
int num = code[pc + 1].v_int;
|
||||
void* addr = &stack[sp] + 1;
|
||||
sp = sp + num + 1;
|
||||
alloca_sp = sp - 1;
|
||||
stack[sp].v_handle = addr;
|
||||
pc = pc + 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK_GE(sp, alloca_sp) << "touch allocated space";
|
||||
CHECK_LT(sp, stack_cap) << "Stack overflow";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,24 +55,33 @@ class StackVM {
|
|||
EQ_F64,
|
||||
LT_F64,
|
||||
LE_F64,
|
||||
// load operation
|
||||
ADDR_LOAD_UINT32,
|
||||
ADDR_LOAD_INT32,
|
||||
ADDR_LOAD_INT64,
|
||||
ADDR_LOAD_FP64,
|
||||
ADDR_LOAD_HANDLE,
|
||||
// store operations
|
||||
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
|
||||
// sp = sp - 2;
|
||||
ADDR_STORE_INT64,
|
||||
/*!
|
||||
* \brief Quick routine to load uint32 from constant offset.
|
||||
* \brief Routine to load data from address with const offset.
|
||||
* \code
|
||||
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int];
|
||||
* stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
|
||||
* pc = pc + 2;
|
||||
* \endcode
|
||||
*/
|
||||
ARRAY_LOAD_UINT32,
|
||||
ARRAY_LOAD_INT32,
|
||||
ARRAY_LOAD_INT64,
|
||||
ARRAY_LOAD_FP64,
|
||||
ARRAY_LOAD_HANDLE,
|
||||
ARRAY_LOAD_TVMVALUE,
|
||||
/*!
|
||||
* \brief Routine to store data from constant offset.
|
||||
* \code
|
||||
* ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
|
||||
* pc = pc + 2;
|
||||
* sp = sp - 2;
|
||||
* \endcode
|
||||
*/
|
||||
ARRAY_STORE_UINT32,
|
||||
ARRAY_STORE_INT32,
|
||||
ARRAY_STORE_INT64,
|
||||
ARRAY_STORE_FP64,
|
||||
ARRAY_STORE_HANDLE,
|
||||
ARRAY_STORE_TVMVALUE,
|
||||
// logical ops
|
||||
NOT,
|
||||
/*!
|
||||
|
@ -128,20 +137,6 @@ class StackVM {
|
|||
* \endcode
|
||||
*/
|
||||
SELECT,
|
||||
/*!
|
||||
* \brief call an extern packed function
|
||||
* \code
|
||||
* num_args = stack[sp].v_int64;
|
||||
* call_fid = code[pc + 1].v_int;
|
||||
* f = extern_func[call_fid];
|
||||
* int* type_codes = &(code[pc + 2].v_int)
|
||||
* stack[sp - num_args] = f(&stack[sp - num_args], type_codes, num_args);
|
||||
* sp = sp - num_args;
|
||||
* // The type codes are hidden in the code space.
|
||||
* pc = pc + 2 + num_args
|
||||
* \endcode
|
||||
*/
|
||||
CALL_PACKED_FUNC,
|
||||
/*!
|
||||
* \brief Assert condition is true.
|
||||
* \code
|
||||
|
@ -189,18 +184,56 @@ class StackVM {
|
|||
* \code
|
||||
*/
|
||||
ASSERT_SP,
|
||||
// Intrinsics for API function,
|
||||
TVM_LOAD_ARG_INT64,
|
||||
TVM_LOAD_ARG_FP64,
|
||||
TVM_LOAD_ARG_HANDLE,
|
||||
TVM_ARRAY_GET_DATA,
|
||||
TVM_ARRAY_GET_SHAPE,
|
||||
TVM_ARRAY_GET_STRIDES,
|
||||
TVM_ARRAY_GET_NDIM,
|
||||
TVM_ARRAY_GET_TYPE_CODE,
|
||||
TVM_ARRAY_GET_TYPE_BITS,
|
||||
TVM_ARRAY_GET_TYPE_LANES,
|
||||
TVM_ARRAY_GET_BYTE_OFFSET
|
||||
/*!
|
||||
* \brief call an extern packed function
|
||||
* \code
|
||||
* value_stack = stack[sp - 1].v_handle;
|
||||
* type_stack = stack[sp - 0].v_handle;
|
||||
* call_fid = code[pc + 1].v_int;
|
||||
* begin = code[pc + 2].v_int;
|
||||
* end = code[pc + 3].v_int;
|
||||
* num_args = end - begin - 1;
|
||||
* f = extern_func[call_fid];
|
||||
* stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
|
||||
* sp = sp - 1;
|
||||
* // The type codes are hidden in the code space.
|
||||
* pc = pc + 4
|
||||
* \endcode
|
||||
*/
|
||||
CALL_PACKED_LOWERED,
|
||||
// Allocate things on stack
|
||||
/*!
|
||||
* \brief allocate data from stack.
|
||||
* \code
|
||||
* num = code[pc + 1].v_int;
|
||||
* void* addr = &stack[sp];
|
||||
* sp = sp + num;
|
||||
* stack[sp].v_handle = addr;
|
||||
* pc = pc + 1;
|
||||
* \endcode
|
||||
*/
|
||||
TVM_STACK_ALLOCA_BY_8BYTE,
|
||||
/*!
|
||||
* \brief get data from structure.
|
||||
* \code
|
||||
* index = code[pc + 1].v_int;
|
||||
* field = code[pc + 2].v_int;
|
||||
* stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
|
||||
* pc = pc + 3
|
||||
* \endcode
|
||||
*/
|
||||
TVM_STRUCT_GET,
|
||||
/*!
|
||||
* \brief set data into structure.
|
||||
* \code
|
||||
* index = code[pc + 1].v_int;
|
||||
* field = code[pc + 2].v_int;
|
||||
* ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
|
||||
* pc = pc + 3
|
||||
* sp = sp - 1
|
||||
* \endcode
|
||||
*/
|
||||
TVM_STRUCT_SET
|
||||
};
|
||||
/*! \brief The code structure */
|
||||
union Code {
|
||||
|
@ -276,23 +309,23 @@ class StackVM {
|
|||
*/
|
||||
static OpCode GetLoad(TVMType t) {
|
||||
CHECK_EQ(t.lanes, 1U);
|
||||
if (t.code == kHandle) return ADDR_LOAD_HANDLE;
|
||||
if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
|
||||
if (t.code == kInt) {
|
||||
switch (t.bits) {
|
||||
case 32 : return ADDR_LOAD_INT32;
|
||||
case 64 : return ADDR_LOAD_INT64;
|
||||
case 32 : return ARRAY_LOAD_INT32;
|
||||
case 64 : return ARRAY_LOAD_INT64;
|
||||
}
|
||||
} else if (t.code == kUInt) {
|
||||
switch (t.bits) {
|
||||
case 32 : return ADDR_LOAD_UINT32;
|
||||
case 32 : return ARRAY_LOAD_UINT32;
|
||||
}
|
||||
} else if (t.code == kFloat) {
|
||||
switch (t.bits) {
|
||||
case 64 : return ADDR_LOAD_FP64;
|
||||
case 64 : return ARRAY_LOAD_FP64;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Cannot load type " << t;
|
||||
return ADDR_LOAD_FP64;
|
||||
return ARRAY_LOAD_FP64;
|
||||
}
|
||||
/*!
|
||||
* \brief Get store opcode for type t
|
||||
|
@ -301,13 +334,23 @@ class StackVM {
|
|||
*/
|
||||
static OpCode GetStore(TVMType t) {
|
||||
CHECK_EQ(t.lanes, 1U);
|
||||
if (t.code == kHandle) return ARRAY_STORE_HANDLE;
|
||||
if (t.code == kInt) {
|
||||
switch (t.bits) {
|
||||
case 64 : return ADDR_STORE_INT64;
|
||||
case 32 : return ARRAY_STORE_INT32;
|
||||
case 64 : return ARRAY_STORE_INT64;
|
||||
}
|
||||
} else if (t.code == kUInt) {
|
||||
switch (t.bits) {
|
||||
case 32 : return ARRAY_STORE_UINT32;
|
||||
}
|
||||
} else if (t.code == kFloat) {
|
||||
switch (t.bits) {
|
||||
case 64 : return ARRAY_STORE_FP64;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Cannot store type " << t;
|
||||
return ADDR_LOAD_FP64;
|
||||
return ARRAY_STORE_FP64;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
|
||||
|
||||
|
|
|
@ -85,6 +85,70 @@ inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
|
|||
return body;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get construct from struct
|
||||
* \param dtype The data type.
|
||||
* \param handle the struct handle.
|
||||
* \param index the offset index.
|
||||
* \param kind The data kind.
|
||||
* \return the get expression.
|
||||
*/
|
||||
inline Expr TVMStructGet(
|
||||
Type dtype, Var handle, int index,
|
||||
intrinsic::TVMStructFieldKind kind) {
|
||||
Array<Expr> args ={
|
||||
handle,
|
||||
make_const(Int(32), index),
|
||||
make_const(Int(32), kind)};
|
||||
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Address of handle + offset
|
||||
* \param handle the array handle.
|
||||
* \param dtype The data type.
|
||||
* \param offset the offset index.
|
||||
*/
|
||||
inline Expr AddressOffset(Var handle, Type dtype, int offset) {
|
||||
return Call::make(
|
||||
Handle(), Call::address_of,
|
||||
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Set value into struct.
|
||||
* \param handle the struct handle.
|
||||
* \param index the offset index.
|
||||
* \param kind The data kind.
|
||||
* \param value The value to be set.
|
||||
* \return the set stmt.
|
||||
*/
|
||||
inline Stmt TVMStructSet(
|
||||
Var handle, int index,
|
||||
intrinsic::TVMStructFieldKind kind, Expr value) {
|
||||
Array<Expr> args ={
|
||||
handle,
|
||||
make_const(Int(32), index),
|
||||
make_const(Int(32), kind),
|
||||
value};
|
||||
return Evaluate::make(
|
||||
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get the type that is passed around TVM PackedFunc API.
|
||||
* \param t The original type.
|
||||
* \return The corresponding API type.
|
||||
*/
|
||||
inline Type APIType(Type t) {
|
||||
if (t.is_handle()) return t;
|
||||
CHECK_EQ(t.lanes(), 1)
|
||||
<< "Cannot pass vector type through packed API.";
|
||||
if (t.is_uint() || t.is_int()) return Int(64);
|
||||
CHECK(t.is_float());
|
||||
return Float(64);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
||||
#endif // TVM_PASS_IR_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Lower calls to packed function.
|
||||
* \file lower_packed_call.cc
|
||||
*/
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_mutator.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <unordered_set>
|
||||
#include "./ir_util.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace ir {
|
||||
|
||||
inline Expr ConstInt32(size_t index) {
|
||||
CHECK_LE(index, std::numeric_limits<int>::max());
|
||||
return make_const(Int(32), static_cast<int>(index));
|
||||
}
|
||||
|
||||
inline Expr StackAlloca(std::string type, size_t num) {
|
||||
Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
|
||||
return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
|
||||
}
|
||||
|
||||
// Calculate the statistics of packed function.
|
||||
// These information are needed during codegen.
|
||||
class PackedCallBuilder : public IRMutator {
|
||||
public:
|
||||
Stmt Build(Stmt stmt) {
|
||||
stack_shape_ = Var("stack_shape", Handle());
|
||||
stack_array_ = Var("stack_array", Handle());
|
||||
stack_value_ = Var("stack_value", Handle());
|
||||
stack_tcode_ = Var("stack_tcode", Handle());
|
||||
stmt = this->Mutate(stmt);
|
||||
if (max_shape_stack_ != 0) {
|
||||
stmt = LetStmt::make(
|
||||
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
|
||||
}
|
||||
if (max_array_stack_ != 0) {
|
||||
stmt = LetStmt::make(
|
||||
stack_array_, StackAlloca("array", max_array_stack_), stmt);
|
||||
}
|
||||
if (max_arg_stack_ != 0) {
|
||||
stmt = LetStmt::make(
|
||||
stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
|
||||
stmt = LetStmt::make(
|
||||
stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
|
||||
}
|
||||
return stmt;
|
||||
}
|
||||
Stmt Mutate(Stmt stmt) final {
|
||||
stmt = IRMutator::Mutate(stmt);
|
||||
CHECK_EQ(run_shape_stack_, 0);
|
||||
CHECK_EQ(run_array_stack_, 0);
|
||||
CHECK_EQ(run_arg_stack_, 0);
|
||||
while (prep_seq_.size() != 0) {
|
||||
stmt = Block::make(prep_seq_.back(), stmt);
|
||||
prep_seq_.pop_back();
|
||||
}
|
||||
return stmt;
|
||||
}
|
||||
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
|
||||
if (op->attr_key == attr::device_context_id) {
|
||||
CHECK(!device_id_.defined());
|
||||
device_id_ = op->value;
|
||||
return Mutate(op->body);
|
||||
} else if (op->attr_key == attr::device_context_type) {
|
||||
CHECK(!device_type_.defined());
|
||||
device_type_ = op->value;
|
||||
return Mutate(op->body);
|
||||
} else {
|
||||
return IRMutator::Mutate_(op, s);
|
||||
}
|
||||
}
|
||||
Expr Mutate_(const Call* op, const Expr &e) final {
|
||||
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
|
||||
return MakeCallPacked(op, e);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
|
||||
return MakeShape(op, e);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
|
||||
return MakeArray(op, e);
|
||||
} else {
|
||||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
}
|
||||
Expr Convert(Type t, Expr e) {
|
||||
if (e.type() != t) {
|
||||
return Cast::make(t, e);
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
// call shape
|
||||
Expr MakeShape(const Call* op, const Expr& e) {
|
||||
size_t stack_begin = run_shape_stack_;
|
||||
run_shape_stack_ += op->args.size();
|
||||
Expr expr = IRMutator::Mutate_(op, e);
|
||||
op = expr.as<Call>();
|
||||
for (size_t i = 0; i < op->args.size(); ++i) {
|
||||
prep_seq_.emplace_back(
|
||||
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
|
||||
ConstInt32(stack_begin +i)));
|
||||
}
|
||||
return AddressOffset(stack_shape_, Int(64), stack_begin);
|
||||
}
|
||||
// make array
|
||||
Expr MakeArray(const Call* op, const Expr& e) {
|
||||
size_t idx = run_array_stack_;
|
||||
run_array_stack_ += 1;
|
||||
Expr expr = IRMutator::Mutate_(op, e);
|
||||
op = expr.as<Call>();
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
|
||||
Expr strides = op->args[2];
|
||||
if (!strides.defined() || is_zero(strides)) {
|
||||
strides = make_zero(Handle());
|
||||
}
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
|
||||
Type dtype = op->args[4].type();
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
|
||||
make_const(UInt(8), static_cast<int>(dtype.code()))));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
|
||||
make_const(UInt(8), dtype.bits())));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
|
||||
make_const(UInt(16), dtype.lanes())));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
|
||||
Convert(Int(64), op->args[5])));
|
||||
CHECK(device_type_.defined()) << "Unknown device type in current IR";
|
||||
CHECK(device_id_.defined()) << "Unknown device id in current IR";
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
|
||||
Convert(Int(32), device_id_)));
|
||||
prep_seq_.emplace_back(
|
||||
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
|
||||
Convert(Int(32), device_type_)));
|
||||
return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
|
||||
}
|
||||
// call packled.
|
||||
Expr MakeCallPacked(const Call* op, const Expr& e) {
|
||||
size_t restore_shape_stack = run_shape_stack_;
|
||||
size_t restore_array_stack = run_array_stack_;
|
||||
size_t arg_stack_begin = run_arg_stack_;
|
||||
run_arg_stack_ += op->args.size();
|
||||
// Specially handle the buffer packed intrinsic
|
||||
Expr expr = IRMutator::Mutate_(op, e);
|
||||
op = expr.as<Call>();
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
|
||||
Expr arg = op->args[i];
|
||||
Type t = arg.type();
|
||||
Type api_type = APIType(t);
|
||||
if (t != api_type) {
|
||||
arg = Cast::make(api_type, arg);
|
||||
}
|
||||
prep_seq_.emplace_back(TVMStructSet(
|
||||
stack_value_, static_cast<int>(arg_stack_begin + i - 1),
|
||||
intrinsic::kTVMValueContent, arg));
|
||||
int arg_tcode = api_type.code();
|
||||
if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
|
||||
prep_seq_.emplace_back(
|
||||
Store::make(stack_tcode_,
|
||||
ConstInt32(arg_tcode),
|
||||
stack_index));
|
||||
}
|
||||
// UPDATE stack value
|
||||
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
|
||||
max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
|
||||
max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
|
||||
run_shape_stack_ = restore_shape_stack;
|
||||
run_array_stack_ = restore_array_stack;
|
||||
run_arg_stack_ = arg_stack_begin;
|
||||
Array<Expr> packed_args = {
|
||||
op->args[0],
|
||||
stack_value_,
|
||||
stack_tcode_,
|
||||
ConstInt32(arg_stack_begin),
|
||||
ConstInt32(arg_stack_begin + op->args.size() - 1)
|
||||
};
|
||||
return Call::make(
|
||||
Int(32), intrinsic::tvm_call_packed_lowered,
|
||||
packed_args, Call::Intrinsic);
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsArrayHandle(const Expr& arg) {
|
||||
// specially set array handle.
|
||||
if (const Call* buf = arg.as<Call>()) {
|
||||
if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
|
||||
buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// The prepration sequence to be emitted.
|
||||
std::vector<Stmt> prep_seq_;
|
||||
Expr device_type_;
|
||||
Expr device_id_;
|
||||
// Var handle for each stack.
|
||||
Var stack_shape_;
|
||||
Var stack_array_;
|
||||
Var stack_tcode_;
|
||||
Var stack_value_;
|
||||
// The running statistics
|
||||
uint64_t run_shape_stack_{0};
|
||||
uint64_t run_array_stack_{0};
|
||||
uint64_t run_arg_stack_{0};
|
||||
// statistics of stacks
|
||||
uint64_t max_shape_stack_{0};
|
||||
uint64_t max_array_stack_{0};
|
||||
uint64_t max_arg_stack_{0};
|
||||
};
|
||||
|
||||
LoweredFunc LowerPackedCall(LoweredFunc f) {
|
||||
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
|
||||
n->body = PackedCallBuilder().Build(n->body);
|
||||
return LoweredFunc(n);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
|
@ -4,6 +4,7 @@
|
|||
*/
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_visitor.h>
|
||||
#include <tvm/buffer.h>
|
||||
|
||||
#include <vector>
|
||||
|
@ -15,11 +16,8 @@
|
|||
namespace tvm {
|
||||
namespace ir {
|
||||
|
||||
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
|
||||
return Call::make(
|
||||
t, intrinsic::tvm_array_get_field,
|
||||
{arr, IntImm::make(Int(32), kind)},
|
||||
Call::PureIntrinsic);
|
||||
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
|
||||
return TVMStructGet(t, arr, 0, kind);
|
||||
}
|
||||
|
||||
inline Stmt AssertNull(Var handle, std::string msg) {
|
||||
|
@ -55,15 +53,25 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
std::unordered_set<const Variable*> visited;
|
||||
// the handle data types
|
||||
Map<Var, Expr> handle_data_type;
|
||||
// The device context
|
||||
Var device_id, device_type;
|
||||
// ---------------------------
|
||||
// local function defintiions
|
||||
// load i-th argument as type t
|
||||
auto f_arg_value = [&](Type t, int i) {
|
||||
Array<Expr> call_args{
|
||||
v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)};
|
||||
return Call::make(
|
||||
t, intrinsic::tvm_api_load_arg, call_args,
|
||||
Array<Expr> call_args{v_packed_args,
|
||||
IntImm::make(Int(32), i),
|
||||
IntImm::make(Int(32), intrinsic::kTVMValueContent)};
|
||||
// load 64 bit version
|
||||
Type api_type = APIType(t);
|
||||
Expr res = Call::make(
|
||||
api_type, intrinsic::tvm_struct_get, call_args,
|
||||
Call::PureIntrinsic);
|
||||
// cast to the target version.
|
||||
if (api_type != t) {
|
||||
res = Cast::make(t, res);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
// get declaration of argument i
|
||||
auto f_arg_decl = [&](int i) {
|
||||
|
@ -107,8 +115,32 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
|
||||
Var v_arg = f_arg_decl(i);
|
||||
if (i < num_packed_args) {
|
||||
// Value loads
|
||||
seq_init.emplace_back(LetStmt::make(
|
||||
v_arg, f_arg_value(v_arg.type(), i), nop));
|
||||
// type code checks
|
||||
Var tcode(v_arg->name_hint + ".code", Int(32));
|
||||
seq_init.emplace_back(LetStmt::make(
|
||||
tcode, Load::make(
|
||||
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop));
|
||||
Type t = v_arg.type();
|
||||
if (t.is_handle()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Expect argument " << i << " to be pointer";
|
||||
seq_check.emplace_back(
|
||||
AssertStmt::make(tcode == kHandle ||
|
||||
tcode == kArrayHandle ||
|
||||
tcode == kNull, msg.str()));
|
||||
} else if (t.is_int() || t.is_uint()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Expect argument " << i << " to be int";
|
||||
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str()));
|
||||
} else {
|
||||
CHECK(t.is_float());
|
||||
std::ostringstream msg;
|
||||
msg << "Expect argument " << i << " to be float";
|
||||
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str()));
|
||||
}
|
||||
} else {
|
||||
args.push_back(v_arg);
|
||||
}
|
||||
|
@ -121,7 +153,7 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
<< "api_args can only be Buffer or Var";
|
||||
Buffer buf(api_args[i].node_);
|
||||
// dimension checks
|
||||
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kNDim);
|
||||
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kArrNDim);
|
||||
std::ostringstream ndim_err_msg;
|
||||
ndim_err_msg << "arg_" << i
|
||||
<< ".ndim is expected to equal "
|
||||
|
@ -135,15 +167,15 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
Type dtype = buf->dtype;
|
||||
std::ostringstream type_err_msg;
|
||||
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
|
||||
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) ==
|
||||
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeCode) ==
|
||||
UIntImm::make(UInt(8), dtype.code()) &&
|
||||
TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) ==
|
||||
TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeBits) ==
|
||||
UIntImm::make(UInt(8), dtype.bits()) &&
|
||||
TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) ==
|
||||
TVMArrayGet(UInt(16), v_arg, intrinsic::kArrTypeLanes) ==
|
||||
UIntImm::make(UInt(16), dtype.lanes()));
|
||||
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
|
||||
// Data Field
|
||||
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
|
||||
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kArrData),
|
||||
v_arg->name_hint + ".data")) {
|
||||
Var vptr(buf->data);
|
||||
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
|
||||
|
@ -152,20 +184,22 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
Var v_shape(v_arg->name_hint + ".shape", Handle());
|
||||
handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
|
||||
seq_init.emplace_back(LetStmt::make(
|
||||
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
|
||||
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kArrShape), nop));
|
||||
for (size_t k = 0; k < buf->shape.size(); ++k) {
|
||||
std::ostringstream field_name;
|
||||
field_name << v_shape->name_hint << '[' << k << ']';
|
||||
f_push(buf->shape[k],
|
||||
cast(buf->shape[k].type(),
|
||||
Load::make(tvm_shape_type, v_shape, IntImm::make(Int(32), k))),
|
||||
Load::make(tvm_shape_type, v_shape,
|
||||
IntImm::make(Int(32), k))),
|
||||
field_name.str());
|
||||
}
|
||||
// strides field
|
||||
Var v_strides(v_arg->name_hint + ".strides", Handle());
|
||||
handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
|
||||
seq_init.emplace_back(LetStmt::make(
|
||||
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
|
||||
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kArrStrides),
|
||||
nop));
|
||||
if (buf->strides.size() == 0) {
|
||||
std::ostringstream stride_err_msg;
|
||||
stride_err_msg << "arg_" << i << ".strides:"
|
||||
|
@ -177,13 +211,22 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
field_name << v_strides->name_hint << '[' << k << ']';
|
||||
f_push(buf->strides[k],
|
||||
cast(buf->shape[k].type(),
|
||||
Load::make(tvm_shape_type, v_strides, IntImm::make(Int(32), k))),
|
||||
Load::make(tvm_shape_type, v_strides,
|
||||
IntImm::make(Int(32), k))),
|
||||
field_name.str());
|
||||
}
|
||||
}
|
||||
// Byte_offset field.
|
||||
f_push(buf->byte_offset, TVMArrayGet(UInt(64), v_arg, intrinsic::kByteOffset),
|
||||
f_push(buf->byte_offset,
|
||||
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
|
||||
v_arg->name_hint + ".byte_offset");
|
||||
// device info.
|
||||
f_push(device_id,
|
||||
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
|
||||
v_arg->name_hint + ".device_id");
|
||||
f_push(device_type,
|
||||
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceType),
|
||||
v_arg->name_hint + ".device_type");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,6 +235,16 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
n->args = args;
|
||||
n->handle_data_type = handle_data_type;
|
||||
n->is_packed_func = num_unpacked_args == 0;
|
||||
|
||||
// Set device context
|
||||
if (visited.count(device_id.get())) {
|
||||
Expr node = StringImm::make("default");
|
||||
CHECK(visited.count(device_type.get()));
|
||||
seq_init.push_back(AttrStmt::make(
|
||||
node, attr::device_context_id, device_id, nop));
|
||||
seq_init.push_back(AttrStmt::make(
|
||||
node, attr::device_context_type, device_type, nop));
|
||||
}
|
||||
n->body = MergeNest({seq_init, seq_check}, body);
|
||||
LoweredFunc f(n);
|
||||
Array<Var> undefined = UndefinedVars(f->body, f->args);
|
||||
|
|
|
@ -70,6 +70,17 @@ class StorageAccessPatternFinder : public IRVisitor {
|
|||
linear_seq_.push_back(e);
|
||||
}
|
||||
}
|
||||
void Visit_(const Evaluate* op) final {
|
||||
scope_.push_back(StmtEntry());
|
||||
// visit subexpr
|
||||
IRVisitor::Visit_(op);
|
||||
StmtEntry e = scope_.back();
|
||||
scope_.pop_back();
|
||||
if (e.access.size() != 0) {
|
||||
e.stmt = op;
|
||||
linear_seq_.push_back(e);
|
||||
}
|
||||
}
|
||||
void Visit_(const Load* op) final {
|
||||
// Add write access.
|
||||
IRVisitor::Visit_(op);
|
||||
|
@ -79,14 +90,14 @@ class StorageAccessPatternFinder : public IRVisitor {
|
|||
CHECK_LT(it->second, scope_.size())
|
||||
<< "Load memory in places other than store.";
|
||||
scope_[it->second].access.emplace_back(
|
||||
AccessEntry(buf, op->index, kRead, GetScope(buf)));
|
||||
AccessEntry(buf, op->index, kRead, GetScope(buf)));
|
||||
}
|
||||
}
|
||||
void Visit_(const Variable* buf) final {
|
||||
// Directly reference to the variable count as a read.
|
||||
auto it = alloc_scope_level_.find(buf);
|
||||
if (it != alloc_scope_level_.end()) {
|
||||
CHECK_LT(it->second, scope_.size());
|
||||
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
|
||||
scope_[it->second].access.emplace_back(
|
||||
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
|
||||
}
|
||||
|
|
|
@ -25,7 +25,8 @@ def test_add_pipeline():
|
|||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
|
||||
stmt = tvm.ir_pass.Simplify(stmt)
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
|
||||
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
|
||||
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
|
||||
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
|
||||
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
|
|
|
@ -34,5 +34,78 @@ def test_add_pipeline():
|
|||
check_llvm()
|
||||
|
||||
|
||||
def test_pack_buffer_simple():
|
||||
nn = 1024
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
def extern_generator(ins, outs):
|
||||
"""Manually write the IR for the extern function, add pipeline."""
|
||||
return tvm.call_packed("my_extern_array_func1", ins[0], outs[0])
|
||||
|
||||
C = tvm.extern(A.shape, [A], extern_generator, name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
@tvm.register_func
|
||||
def my_extern_array_func1(aa, bb):
|
||||
aa.copyto(bb)
|
||||
|
||||
|
||||
def check_target(target):
|
||||
if not tvm.codegen.enabled(target):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], target)
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
|
||||
f(a, c)
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), a.asnumpy())
|
||||
check_target("stackvm")
|
||||
check_target("llvm")
|
||||
|
||||
|
||||
def test_pack_buffer_intermediate():
|
||||
nn = 1024
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
|
||||
def extern_generator(ins, outs):
|
||||
"""Manually write the IR for the extern function, add pipeline."""
|
||||
return tvm.call_packed("my_extern_array_func2", ins[0], outs[0])
|
||||
|
||||
C = tvm.extern(B.shape, [B], extern_generator, name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
def check_target(target):
|
||||
if not tvm.codegen.enabled(target):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], target)
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
|
||||
@tvm.register_func
|
||||
def my_extern_array_func2(aa, bb):
|
||||
assert aa.shape == a.shape
|
||||
np.testing.assert_allclose(
|
||||
aa.asnumpy(), a.asnumpy() + 1)
|
||||
aa.copyto(bb)
|
||||
|
||||
f(a, c)
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), a.asnumpy() + 1)
|
||||
|
||||
check_target("llvm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pack_buffer_simple()
|
||||
test_pack_buffer_intermediate()
|
||||
test_add_pipeline()
|
||||
|
|
|
@ -9,7 +9,6 @@ def run_jit(fapi, check):
|
|||
s = f.get_source()
|
||||
check(f)
|
||||
|
||||
|
||||
def test_stack_vm_basic():
|
||||
a = tvm.nd.array(np.zeros(10, dtype='float32'))
|
||||
@tvm.register_func
|
||||
|
@ -21,6 +20,7 @@ def test_stack_vm_basic():
|
|||
Ab = tvm.decl_buffer((n, ), tvm.float32)
|
||||
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
|
||||
fapi = tvm.ir_pass.LowerPackedCall(fapi)
|
||||
run_jit(fapi, lambda f: f(a))
|
||||
|
||||
|
||||
|
@ -42,6 +42,7 @@ def test_stack_vm_loop():
|
|||
|
||||
stmt = ib.get()
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
|
||||
fapi = tvm.ir_pass.LowerPackedCall(fapi)
|
||||
a = tvm.nd.array(np.zeros(10, dtype=dtype))
|
||||
def check(f):
|
||||
f(a)
|
||||
|
@ -64,6 +65,7 @@ def test_stack_vm_cond():
|
|||
|
||||
stmt = ib.get()
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
|
||||
fapi = tvm.ir_pass.LowerPackedCall(fapi)
|
||||
def check(f):
|
||||
a = tvm.nd.array(np.zeros(10, dtype=dtype))
|
||||
f(a)
|
||||
|
|
|
@ -38,7 +38,8 @@ def test_add_pipeline():
|
|||
px, x = s[C].split(C.op.axis[0], nparts=1)
|
||||
s[C].bind(px, tvm.thread_axis("pipeline"))
|
||||
fapi = lower(s, [A, B, C], "myadd")
|
||||
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
|
||||
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
|
||||
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
|
||||
print(fsplits[1].body)
|
||||
print("------")
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче