[API/JIT] Enable registerable global function, introduce StackVM intepreter (#25)
This commit is contained in:
Родитель
01a7ce0cb6
Коммит
4242b9cff5
|
@ -10,7 +10,7 @@
|
|||
#include "./base.h"
|
||||
#include "./expr.h"
|
||||
#include "./module.h"
|
||||
#include "./runtime/runtime.h"
|
||||
#include "./runtime/packed_func.h"
|
||||
|
||||
|
||||
namespace tvm {
|
||||
|
|
|
@ -81,11 +81,13 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
|
|||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
* bool tvm_print(VType value) {
|
||||
* LOG(INFO) << value;
|
||||
* int tvm_call_global(name, TVMValue* args) {
|
||||
* PackedFunc f = PackedFunc::GetGlobal(name);
|
||||
* f (args, type_code_of(args), len(args));
|
||||
* return 0;
|
||||
* }
|
||||
*/
|
||||
constexpr const char* tvm_print = "tvm_print";
|
||||
constexpr const char* tvm_call_global = "tvm_call_global";
|
||||
|
||||
/*! \brief The field id of each field in array */
|
||||
enum TVMArrayFieldKind {
|
||||
|
|
|
@ -20,9 +20,6 @@ namespace tvm {
|
|||
// Internal node container of lowered function.
|
||||
class LoweredFuncNode;
|
||||
|
||||
// Internal node container of module.
|
||||
class ModuleNode;
|
||||
|
||||
/*!
|
||||
* \brief LoweredFunc represents function after lowering.
|
||||
* This is the final IR representation before codegen.
|
||||
|
|
|
@ -161,7 +161,7 @@ TVM_DLL const char *TVMGetLastError(void);
|
|||
* \param option_vals Additional option values to pass
|
||||
* \param num_options Number of options to be passed into it.
|
||||
* \param out_code 1: success, 0: already initialized
|
||||
* \return Whether the function is successful.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMDeviceInit(int dev_mask,
|
||||
const char** option_keys,
|
||||
|
@ -188,7 +188,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
|
|||
* \param dtype The array data type.
|
||||
* \param ctx The ctx this array sits on.
|
||||
* \param out The output handle.
|
||||
* \return Whether the function is successful.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
|
||||
tvm_index_t ndim,
|
||||
|
@ -198,6 +198,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
|
|||
/*!
|
||||
* \brief Free the TVM Array.
|
||||
* \param handle The array handle to be freed.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
|
||||
|
||||
|
@ -206,6 +207,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
|
|||
* \param from The array to be copied from.
|
||||
* \param to The target space.
|
||||
* \param stream The stream where the copy happens, can be NULL.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
|
||||
TVMArrayHandle to,
|
||||
|
@ -214,13 +216,14 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
|
|||
* \brief Wait until all computations on stream completes.
|
||||
* \param ctx The ctx to be synchronized.
|
||||
* \param stream The stream to be synchronized.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
|
||||
|
||||
/*!
|
||||
* \brief Free the function when it is no longer needed.
|
||||
* \param func The function handle
|
||||
* \return whether
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
|
||||
|
||||
|
@ -239,6 +242,57 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
|||
TVMValue* args,
|
||||
int* type_codes,
|
||||
int num_args);
|
||||
|
||||
/*!
|
||||
* \brief C type of packed function.
|
||||
*
|
||||
* \param args The arguments
|
||||
* \param type_codes The type codes of the arguments
|
||||
* \param num_args Number of arguments.
|
||||
* \param resource_handle The handle additional resouce handle from fron-end.
|
||||
*/
|
||||
typedef void (*TVMPackedCFunc)(
|
||||
TVMValue* args, int* type_codes, int num_args, void* resource_handle);
|
||||
|
||||
/*!
|
||||
* \brief C callback to free the resource handle in C packed function.
|
||||
* \param resource_handle The handle additional resouce handle from fron-end.
|
||||
*/
|
||||
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
|
||||
|
||||
/*!
|
||||
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
|
||||
*
|
||||
* The resource_handle will be managed by TVM API, until the function is no longer used.
|
||||
*
|
||||
* \param func The packed C function.
|
||||
* \param resource_handle The resource handle from front-end, can be NULL.
|
||||
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
|
||||
* \param out the result function handle.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
||||
void* resource_handle,
|
||||
TVMPackedCFuncFinalizer fin,
|
||||
TVMFunctionHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief Register the function to runtime's global table.
|
||||
*
|
||||
* The registered function then can be pulled by the backend by the name.
|
||||
*
|
||||
* \param name The name of the function.
|
||||
* \param f The function to be registered.
|
||||
*/
|
||||
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
|
||||
|
||||
/*!
|
||||
* \brief Get a global function.
|
||||
*
|
||||
* \param name The name of the function.
|
||||
* \param out the result function pointer.
|
||||
*/
|
||||
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
|
||||
} // TVM_EXTERN_C
|
||||
|
||||
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
|
||||
|
|
|
@ -1,35 +1,43 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file runtime.h
|
||||
* \file packed_func.h
|
||||
* \brief Runtime related c++ class.
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_RUNTIME_H_
|
||||
#define TVM_RUNTIME_RUNTIME_H_
|
||||
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
|
||||
#define TVM_RUNTIME_PACKED_FUNC_H_
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "./c_runtime_api.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
/*!
|
||||
* \brief Packed function is a runtime function
|
||||
* whose argument type_codes are erased by packed format.
|
||||
* \brief Packed function is a type-erased function.
|
||||
* The arguments are passed by packed format.
|
||||
*
|
||||
* This is an useful unified interface to call generated functions.
|
||||
* This is an useful unified interface to call generated functions,
|
||||
* It is the unified function function type of TVM.
|
||||
* It corresponds to TVMFunctionHandle in C runtime API.
|
||||
*/
|
||||
class PackedFunc {
|
||||
public:
|
||||
/*! \brief The internal std::function */
|
||||
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
|
||||
/*! \brief default constructor */
|
||||
PackedFunc() {}
|
||||
/*!
|
||||
* \brief constructing a packed function from a std::function.
|
||||
* \param body the internal container of packed function.
|
||||
*/
|
||||
explicit PackedFunc(FType body) : body_(body) {}
|
||||
/*!
|
||||
* \brief invoke the packed function by directly passing in arguments.
|
||||
* \brief Call packed function by directly passing in unpacked format.
|
||||
* \param args Arguments to be passed.
|
||||
* \tparam Args arguments to be passed.
|
||||
* \return The first return value.
|
||||
*/
|
||||
template<typename... Args>
|
||||
inline void operator()(Args&& ...args) const;
|
||||
|
@ -41,9 +49,25 @@ class PackedFunc {
|
|||
*/
|
||||
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
|
||||
/*! \return the internal body function */
|
||||
inline FType body() const {
|
||||
return body_;
|
||||
}
|
||||
inline FType body() const;
|
||||
/*!
|
||||
* \brief Register f as into global function table
|
||||
* \param name The name of the function.
|
||||
* \param f The function to be registered.
|
||||
* \return Reference to the registered function.
|
||||
* \note The returned reference is valid until the end of the program
|
||||
*/
|
||||
static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f);
|
||||
/*!
|
||||
* \brief Get the global function by name.
|
||||
* \param name The name of the function.
|
||||
* \return reference to the registered function.
|
||||
*/
|
||||
static const PackedFunc& GetGlobal(const std::string& name);
|
||||
/*!
|
||||
* \brief Get the names of currently registered global function.
|
||||
*/
|
||||
static std::vector<std::string> ListGlobalNames();
|
||||
|
||||
private:
|
||||
/*! \brief internal container of packed function */
|
||||
|
@ -56,6 +80,10 @@ inline void PackedFunc::CallPacked(
|
|||
body_(args, type_codes, num_args);
|
||||
}
|
||||
|
||||
inline PackedFunc::FType PackedFunc::body() const {
|
||||
return body_;
|
||||
}
|
||||
|
||||
template<bool stop, std::size_t I, typename F, typename ...Args>
|
||||
struct for_each_dispatcher_ {
|
||||
static inline void run(const std::tuple<Args...>& args, F f) {
|
||||
|
@ -124,4 +152,4 @@ inline void PackedFunc::operator()(Args&& ...args) const {
|
|||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_RUNTIME_H_
|
||||
#endif // TVM_RUNTIME_PACKED_FUNC_H_
|
|
@ -1,6 +1,6 @@
|
|||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
|
||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
|
||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements
|
||||
"""Symbolic configuration API."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .._base import c_str, py_str, string_types
|
|||
from .._base import check_call, ctypes2docstring
|
||||
from .. import _api_internal
|
||||
from . import _runtime_api
|
||||
from ._types import TVMValue, TypeCode
|
||||
from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer
|
||||
|
||||
# type definitions
|
||||
APIFuncHandle = ctypes.c_void_p
|
||||
|
@ -57,6 +57,13 @@ def _return_func(x):
|
|||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def _return_handle(x):
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, ctypes.c_void_p):
|
||||
handle = ctypes.c_void_p(handle)
|
||||
return handle
|
||||
|
||||
|
||||
RET_SWITCH = {
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
|
@ -66,6 +73,15 @@ RET_SWITCH = {
|
|||
TypeCode.FUNC_HANDLE: _return_func
|
||||
}
|
||||
|
||||
PACK_ARG_SWITCH = {
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||
TypeCode.STR: lambda x: py_str(x.v_str),
|
||||
TypeCode.HANDLE: lambda x: _return_handle,
|
||||
}
|
||||
|
||||
|
||||
class SliceBase(object):
|
||||
"""base class of slice object"""
|
||||
pass
|
||||
|
@ -159,10 +175,53 @@ def const(value, dtype=None):
|
|||
return _api_internal._const(value, dtype)
|
||||
|
||||
|
||||
def _ctypes_free_resource(rhandle):
|
||||
"""callback to free resources when it it not needed."""
|
||||
pyobj = ctypes.cast(rhandle, ctypes.py_object)
|
||||
ctypes.pythonapi.Py_DecRef(pyobj)
|
||||
|
||||
# Global callback that is always alive
|
||||
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
|
||||
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
|
||||
|
||||
def convert_to_tvm_func(pyfunc):
|
||||
"""Convert a python function to TVM function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pyfunc : python function
|
||||
The python function to be converted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tvmfunc: tvm.nd.Function
|
||||
The converted tvm function.
|
||||
"""
|
||||
local_pyfunc = pyfunc
|
||||
def cfun(args, type_codes, num_args, _):
|
||||
""" ctypes function """
|
||||
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
||||
pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
||||
local_pyfunc(*pyargs)
|
||||
handle = FunctionHandle()
|
||||
f = TVMPackedCFunc(cfun)
|
||||
# NOTE: We will need to use python-api to increase ref count of the f
|
||||
# TVM_FREE_PYOBJ will be called after it is no longer needed.
|
||||
pyobj = ctypes.py_object(f)
|
||||
ctypes.pythonapi.Py_IncRef(pyobj)
|
||||
check_call(_LIB.TVMFuncCreateFromCFunc(
|
||||
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
|
||||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def convert(value):
|
||||
"""Convert a value to expression."""
|
||||
if isinstance(value, Number):
|
||||
if isinstance(value, (NodeBase, _runtime_api.FunctionBase)):
|
||||
return value
|
||||
elif isinstance(value, Number):
|
||||
return const(value)
|
||||
elif isinstance(value, string_types):
|
||||
return _api_internal._str(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value = [convert(x) for x in value]
|
||||
return _api_internal._Array(*value)
|
||||
|
@ -176,10 +235,11 @@ def convert(value):
|
|||
return _api_internal._Map(*vlist)
|
||||
elif isinstance(value, SliceBase):
|
||||
return value.tensor(*value.indices)
|
||||
elif callable(value):
|
||||
return convert_to_tvm_func(value)
|
||||
else:
|
||||
if not isinstance(value, NodeBase):
|
||||
raise ValueError("don't know how to handle type %s" % type(value))
|
||||
return value
|
||||
raise ValueError("don't know how to handle type %s" % type(value))
|
||||
return value
|
||||
|
||||
|
||||
def _push_arg(arg):
|
||||
|
@ -270,6 +330,59 @@ def register_node(type_key=None):
|
|||
NODE_TYPE[cls.__name__] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def register_func(func_name, f=None):
|
||||
"""Register global function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func_name : str or function
|
||||
The function name
|
||||
|
||||
f : function
|
||||
The function to be registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fregister : function
|
||||
Register function if f is not specified.
|
||||
"""
|
||||
if callable(func_name):
|
||||
f = func_name
|
||||
func_name = f.__name__
|
||||
|
||||
if not isinstance(func_name, str):
|
||||
raise ValueError("expect string function name")
|
||||
def register(myf):
|
||||
"""internal register function"""
|
||||
if not isinstance(myf, _runtime_api.FunctionBase):
|
||||
myf = convert_to_tvm_func(myf)
|
||||
check_call(_LIB.TVMFuncRegisterGlobal(
|
||||
c_str(func_name), myf.handle))
|
||||
if f:
|
||||
register(f)
|
||||
else:
|
||||
return register
|
||||
|
||||
|
||||
def get_global_func(name):
|
||||
"""Get a global function by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the global function
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.nd.Function
|
||||
The function to be returned.
|
||||
"""
|
||||
handle = FunctionHandle()
|
||||
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
||||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def _init_api_module(root_namespace):
|
||||
"""List and add all the functions to current module."""
|
||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||
|
|
|
@ -70,3 +70,16 @@ class TVMType(ctypes.Structure):
|
|||
if self.lanes != 1:
|
||||
x += "x%d" % self.lanes
|
||||
return x
|
||||
|
||||
|
||||
TVMPackedCFunc = ctypes.CFUNCTYPE(
|
||||
None,
|
||||
ctypes.POINTER(TVMValue),
|
||||
ctypes.POINTER(ctypes.c_int),
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p)
|
||||
|
||||
|
||||
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
|
||||
None,
|
||||
ctypes.c_void_p)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# pylint: disable=protected-access, no-member, invalid-name
|
||||
# pylint: disable=redefined-builtin, undefined-variable
|
||||
# pylint: disable=redefined-builtin, undefined-variable, unused-import
|
||||
"""Functions defined in TVM."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from numbers import Integral as _Integral
|
||||
from ._ctypes._api import _init_api_module, convert
|
||||
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
|
||||
from . import _api_internal
|
||||
from . import make as _make
|
||||
from . import expr as _expr
|
||||
|
|
|
@ -52,5 +52,10 @@ TVM_REGISTER_API(_codegen_DummyHelloFunction)
|
|||
*ret = runtime::PackedFunc(DummyHelloFunction);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_BuildStackVM)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = BuildStackVM(args.at(0));
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -46,7 +46,8 @@ TVM_REGISTER_API(_make_Call)
|
|||
args.at(1),
|
||||
args.at(2),
|
||||
static_cast<Call::CallType>(args.at(3).operator int()),
|
||||
args.at(4));
|
||||
args.at(4),
|
||||
args.at(5));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_make_Allocate)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
* \file c_api_lang.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/buffer.h>
|
||||
#include <tvm/schedule.h>
|
||||
|
@ -27,6 +28,13 @@ TVM_REGISTER_API(_const)
|
|||
.add_argument("src", "Number", "source number")
|
||||
.add_argument("dtype", "str", "data type");
|
||||
|
||||
|
||||
TVM_REGISTER_API(_str)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = ir::StringImm::make(args.at(0));
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_Array)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
std::vector<std::shared_ptr<Node> > data;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <tvm/base.h>
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/c_api.h>
|
||||
#include <tvm/runtime/runtime.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
|
|
@ -0,0 +1,497 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_stack_vm.cc
|
||||
*/
|
||||
#include <limits>
|
||||
#include "./codegen_stack_vm.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
using namespace ir;
|
||||
|
||||
runtime::PackedFunc BuildStackVM(LoweredFunc func) {
|
||||
StackVM vm = codegen::CodeGenStackVM().Compile(func);
|
||||
auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) {
|
||||
LOG(INFO) << "Run stack VM";
|
||||
StackVM::State* s = StackVM::ThreadLocalState();
|
||||
s->sp = 0;
|
||||
s->pc = 0;
|
||||
if (s->heap.size() < vm.heap_size) {
|
||||
s->heap.resize(vm.heap_size);
|
||||
}
|
||||
s->heap[0].v_handle = (void*)args; // NOLINT(*)
|
||||
s->heap[1].v_handle = (void*)type_codes; // NOLINT(*)
|
||||
s->heap[2].v_int64 = num_args;
|
||||
vm.Run(s);
|
||||
};
|
||||
return runtime::PackedFunc(f);
|
||||
}
|
||||
|
||||
TVMValue TVMPrint(const TVMValue* args, int num_args) {
|
||||
CHECK_EQ(num_args, 2);
|
||||
int tcode = static_cast<int>(args[1].v_int64);
|
||||
int code = (tcode >> (8 * 3)) & 255;
|
||||
int bits = (tcode >> (8 * 2)) & 255;
|
||||
int lanes = tcode & ((1 << 16) - 1);
|
||||
Type t((halide_type_code_t)code, bits, lanes);
|
||||
if (t.is_handle()) {
|
||||
LOG(INFO) << t << ": " << args[0].v_handle;
|
||||
} else if (t.is_float()) {
|
||||
LOG(INFO) << t << ": " << args[0].v_float64;
|
||||
} else {
|
||||
LOG(INFO) << t << ": " << args[0].v_int64;
|
||||
}
|
||||
TVMValue r; r.v_int64 = 0;
|
||||
return r;
|
||||
}
|
||||
|
||||
CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*)
|
||||
static FType inst; return inst;
|
||||
}
|
||||
|
||||
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
|
||||
for (size_t i = 0; i < f->args.size(); ++i) {
|
||||
Var v = f->args[i];
|
||||
int vid = AllocVarID(v.get());
|
||||
CHECK_EQ(static_cast<size_t>(vid), i);
|
||||
}
|
||||
this->Push(f->body);
|
||||
return std::move(vm_);
|
||||
}
|
||||
|
||||
void CodeGenStackVM::Push(const Stmt& n) {
|
||||
static const FType& f = vtable();
|
||||
f(n, this);
|
||||
if (debug_) {
|
||||
this->PushOp(StackVM::ASSERT_SP, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::Push(const Expr& n) {
|
||||
static const FType& f = vtable();
|
||||
f(n, this);
|
||||
}
|
||||
|
||||
void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
|
||||
StackVM::Code code;
|
||||
code.op_code = opcode;
|
||||
vm_.code.push_back(code);
|
||||
}
|
||||
|
||||
void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
|
||||
CHECK(operand >= std::numeric_limits<int>::min() &&
|
||||
operand <= std::numeric_limits<int>::max());
|
||||
vm_.code.at(operand_index).v_int = static_cast<int>(operand);
|
||||
}
|
||||
|
||||
int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) {
|
||||
int64_t pc = static_cast<int64_t>(vm_.code.size());
|
||||
StackVM::Code code;
|
||||
code.op_code = opcode;
|
||||
vm_.code.push_back(code);
|
||||
code.v_int = operand;
|
||||
vm_.code.push_back(code);
|
||||
return pc + 1;
|
||||
}
|
||||
|
||||
int CodeGenStackVM::GetStrID(const std::string& key) {
|
||||
auto it = str_idmap_.find(key);
|
||||
if (it != str_idmap_.end()) return it->second;
|
||||
int sid = static_cast<int>(vm_.str_data.size());
|
||||
vm_.str_data.push_back(key);
|
||||
str_idmap_[key] = sid;
|
||||
return sid;
|
||||
}
|
||||
|
||||
int CodeGenStackVM::AllocVarID(const Variable* v) {
|
||||
CHECK(!var_idmap_.count(v));
|
||||
int vid = static_cast<int>(vm_.heap_size);
|
||||
CHECK_EQ(vm_.heap_size, var_idmap_.size());
|
||||
vm_.heap_id_name.push_back(v->name_hint);
|
||||
++vm_.heap_size;
|
||||
var_idmap_[v] = vid;
|
||||
return vid;
|
||||
}
|
||||
|
||||
int CodeGenStackVM::GetGlobalFuncID(std::string name) {
|
||||
auto it = fun_idmap_.find(name);
|
||||
if (it != fun_idmap_.end()) return it->second;
|
||||
using runtime::PackedFunc;
|
||||
PackedFunc f = PackedFunc::GetGlobal(name);
|
||||
auto extern_f = [f](const TVMValue* args, int num_args) {
|
||||
CHECK_EQ(num_args % 2, 0);
|
||||
num_args = num_args / 2;
|
||||
std::vector<int> type_codes(std::max(num_args, 1));
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
int tcode = static_cast<int>(args[num_args + i].v_int64);
|
||||
int code = (tcode >> (8 * 3)) & 255;
|
||||
type_codes[i] = code;
|
||||
}
|
||||
f.CallPacked(args, &type_codes[0], num_args);
|
||||
TVMValue r; r.v_int64 = 0;
|
||||
return r;
|
||||
};
|
||||
int fid = static_cast<int>(vm_.extern_func.size());
|
||||
vm_.extern_func.push_back(extern_f);
|
||||
fun_idmap_[name] = fid;
|
||||
|
||||
return fid;
|
||||
}
|
||||
|
||||
int CodeGenStackVM::GetVarID(const Variable* v) const {
|
||||
auto it = var_idmap_.find(v);
|
||||
CHECK(it != var_idmap_.end())
|
||||
<< "Find undefined Variable " << v->name_hint;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void CodeGenStackVM::Push_(const ir::Load* op) {
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
|
||||
if (op->type == UInt(32) && op->index.as<IntImm>()) {
|
||||
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
|
||||
} else {
|
||||
this->Push(op->index);
|
||||
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
this->PushOp(StackVM::GetLoad(op->type));
|
||||
}
|
||||
}
|
||||
void CodeGenStackVM::Push_(const ir::Store* op) {
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
|
||||
this->Push(op->index);
|
||||
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
this->Push(op->value);
|
||||
this->PushOp(StackVM::GetStore(op->value.type()));
|
||||
}
|
||||
|
||||
void CodeGenStackVM::Push_(const ir::Allocate* op) {
|
||||
CHECK(!is_zero(op->condition));
|
||||
int vid = AllocVarID(op->buffer_var.get());
|
||||
if (op->new_expr.defined()) {
|
||||
// Prefer global static allocation for the program
|
||||
CHECK_EQ(op->free_function, "nop");
|
||||
this->Push(op->new_expr);
|
||||
this->PushOp(StackVM::STORE_HEAP, vid);
|
||||
} else {
|
||||
LOG(FATAL) << "Dynamic allocation not supported";
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::Push_(const ir::Call* op) {
|
||||
if (op->is_intrinsic(Call::address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
CHECK(op->args.size() == 1 && l);
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
|
||||
this->Push(l->index);
|
||||
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
this->Push(op->args[0]);
|
||||
this->Push(op->args[1]);
|
||||
this->Push(op->args[2]);
|
||||
if (op->type.is_handle()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE);
|
||||
} else if (op->type.is_float()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_FP64);
|
||||
} else if (op->type.is_int() || op->type.is_uint()) {
|
||||
this->PushOp(StackVM::TVM_LOAD_ARG_INT64);
|
||||
} else {
|
||||
LOG(FATAL) << "donot know how to handle type" << op->type;
|
||||
}
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
this->Push(op->args[0]);
|
||||
switch (op->args[1].as<IntImm>()->value) {
|
||||
case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
|
||||
case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
|
||||
case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
|
||||
case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
|
||||
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
|
||||
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
|
||||
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
|
||||
default: LOG(FATAL) << "unknown field code";
|
||||
}
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_call_global)) {
|
||||
CHECK_GE(op->args.size(), 1U);
|
||||
const StringImm* s = op->args[0].as<StringImm>();
|
||||
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
this->Push(op->args[i]);
|
||||
}
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
Type t = op->args[i].type();
|
||||
int code = t.code();
|
||||
int bits = t.bits();
|
||||
int lanes = t.lanes();
|
||||
int tcode = (code << (8 * 3)) | (bits << 16) | lanes;
|
||||
this->PushOp(StackVM::PUSH_I64, tcode);
|
||||
}
|
||||
int num_args = static_cast<int>((op->args.size() - 1) * 2);
|
||||
this->PushOp(StackVM::PUSH_I64, num_args);
|
||||
this->PushOp(StackVM::CALL_EXTERN, GetGlobalFuncID(s->value));
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
|
||||
CHECK_EQ(op->args.size(), 1U);
|
||||
this->Push(op->args[0]);
|
||||
this->PushOp(StackVM::PUSH_I64, 0);
|
||||
this->PushOp(StackVM::EQ_I64);
|
||||
} else {
|
||||
this->HandleUnknownCall(op);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) {
|
||||
LOG(FATAL) << "donot know how to handle call " << op->name;
|
||||
}
|
||||
|
||||
inline void PushBinary(StackVM::OpCode op_int64,
|
||||
const Expr& a,
|
||||
const Expr& b,
|
||||
CodeGenStackVM* p) {
|
||||
p->Push(a);
|
||||
p->Push(b);
|
||||
Type t = a.type();
|
||||
if (t.is_int()) {
|
||||
p->PushOp(op_int64);
|
||||
} else if (t.is_uint()) {
|
||||
if (t.bits() <= 32) {
|
||||
p->PushOp(op_int64);
|
||||
} else {
|
||||
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
|
||||
}
|
||||
} else {
|
||||
p->PushOp(StackVM::CodeI64ToF64(op_int64));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void PushCast(Type dst,
|
||||
Type src,
|
||||
CodeGenStackVM* p) {
|
||||
if (dst.is_int()) {
|
||||
if (src.is_int()) return;
|
||||
if (src.is_uint() && src.bits() <= 32) return;
|
||||
} else if (dst.is_uint() && dst.bits() <= 32) {
|
||||
if (src.is_int()) return;
|
||||
if (src.is_uint() && src.bits() <= 32) return;
|
||||
} else if (dst.is_float()) {
|
||||
if (src.is_float()) return;
|
||||
}
|
||||
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
|
||||
.set_dispatch<StringImm>([](const StringImm *op, CodeGenStackVM *p) {
|
||||
int sid = p->GetStrID(op->value);
|
||||
p->PushOp(StackVM::PUSH_I64, sid);
|
||||
})
|
||||
.set_dispatch<IntImm>([](const IntImm *op, CodeGenStackVM *p) {
|
||||
CHECK(op->value >= std::numeric_limits<int>::min() &&
|
||||
op->value <= std::numeric_limits<int>::max())
|
||||
<< "Int constant exceed bound";
|
||||
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
|
||||
})
|
||||
.set_dispatch<UIntImm>([](const UIntImm *op, CodeGenStackVM *p) {
|
||||
CHECK(op->value <= std::numeric_limits<int>::max())
|
||||
<< "Int constant exceed bound";
|
||||
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
|
||||
})
|
||||
.set_dispatch<FloatImm>([](const FloatImm *op, CodeGenStackVM *p) {
|
||||
LOG(FATAL) << "Float Imm is not supported";
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
|
||||
.set_dispatch<Variable>([](const Variable *op, CodeGenStackVM* p) {
|
||||
int vid = p->GetVarID(op);
|
||||
p->PushOp(StackVM::LOAD_HEAP, vid);
|
||||
})
|
||||
.set_dispatch<Cast>([](const Cast *op, CodeGenStackVM* p) {
|
||||
p->Push(op->value);
|
||||
PushCast(op->type, op->value.type(), p);
|
||||
})
|
||||
.set_dispatch<Add>([](const Add *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::ADD_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<Sub>([](const Sub *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::SUB_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<Mul>([](const Mul *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::MUL_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<Div>([](const Div *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::DIV_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<Mod>([](const Mod *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::MOD_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<Min>([](const Min *op, CodeGenStackVM* p) {
|
||||
p->Push(op->a);
|
||||
p->Push(op->b);
|
||||
p->PushOp(StackVM::PUSH_VALUE, -1);
|
||||
p->PushOp(StackVM::PUSH_VALUE, -1);
|
||||
p->PushOp(StackVM::LT_I64);
|
||||
p->PushOp(StackVM::SELECT);
|
||||
})
|
||||
.set_dispatch<Max>([](const Max *op, CodeGenStackVM* p) {
|
||||
p->Push(op->a);
|
||||
p->Push(op->b);
|
||||
p->PushOp(StackVM::PUSH_VALUE, 0);
|
||||
p->PushOp(StackVM::PUSH_VALUE, -2);
|
||||
p->PushOp(StackVM::LT_I64);
|
||||
p->PushOp(StackVM::SELECT);
|
||||
})
|
||||
.set_dispatch<EQ>([](const EQ *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<LE>([](const LE *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::LE_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<NE>([](const NE *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
|
||||
p->PushOp(StackVM::NOT);
|
||||
})
|
||||
.set_dispatch<LT>([](const LT *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::LT_I64, op->a, op->b, p);
|
||||
})
|
||||
.set_dispatch<GE>([](const GE *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::LT_I64, op->a, op->b, p);
|
||||
p->PushOp(StackVM::NOT);
|
||||
})
|
||||
.set_dispatch<GT>([](const GT *op, CodeGenStackVM* p) {
|
||||
PushBinary(StackVM::LE_I64, op->a, op->b, p);
|
||||
p->PushOp(StackVM::NOT);
|
||||
})
|
||||
.set_dispatch<And>([](const And *op, CodeGenStackVM* p) {
|
||||
p->Push(op->a);
|
||||
int64_t pc_jump = p->GetPC();
|
||||
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
|
||||
p->PushOp(StackVM::POP);
|
||||
p->Push(op->b);
|
||||
int64_t diff = p->GetPC() - pc_jump;
|
||||
p->SetOperand(opr_index, diff);
|
||||
})
|
||||
.set_dispatch<Or>([](const Or *op, CodeGenStackVM* p) {
|
||||
p->Push(op->a);
|
||||
int64_t pc_jump = p->GetPC();
|
||||
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0);
|
||||
p->Push(op->b);
|
||||
int64_t diff = p->GetPC() - pc_jump;
|
||||
p->SetOperand(opr_index, diff);
|
||||
})
|
||||
.set_dispatch<Not>([](const Not* op, CodeGenStackVM* p) {
|
||||
p->PushOp(StackVM::NOT);
|
||||
});
|
||||
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
|
||||
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenStackVM* p) {
|
||||
p->Push(op->body);
|
||||
})
|
||||
.set_dispatch<For>([](const For *op, CodeGenStackVM* p) {
|
||||
CHECK(is_zero(op->min));
|
||||
int vid = p->AllocVarID(op->loop_var.get());
|
||||
p->PushOp(StackVM::PUSH_I64, 0);
|
||||
int64_t loop_head = p->GetPC();
|
||||
p->PushOp(StackVM::STORE_HEAP, vid);
|
||||
p->PushOp(StackVM::LOAD_HEAP, vid);
|
||||
p->Push(op->extent);
|
||||
p->PushOp(StackVM::LT_I64);
|
||||
int64_t label_fjump = p->GetPC();
|
||||
int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
|
||||
p->PushOp(StackVM::POP);
|
||||
p->Push(op->body);
|
||||
p->PushOp(StackVM::LOAD_HEAP, vid);
|
||||
p->PushOp(StackVM::PUSH_I64, 1);
|
||||
p->PushOp(StackVM::ADD_I64);
|
||||
int64_t label_bjump = p->GetPC();
|
||||
int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0);
|
||||
int64_t loop_end = p->GetPC();
|
||||
p->PushOp(StackVM::POP);
|
||||
p->SetOperand(foward_jump, loop_end - label_fjump);
|
||||
p->SetOperand(backward_jump, loop_head - label_bjump);
|
||||
})
|
||||
.set_dispatch<Block>([](const Block *op, CodeGenStackVM* p) {
|
||||
p->Push(op->first);
|
||||
if (op->rest.defined()) p->Push(op->rest);
|
||||
})
|
||||
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenStackVM* p) {
|
||||
if (is_const(op->value)) return;
|
||||
p->Push(op->value);
|
||||
p->PushOp(StackVM::POP);
|
||||
})
|
||||
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenStackVM* p) {
|
||||
p->Push(op->condition);
|
||||
int64_t label_ejump = p->GetPC();
|
||||
int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
|
||||
p->PushOp(StackVM::POP);
|
||||
p->Push(op->then_case);
|
||||
if (op->else_case.defined()) {
|
||||
int64_t label_then_jump = p->GetPC();
|
||||
int64_t then_jump = p->PushOp(StackVM::RJUMP, 0);
|
||||
int64_t else_begin = p->GetPC();
|
||||
p->SetOperand(else_jump, else_begin - label_ejump);
|
||||
p->PushOp(StackVM::POP);
|
||||
p->Push(op->else_case);
|
||||
int64_t if_end = p->GetPC();
|
||||
p->SetOperand(then_jump, if_end - label_then_jump);
|
||||
} else {
|
||||
int64_t if_end = p->GetPC();
|
||||
p->SetOperand(else_jump, if_end - label_ejump);
|
||||
p->PushOp(StackVM::POP);
|
||||
}
|
||||
})
|
||||
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) {
|
||||
p->Push(op->value);
|
||||
int64_t vid = p->AllocVarID(op->var.get());
|
||||
p->PushOp(StackVM::STORE_HEAP, vid);
|
||||
p->Push(op->body);
|
||||
})
|
||||
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) {
|
||||
LOG(FATAL) << "Ramp is not supported";
|
||||
})
|
||||
.set_dispatch<Broadcast>([](const Broadcast *op, CodeGenStackVM* p) {
|
||||
LOG(FATAL) << "Broadcast is not supported";
|
||||
})
|
||||
.set_dispatch<Select>([](const Select *op, CodeGenStackVM* p) {
|
||||
p->Push(op->true_value);
|
||||
p->Push(op->false_value);
|
||||
p->Push(op->condition);
|
||||
p->PushOp(StackVM::SELECT);
|
||||
})
|
||||
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenStackVM* p) {
|
||||
if (op->message.as<StringImm>()) {
|
||||
int sid = p->GetStrID(op->message.as<StringImm>()->value);
|
||||
p->Push(op->condition);
|
||||
p->PushOp(StackVM::ASSERT, sid);
|
||||
}
|
||||
})
|
||||
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenStackVM* p) {
|
||||
p->Push(op->body);
|
||||
})
|
||||
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) {
|
||||
p->Push(op->value);
|
||||
int64_t vid = p->AllocVarID(op->var.get());
|
||||
p->PushOp(StackVM::STORE_HEAP, vid);
|
||||
p->Push(op->body);
|
||||
})
|
||||
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) {
|
||||
p->Push_(op);
|
||||
})
|
||||
.set_dispatch<Store>([](const Store *op, CodeGenStackVM* p) {
|
||||
p->Push_(op);
|
||||
})
|
||||
.set_dispatch<Allocate>([](const Allocate *op, CodeGenStackVM* p) {
|
||||
p->Push_(op);
|
||||
})
|
||||
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
|
||||
p->Push_(op);
|
||||
});
|
||||
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -0,0 +1,110 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_stack_vm.h
|
||||
* \brief Codegen into Simple Stack VM.
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_CODEGEN_STACK_VM_H_
|
||||
#define TVM_CODEGEN_CODEGEN_STACK_VM_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/module.h>
|
||||
#include <tvm/codegen.h>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "../jit/stack_vm.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
using jit::StackVM;
|
||||
|
||||
/*!
|
||||
* \brief A base class to generate a stack VM.
|
||||
* This module is used to generate host wrapper
|
||||
* into device function when only device JIT is available.
|
||||
*/
|
||||
class CodeGenStackVM {
|
||||
public:
|
||||
/*!
|
||||
* \brief Generate a stack VM representing
|
||||
* \param f The function to be compiled
|
||||
* \note Only call compile once,
|
||||
* create a new codegen object each time.
|
||||
*/
|
||||
StackVM Compile(LoweredFunc f);
|
||||
/*! \brief Push stmt to generate new code */
|
||||
void Push(const Stmt& n);
|
||||
/*! \brief Push expr to generate new code */
|
||||
void Push(const Expr& n);
|
||||
/*!
|
||||
* \brief Push the opcode to the code.
|
||||
* \param opcode The code to be pushed.
|
||||
*/
|
||||
void PushOp(StackVM::OpCode opcode);
|
||||
/*!
|
||||
* \brief Push the opcode and operand to the code.
|
||||
* \param opcode The opcode.
|
||||
* \param operand The operand to be pushed.
|
||||
* \return operand_index, indicating location of operand
|
||||
*/
|
||||
int64_t PushOp(StackVM::OpCode opcode, int operand);
|
||||
/*!
|
||||
* \brief Set the relative jump offset to be offset.
|
||||
* \param operand_index The indexed returned by PushOp.
|
||||
* \param operand The operand to be set.
|
||||
*/
|
||||
void SetOperand(int64_t operand_index, int64_t operand);
|
||||
/*! \return The current program pointer */
|
||||
int64_t GetPC() const {
|
||||
return static_cast<int64_t>(vm_.code.size());
|
||||
}
|
||||
/*!
|
||||
* \brief Get string id in vm
|
||||
* \param key The string to get id.
|
||||
* \return the id of the string.
|
||||
*/
|
||||
int GetStrID(const std::string& key);
|
||||
/*!
|
||||
* \brief Push the function to the VM and get a id.
|
||||
* \param f The function to be pushed.
|
||||
*/
|
||||
int GetGlobalFuncID(std::string name);
|
||||
/*!
|
||||
* \brief Allocate a variable name for a newly defined var.
|
||||
* \param v The variable.
|
||||
* \return the heap index of the var.
|
||||
*/
|
||||
int AllocVarID(const Variable* v);
|
||||
/*!
|
||||
* \brief Get a variable name.
|
||||
* \param v The variable.
|
||||
* \return the heap index of the var.
|
||||
*/
|
||||
int GetVarID(const Variable* v) const;
|
||||
// overloadable functions
|
||||
virtual void Push_(const ir::Load* op);
|
||||
virtual void Push_(const ir::Store* op);
|
||||
virtual void Push_(const ir::Allocate* op);
|
||||
virtual void Push_(const ir::Call* op);
|
||||
virtual void HandleUnknownCall(const ir::Call* op);
|
||||
/*! \brief function to to print normal code */
|
||||
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>;
|
||||
// vtable to print code
|
||||
static FType& vtable(); // NOLINT(*)
|
||||
|
||||
private:
|
||||
bool debug_{false};
|
||||
/*! \brief The vm to be generated */
|
||||
StackVM vm_;
|
||||
/*! \brief id of each variable */
|
||||
std::unordered_map<const Variable*, int> var_idmap_;
|
||||
/*! \brief id of each string */
|
||||
std::unordered_map<std::string, int> str_idmap_;
|
||||
/*! \brief id of each function */
|
||||
std::unordered_map<std::string, int> fun_idmap_;
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_CODEGEN_CODEGEN_STACK_VM_H_
|
|
@ -0,0 +1,334 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Implementation stack VM.
|
||||
* \file stack_vm.cc
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include "./stack_vm.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace jit {
|
||||
|
||||
typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
|
||||
|
||||
StackVM::State* StackVM::ThreadLocalState() {
|
||||
return StackVMStateStore::Get();
|
||||
}
|
||||
|
||||
#define STACK_VM_BINOP(OP, FIELD) \
|
||||
{ \
|
||||
stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \
|
||||
sp -= 1; pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_CMPOP(OP, FIELD) \
|
||||
{ \
|
||||
stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
|
||||
sp -= 1; pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
|
||||
{ \
|
||||
stack[sp].FIELD = static_cast<DST_TYPE>( \
|
||||
*static_cast<SRC_TYPE*>(stack[sp].v_handle)); \
|
||||
pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_STORE(FIELD, DST_TYPE) \
|
||||
{ \
|
||||
*static_cast<DST_TYPE*>(stack[sp - 1].v_handle) = \
|
||||
static_cast<DST_TYPE>(stack[sp].FIELD); \
|
||||
sp -= 2; pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_TVM_LOAD_ARG(OP, TYPE) \
|
||||
{ \
|
||||
TVMValue* args = static_cast<TVMValue*>(stack[sp - 2].v_handle); \
|
||||
int64_t index = stack[sp].v_int64; \
|
||||
int tc = static_cast<int*>(stack[sp - 1].v_handle)[index]; \
|
||||
CHECK(OP) \
|
||||
<< " argument " << index << " is expected to be " << TYPE; \
|
||||
stack[sp - 2] = args[index]; \
|
||||
sp -= 2; \
|
||||
pc += 1; \
|
||||
}
|
||||
|
||||
|
||||
#define STACK_VM_TVM_ARRARY_GET(FIELD, TYPE, SFIELD) \
|
||||
{ \
|
||||
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle); \
|
||||
stack[sp].FIELD = (TYPE)(arr->SFIELD); \
|
||||
pc += 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_CODE0(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_CODE1(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
|
||||
<< "[" << pc + 1 << "]" << std::endl; \
|
||||
return pc + 2; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \
|
||||
<< " " << heap_id_name[code[pc + 1].v_int] << "\n" \
|
||||
<< "[" << pc + 1 << "]" << std::endl; \
|
||||
return pc + 2; \
|
||||
}
|
||||
|
||||
#define STACK_VM_PRINT_JUMP(CODE) \
|
||||
case CODE: { \
|
||||
os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \
|
||||
<< " to " << pc + code[pc + 1].v_int << '\n' \
|
||||
<< "[" << pc + 1 << "]" << std::endl; \
|
||||
return pc + 2; \
|
||||
}
|
||||
|
||||
|
||||
int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
|
||||
switch (code[pc].op_code) {
|
||||
// int
|
||||
STACK_VM_PRINT_CODE0(ADD_I64);
|
||||
STACK_VM_PRINT_CODE0(SUB_I64);
|
||||
STACK_VM_PRINT_CODE0(MUL_I64);
|
||||
STACK_VM_PRINT_CODE0(MOD_I64);
|
||||
STACK_VM_PRINT_CODE0(DIV_I64);
|
||||
STACK_VM_PRINT_CODE0(EQ_I64);
|
||||
STACK_VM_PRINT_CODE0(LT_I64);
|
||||
STACK_VM_PRINT_CODE0(LE_I64);
|
||||
// floats
|
||||
STACK_VM_PRINT_CODE0(ADD_F64);
|
||||
STACK_VM_PRINT_CODE0(SUB_F64);
|
||||
STACK_VM_PRINT_CODE0(MUL_F64);
|
||||
STACK_VM_PRINT_CODE0(DIV_F64);
|
||||
STACK_VM_PRINT_CODE0(EQ_F64);
|
||||
STACK_VM_PRINT_CODE0(LT_F64);
|
||||
STACK_VM_PRINT_CODE0(LE_F64);
|
||||
// addressing load
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_UINT32);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT32);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT64);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_FP64);
|
||||
STACK_VM_PRINT_CODE0(ADDR_LOAD_HANDLE);
|
||||
STACK_VM_PRINT_CODE0(ADDR_STORE_INT64);
|
||||
STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
|
||||
STACK_VM_PRINT_CODE0(NOT);
|
||||
STACK_VM_PRINT_CODE0(ADDR_ADD);
|
||||
// stack ops
|
||||
STACK_VM_PRINT_CODE1(PUSH_I64);
|
||||
STACK_VM_PRINT_CODE1(PUSH_VALUE);
|
||||
STACK_VM_PRINT_CODE0(POP);
|
||||
STACK_VM_PRINT_CODE0(SELECT);
|
||||
STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP);
|
||||
STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP);
|
||||
STACK_VM_PRINT_CODE1(CALL_EXTERN);
|
||||
STACK_VM_PRINT_CODE1(ASSERT);
|
||||
STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE);
|
||||
STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE);
|
||||
STACK_VM_PRINT_JUMP(RJUMP);
|
||||
STACK_VM_PRINT_CODE1(ASSERT_SP);
|
||||
// Intrinsics
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_INT64);
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_FP64);
|
||||
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_HANDLE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_DATA);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
|
||||
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
|
||||
}
|
||||
LOG(FATAL) << "unknown op code " << code[pc].op_code;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*)
|
||||
int64_t pc = 0;
|
||||
const int64_t code_size = static_cast<int64_t>(vm.code.size());
|
||||
os << "Program dump: code-size=" << code_size << '\n'
|
||||
<< "----------begin-----------------\n";
|
||||
while (pc < code_size) {
|
||||
pc = vm.PrintCode(os, pc);
|
||||
}
|
||||
os << "----------end--------------------\n";
|
||||
return os;
|
||||
}
|
||||
|
||||
void StackVM::Run(State* s) const {
|
||||
int64_t sp = s->sp;
|
||||
int64_t pc = s->pc;
|
||||
std::vector<TVMValue>& stack = s->stack;
|
||||
std::vector<TVMValue>& heap = s->heap;
|
||||
|
||||
if (stack.size() < stack_size) {
|
||||
stack.resize(stack_size);
|
||||
}
|
||||
int64_t stack_cap = static_cast<int64_t>(stack_size - 4);
|
||||
if (heap.size() < heap_size) {
|
||||
heap.resize(heap_size);
|
||||
}
|
||||
const int64_t code_size = static_cast<int64_t>(code.size());
|
||||
|
||||
while (pc < code_size) {
|
||||
switch (code[pc].op_code) {
|
||||
case ADD_I64: STACK_VM_BINOP(+, v_int64); break;
|
||||
case SUB_I64: STACK_VM_BINOP(-, v_int64); break;
|
||||
case MUL_I64: STACK_VM_BINOP(*, v_int64); break;
|
||||
case DIV_I64: STACK_VM_BINOP(/, v_int64); break;
|
||||
case MOD_I64: STACK_VM_BINOP(%, v_int64); break;
|
||||
case EQ_I64: STACK_VM_CMPOP(==, v_int64); break;
|
||||
case LT_I64: STACK_VM_CMPOP(<, v_int64); break;
|
||||
case LE_I64: STACK_VM_CMPOP(<=, v_int64); break;
|
||||
case ADD_F64: STACK_VM_BINOP(+, v_float64); break;
|
||||
case SUB_F64: STACK_VM_BINOP(-, v_float64); break;
|
||||
case MUL_F64: STACK_VM_BINOP(*, v_float64); break;
|
||||
case DIV_F64: STACK_VM_BINOP(/, v_float64); break;
|
||||
case EQ_F64: STACK_VM_CMPOP(==, v_float64); break;
|
||||
case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
|
||||
case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
|
||||
// addressing
|
||||
case ADDR_LOAD_UINT32: STACK_VM_LOAD(v_int64, int64_t, uint32_t); break;
|
||||
case ADDR_LOAD_INT32: STACK_VM_LOAD(v_int64, int64_t, int32_t); break;
|
||||
case ADDR_LOAD_INT64: STACK_VM_LOAD(v_int64, int64_t, int64_t); break;
|
||||
case ADDR_LOAD_FP64: STACK_VM_LOAD(v_float64, double, double); break;
|
||||
case ADDR_LOAD_HANDLE: STACK_VM_LOAD(v_handle, void*, void*); break;
|
||||
case ADDR_STORE_INT64: STACK_VM_STORE(v_int64, int64_t); break;
|
||||
case ADDR_ADD: {
|
||||
stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
|
||||
sp = sp - 1;
|
||||
pc = pc + 1;
|
||||
break;
|
||||
}
|
||||
case ARRAY_LOAD_UINT32: {
|
||||
stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int]; // NOLINT(*)
|
||||
pc = pc + 2;
|
||||
break;
|
||||
}
|
||||
case NOT: {
|
||||
stack[sp].v_int64 = !stack[sp].v_int64;
|
||||
pc += 1;
|
||||
break;
|
||||
}
|
||||
case PUSH_I64: {
|
||||
stack[sp + 1].v_int64 = code[pc + 1].v_int;
|
||||
sp += 1;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case PUSH_VALUE: {
|
||||
int relpos = code[pc + 1].v_int;
|
||||
CHECK_LE(relpos, 0);
|
||||
stack[sp + 1] = stack[sp + relpos];
|
||||
sp += 1;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case POP: {
|
||||
sp -= 1;
|
||||
pc += 1;
|
||||
break;
|
||||
}
|
||||
case SELECT: {
|
||||
stack[sp - 2] = (stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]);
|
||||
sp -= 2;
|
||||
pc += 1;
|
||||
break;
|
||||
}
|
||||
case LOAD_HEAP: {
|
||||
stack[sp + 1] = heap[code[pc + 1].v_int];
|
||||
sp += 1;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case STORE_HEAP: {
|
||||
heap[code[pc + 1].v_int] = stack[sp];
|
||||
sp -= 1;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case CALL_EXTERN: {
|
||||
int num_args = static_cast<int>(stack[sp].v_int64);
|
||||
int call_fid = code[pc + 1].v_int;
|
||||
stack[sp - num_args] = extern_func[call_fid](
|
||||
&stack[sp - num_args], num_args);
|
||||
sp = sp - num_args;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case ASSERT: {
|
||||
CHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
|
||||
sp -= 1;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case RJUMP_IF_TRUE: {
|
||||
if (stack[sp].v_int64) {
|
||||
pc += code[pc + 1].v_int;
|
||||
} else {
|
||||
pc += 2;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RJUMP_IF_FALSE: {
|
||||
if (!stack[sp].v_int64) {
|
||||
pc += code[pc + 1].v_int;
|
||||
} else {
|
||||
pc += 2;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RJUMP: {
|
||||
pc += code[pc + 1].v_int;
|
||||
break;
|
||||
}
|
||||
case ASSERT_SP: {
|
||||
int64_t expected = code[pc + 1].v_int;
|
||||
CHECK_EQ(sp, expected)
|
||||
<< "sp assertion failed, expected="
|
||||
<< expected << " now=" << sp << ", pc=" << pc;
|
||||
pc += 2;
|
||||
break;
|
||||
}
|
||||
case TVM_LOAD_ARG_INT64: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kInt, "int"); break;
|
||||
}
|
||||
case TVM_LOAD_ARG_FP64: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
|
||||
}
|
||||
case TVM_LOAD_ARG_HANDLE: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_DATA: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_SHAPE: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, shape); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_STRIDES: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, strides); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_NDIM: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_CODE: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_BITS: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_LANES: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.lanes); break;
|
||||
}
|
||||
}
|
||||
CHECK_LT(sp, stack_cap) << "Stack overflow";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace tvm
|
|
@ -0,0 +1,298 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file stack_vm.h
|
||||
* \brief A simple stack-based virtual machine.
|
||||
*
|
||||
* This can be used to interepret host side code
|
||||
* to setup calls into device functions
|
||||
* when only JIT for device is available(via NVRTC or OpenCL).
|
||||
*/
|
||||
#ifndef TVM_JIT_STACK_VM_H_
|
||||
#define TVM_JIT_STACK_VM_H_
|
||||
|
||||
#include <tvm/base.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tvm {
|
||||
namespace jit {
|
||||
|
||||
/*!
|
||||
* \brief A simple stack-based virtual machine.
|
||||
*/
|
||||
class StackVM {
|
||||
public:
|
||||
/*!
|
||||
* \brief The opcode of stack vm
|
||||
* \note Notation
|
||||
* - sp Stack pointer
|
||||
* - pc Program pointer
|
||||
*/
|
||||
enum OpCode {
|
||||
// integer ops
|
||||
ADD_I64,
|
||||
SUB_I64,
|
||||
MUL_I64,
|
||||
DIV_I64,
|
||||
MOD_I64,
|
||||
EQ_I64,
|
||||
LT_I64,
|
||||
LE_I64,
|
||||
// floating ops
|
||||
ADD_F64,
|
||||
SUB_F64,
|
||||
MUL_F64,
|
||||
DIV_F64,
|
||||
EQ_F64,
|
||||
LT_F64,
|
||||
LE_F64,
|
||||
// load operation
|
||||
ADDR_LOAD_UINT32,
|
||||
ADDR_LOAD_INT32,
|
||||
ADDR_LOAD_INT64,
|
||||
ADDR_LOAD_FP64,
|
||||
ADDR_LOAD_HANDLE,
|
||||
// store operations
|
||||
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
|
||||
// sp = sp - 2;
|
||||
ADDR_STORE_INT64,
|
||||
/*!
|
||||
* \brief Quick routine to load uint32 from constant offset.
|
||||
* \code
|
||||
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int];
|
||||
* pc = pc + 2;
|
||||
* \endcode
|
||||
*/
|
||||
ARRAY_LOAD_UINT32,
|
||||
// logical ops
|
||||
NOT,
|
||||
/*!
|
||||
* \brief Add address by an offset.
|
||||
* \code
|
||||
* stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
|
||||
* sp = sp - 1;
|
||||
* \endcode
|
||||
*/
|
||||
ADDR_ADD,
|
||||
/*!
|
||||
* \brief push integer fetched from next pc position into stack
|
||||
* \code
|
||||
* stack[sp + 1].v_int64 = code[pc + 1].v_int;
|
||||
* pc = pc + 2;
|
||||
* sp = sp + 1;
|
||||
* \endcode
|
||||
*/
|
||||
PUSH_I64,
|
||||
/*!
|
||||
* \brief push a value given relative index on the stack
|
||||
* \code
|
||||
* stack[sp + 1] = stack[sp + code[pc + 1].v_int];
|
||||
* pc = pc + 2;
|
||||
* sp = sp + 1;
|
||||
* \endcode
|
||||
*/
|
||||
PUSH_VALUE,
|
||||
/*!
|
||||
* \brief Load data from heap to top of stack
|
||||
* \code
|
||||
* stack[sp + 1] = heap[code[pc + 1].v_int];
|
||||
* pc = pc + 2;
|
||||
* sp = sp + 1;
|
||||
* \endcode
|
||||
*/
|
||||
LOAD_HEAP,
|
||||
/*!
|
||||
* \brief Store data to heap
|
||||
* \code
|
||||
* heap[code[pc + 1].v_int] = stack[sp];
|
||||
* sp = sp - 1;
|
||||
* \endcode
|
||||
*/
|
||||
STORE_HEAP,
|
||||
/*! \brief pop value from top of the stack */
|
||||
POP,
|
||||
/*!
|
||||
* \brief select based on operands.
|
||||
* \code
|
||||
* stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
|
||||
* sp = sp - 2;
|
||||
* \endcode
|
||||
*/
|
||||
SELECT,
|
||||
/*!
|
||||
* \brief call an extern function
|
||||
* \code
|
||||
* num_args = stack[sp].v_int64;
|
||||
* call_fid = code[pc + 1].v_int;
|
||||
* f = extern_func[call_fid];
|
||||
* stack[sp - num_args] = f(&stack[sp - num_args], num_args);
|
||||
* sp = sp - num_args;
|
||||
* \endcode
|
||||
*/
|
||||
CALL_EXTERN,
|
||||
/*!
|
||||
* \brief Assert condition is true.
|
||||
* \code
|
||||
* CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
|
||||
* sp = sp - 1;
|
||||
* \endcode
|
||||
*/
|
||||
ASSERT,
|
||||
/*!
|
||||
* \brief Relative Jump if the condition is true,
|
||||
* Does not change the stack status.
|
||||
* \code
|
||||
* if (stack[sp]) {
|
||||
* pc += code[pc + 1].v_int
|
||||
* } else {
|
||||
* pc = pc + 2;
|
||||
* }
|
||||
* \endcode
|
||||
*/
|
||||
RJUMP_IF_TRUE,
|
||||
/*!
|
||||
* \brief Relative Jump if the condition is true,
|
||||
* Does not change the stack status.
|
||||
* \code
|
||||
* if (stack[sp]) {
|
||||
* pc += code[pc + 1].v_int
|
||||
* } else {
|
||||
* pc = pc + 2;
|
||||
* }
|
||||
* \endcode
|
||||
*/
|
||||
RJUMP_IF_FALSE,
|
||||
/*!
|
||||
* \brief Relative jump to a location.
|
||||
* \code
|
||||
* pc += code[pc + 1].v_int;
|
||||
* \endcode
|
||||
*/
|
||||
RJUMP,
|
||||
/*!
|
||||
* \brief debug instruction.
|
||||
* \code
|
||||
* CHECK_EQ(sp, code[pc + 1]).v_int;
|
||||
* pc += 2;
|
||||
* \code
|
||||
*/
|
||||
ASSERT_SP,
|
||||
// Intrinsics for API function,
|
||||
TVM_LOAD_ARG_INT64,
|
||||
TVM_LOAD_ARG_FP64,
|
||||
TVM_LOAD_ARG_HANDLE,
|
||||
TVM_ARRAY_GET_DATA,
|
||||
TVM_ARRAY_GET_SHAPE,
|
||||
TVM_ARRAY_GET_STRIDES,
|
||||
TVM_ARRAY_GET_NDIM,
|
||||
TVM_ARRAY_GET_TYPE_CODE,
|
||||
TVM_ARRAY_GET_TYPE_BITS,
|
||||
TVM_ARRAY_GET_TYPE_LANES
|
||||
};
|
||||
/*! \brief The code structure */
|
||||
union Code {
|
||||
OpCode op_code;
|
||||
int v_int;
|
||||
};
|
||||
/*! \brief The state object of StackVM */
|
||||
struct State {
|
||||
/*! \brief The execution stack */
|
||||
std::vector<TVMValue> stack;
|
||||
/*! \brief The global heap space */
|
||||
std::vector<TVMValue> heap;
|
||||
/*! \brief stack pointer */
|
||||
int64_t sp{0};
|
||||
/*! \brief program counter */
|
||||
int64_t pc{0};
|
||||
};
|
||||
/*! \brief execute the stack vm with given state */
|
||||
void Run(State* state) const;
|
||||
/*!
|
||||
* \brief Print instruction at location pc
|
||||
* \param os The ostream
|
||||
* \param pc The pc
|
||||
* \return the pc to next instruction.
|
||||
*/
|
||||
int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*)
|
||||
/*! \brief Get thread local state of the stack VM */
|
||||
static State* ThreadLocalState();
|
||||
/*! \brief extern function that will mutate the state */
|
||||
using ExternFunc = std::function<TVMValue (const TVMValue* args, int num_args)>;
|
||||
/*! \brief The instructions */
|
||||
std::vector<Code> code;
|
||||
/*! \brief constant error messages */
|
||||
std::vector<std::string> str_data;
|
||||
/*! \brief Extern functions */
|
||||
std::vector<ExternFunc> extern_func;
|
||||
/*! \brief name of each heap id*/
|
||||
std::vector<std::string> heap_id_name;
|
||||
/*! \brief The memory size needed */
|
||||
size_t heap_size{0};
|
||||
/*! \brief The stack size required */
|
||||
size_t stack_size{1024};
|
||||
/*!
|
||||
* \brief Convert I64 opcode to F64 Ones
|
||||
* \param code The op code.
|
||||
* \return the F64 op code.
|
||||
*/
|
||||
static OpCode CodeI64ToF64(OpCode code) {
|
||||
switch (code) {
|
||||
case ADD_I64: return ADD_F64;
|
||||
case SUB_I64: return SUB_F64;
|
||||
case MUL_I64: return MUL_F64;
|
||||
case DIV_I64: return DIV_F64;
|
||||
case EQ_I64: return EQ_F64;
|
||||
case LT_I64: return LT_F64;
|
||||
case LE_I64: return LE_F64;
|
||||
case MOD_I64: LOG(FATAL) << "cannot handle mod for float";
|
||||
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Get load opcode for type t
|
||||
* \param t the type code.
|
||||
* \return The load opcode
|
||||
*/
|
||||
static OpCode GetLoad(Type t) {
|
||||
CHECK_EQ(t.lanes(), 1);
|
||||
if (t.is_handle()) return ADDR_LOAD_HANDLE;
|
||||
if (t.is_int()) {
|
||||
switch (t.bits()) {
|
||||
case 32 : return ADDR_LOAD_INT32;
|
||||
case 64 : return ADDR_LOAD_INT64;
|
||||
}
|
||||
} else if (t.is_uint()) {
|
||||
switch (t.bits()) {
|
||||
case 32 : return ADDR_LOAD_UINT32;
|
||||
}
|
||||
} else if (t.is_float()) {
|
||||
switch (t.bits()) {
|
||||
case 64 : return ADDR_LOAD_FP64;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Cannot load type " << t;
|
||||
return ADDR_LOAD_FP64;
|
||||
}
|
||||
/*!
|
||||
* \brief Get store opcode for type t
|
||||
* \param t the type code.
|
||||
* \return The load opcode
|
||||
*/
|
||||
static OpCode GetStore(Type t) {
|
||||
CHECK_EQ(t.lanes(), 1);
|
||||
if (t.is_int()) {
|
||||
switch (t.bits()) {
|
||||
case 64 : return ADDR_STORE_INT64;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Cannot store type " << t;
|
||||
return ADDR_LOAD_FP64;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace tvm
|
||||
#endif // TVM_JIT_STACK_VM_H_
|
|
@ -4,7 +4,7 @@
|
|||
* \brief Device specific implementations
|
||||
*/
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/runtime.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <algorithm>
|
||||
#include "./runtime_base.h"
|
||||
#include "./device_api.h"
|
||||
|
@ -170,7 +170,7 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
|
|||
|
||||
int TVMFuncFree(TVMFunctionHandle func) {
|
||||
API_BEGIN();
|
||||
delete static_cast<PackedFunc::FType*>(func);
|
||||
delete static_cast<PackedFunc*>(func);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -179,7 +179,35 @@ int TVMFuncCall(TVMFunctionHandle func,
|
|||
int* arg_type_codes,
|
||||
int num_args) {
|
||||
API_BEGIN();
|
||||
(*static_cast<const PackedFunc::FType*>(func))(
|
||||
(*static_cast<const PackedFunc*>(func)).CallPacked(
|
||||
args, arg_type_codes, num_args);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
||||
void* resource_handle,
|
||||
TVMPackedCFuncFinalizer fin,
|
||||
TVMFunctionHandle *out) {
|
||||
API_BEGIN();
|
||||
if (fin == nullptr) {
|
||||
*out = new PackedFunc(
|
||||
[func, resource_handle](const TVMValue* args,
|
||||
const int* type_codes,
|
||||
int num_args) {
|
||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
||||
num_args, resource_handle);
|
||||
});
|
||||
} else {
|
||||
// wrap it in a shared_ptr, with fin as deleter.
|
||||
// so fin will be called when the lambda went out of scope.
|
||||
std::shared_ptr<void> rpack(resource_handle, fin);
|
||||
*out = new PackedFunc(
|
||||
[func, rpack](const TVMValue* args,
|
||||
const int* type_codes,
|
||||
int num_args) {
|
||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
||||
num_args, rpack.get());
|
||||
});
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file packed_func_registry.cc
|
||||
* \brief The global registry of packed function.
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "./runtime_base.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
struct PackedFuncRegistry {
|
||||
// map storing the functions.
|
||||
// We delibrately used raw pointer
|
||||
// This is because PackedFunc can contain callbacks into the host languge(python)
|
||||
// and the resource can become invalid because of indeterminstic order of destruction.
|
||||
// The resources will only be recycled during program exit.
|
||||
std::unordered_map<std::string, PackedFunc*> fmap;
|
||||
|
||||
static PackedFuncRegistry* Global() {
|
||||
static PackedFuncRegistry inst;
|
||||
return &inst;
|
||||
}
|
||||
};
|
||||
|
||||
const PackedFunc& PackedFunc::RegisterGlobal(
|
||||
const std::string& name, PackedFunc f) {
|
||||
PackedFuncRegistry* r = PackedFuncRegistry::Global();
|
||||
auto it = r->fmap.find(name);
|
||||
CHECK(it == r->fmap.end())
|
||||
<< "Global PackedFunc " << name << " is already registered";
|
||||
PackedFunc* fp = new PackedFunc(f);
|
||||
r->fmap[name] = fp;
|
||||
return *fp;
|
||||
}
|
||||
|
||||
const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
|
||||
PackedFuncRegistry* r = PackedFuncRegistry::Global();
|
||||
auto it = r->fmap.find(name);
|
||||
CHECK(it != r->fmap.end())
|
||||
<< "Global PackedFunc " << name << " is not registered";
|
||||
return *(it->second);
|
||||
}
|
||||
|
||||
std::vector<std::string> PackedFunc::ListGlobalNames() {
|
||||
PackedFuncRegistry* r = PackedFuncRegistry::Global();
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(r->fmap.size());
|
||||
for (const auto &kv : r->fmap) {
|
||||
keys.push_back(kv.first);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
||||
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
||||
using tvm::runtime::PackedFunc;
|
||||
API_BEGIN();
|
||||
PackedFunc::RegisterGlobal(name, *static_cast<PackedFunc*>(f));
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
|
||||
using tvm::runtime::PackedFunc;
|
||||
API_BEGIN();
|
||||
*out = new PackedFunc(PackedFunc::GetGlobal(name));
|
||||
API_END();
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
#include <dmlc/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <tvm/runtime/runtime.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
|
||||
TEST(PackedFunc, Basic) {
|
||||
using namespace tvm::runtime;
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
def tvm_call_global(*args):
|
||||
args = tvm.convert(args)
|
||||
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
|
||||
|
||||
|
||||
def test_stack_vm_basic():
|
||||
a = tvm.nd.array(np.zeros(10, dtype='float32'))
|
||||
@tvm.register_func
|
||||
def tvm_call_back_get_shape(shape0):
|
||||
assert shape0 == a.shape[0]
|
||||
|
||||
n = tvm.Var('n')
|
||||
Ab = tvm.Buffer((n, ), tvm.float32)
|
||||
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
|
||||
print(stmt)
|
||||
fapi = tvm.codegen.MakeAPI(stmt, "print_shape", [Ab], 1)
|
||||
print(fapi.body)
|
||||
f = tvm.codegen.BuildStackVM(fapi)
|
||||
f(a)
|
||||
|
||||
|
||||
@tvm.register_func
|
||||
def tvm_stack_vm_print(*x):
|
||||
print(x)
|
||||
|
||||
|
||||
def test_stack_vm_loop():
|
||||
dtype = 'int64'
|
||||
n = tvm.Var('n')
|
||||
Ab = tvm.Buffer((n, ), dtype)
|
||||
i = tvm.Var('i')
|
||||
# for i in 0 to n-1:
|
||||
stmt = tvm.make.For(
|
||||
i, 0, n - 1, 0, 0,
|
||||
tvm.make.Block(
|
||||
tvm.make.Store(Ab.ptr,
|
||||
tvm.make.Load(dtype, Ab.ptr, i) + 1,
|
||||
i + 1),
|
||||
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
|
||||
print(stmt)
|
||||
fapi = tvm.codegen.MakeAPI(stmt, "ramp", [Ab], 1)
|
||||
f = tvm.codegen.BuildStackVM(fapi)
|
||||
a = tvm.nd.array(np.zeros(10, dtype=dtype))
|
||||
f(a)
|
||||
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
|
||||
|
||||
|
||||
def test_stack_vm_cond():
|
||||
dtype = 'int64'
|
||||
n = tvm.Var('n')
|
||||
Ab = tvm.Buffer((n, ), dtype)
|
||||
i = tvm.Var('i')
|
||||
# for i in 0 to n-1:
|
||||
stmt = tvm.make.For(
|
||||
i, 0, n - 1, 0, 0,
|
||||
tvm.make.IfThenElse(
|
||||
tvm.make.EQ(i, 4),
|
||||
tvm.make.Store(Ab.ptr,
|
||||
tvm.make.Load(dtype, Ab.ptr, i) + 1, i + 1),
|
||||
tvm.make.Store(Ab.ptr,
|
||||
tvm.make.Load(dtype, Ab.ptr, i) + 2, i + 1)))
|
||||
print(stmt)
|
||||
fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1)
|
||||
f = tvm.codegen.BuildStackVM(fapi)
|
||||
a = tvm.nd.array(np.zeros(10, dtype=dtype))
|
||||
f(a)
|
||||
y = np.arange(a.shape[0]) * 2
|
||||
y[5:] -= 1
|
||||
np.testing.assert_equal(a.asnumpy(), y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_stack_vm_cond()
|
|
@ -1,17 +0,0 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def test_function():
|
||||
ctx = tvm.cpu(0)
|
||||
x = np.random.randint(0, 10, size=(3, 4))
|
||||
x = np.array(x)
|
||||
y = tvm.nd.array(x, ctx=ctx)
|
||||
|
||||
f = tvm.codegen.DummyHelloFunction()
|
||||
f(y, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_function()
|
|
@ -0,0 +1,40 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
def test_function():
|
||||
ctx = tvm.cpu(0)
|
||||
x = np.random.randint(0, 10, size=(3, 4))
|
||||
x = np.array(x)
|
||||
y = tvm.nd.array(x, ctx=ctx)
|
||||
|
||||
f = tvm.codegen.DummyHelloFunction()
|
||||
f(y, 10)
|
||||
|
||||
|
||||
def test_get_global():
|
||||
targs = (10, 10.0, "hello")
|
||||
# register into global function table
|
||||
@tvm.register_func
|
||||
def my_packed_func(*args):
|
||||
assert(tuple(args) == targs)
|
||||
# get it out from global function table
|
||||
f = tvm.get_global_func("my_packed_func")
|
||||
assert isinstance(f, tvm.nd.Function)
|
||||
f(*targs)
|
||||
|
||||
|
||||
def test_convert():
|
||||
# convert a function to tvm function
|
||||
targs = (10, 10.0, "hello", 10)
|
||||
def myfunc(*args):
|
||||
assert(tuple(args) == targs)
|
||||
|
||||
f = tvm.convert(myfunc)
|
||||
assert isinstance(f, tvm.nd.Function)
|
||||
f(*targs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_function()
|
||||
test_convert()
|
||||
test_get_global()
|
Загрузка…
Ссылка в новой задаче