[API/JIT] Enable registerable global function, introduce StackVM intepreter (#25)

This commit is contained in:
Tianqi Chen 2017-01-24 23:47:07 -08:00 коммит произвёл GitHub
Родитель 01a7ce0cb6
Коммит 4242b9cff5
22 изменённых файлов: 1713 добавлений и 53 удалений

Просмотреть файл

@ -10,7 +10,7 @@
#include "./base.h"
#include "./expr.h"
#include "./module.h"
#include "./runtime/runtime.h"
#include "./runtime/packed_func.h"
namespace tvm {

Просмотреть файл

@ -81,11 +81,13 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* bool tvm_print(VType value) {
* LOG(INFO) << value;
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_print = "tvm_print";
constexpr const char* tvm_call_global = "tvm_call_global";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {

Просмотреть файл

@ -20,9 +20,6 @@ namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;
// Internal node container of module.
class ModuleNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.

Просмотреть файл

@ -161,7 +161,7 @@ TVM_DLL const char *TVMGetLastError(void);
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
@ -188,7 +188,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
@ -198,6 +198,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
@ -206,6 +207,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
@ -214,13 +216,14 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return whether
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
@ -239,6 +242,57 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* type_codes,
int num_args);
/*!
* \brief C type of packed function.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFunc)(
TVMValue* args, int* type_codes, int num_args, void* resource_handle);
/*!
* \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out);
/*!
* \brief Register the function to runtime's global table.
*
* The registered function then can be pulled by the backend by the name.
*
* \param name The name of the function.
* \param f The function to be registered.
*/
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
} // TVM_EXTERN_C
#endif // TVM_RUNTIME_C_RUNTIME_API_H_

Просмотреть файл

@ -1,35 +1,43 @@
/*!
* Copyright (c) 2016 by Contributors
* \file runtime.h
* \file packed_func.h
* \brief Runtime related c++ class.
*/
#ifndef TVM_RUNTIME_RUNTIME_H_
#define TVM_RUNTIME_RUNTIME_H_
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Packed function is a runtime function
* whose argument type_codes are erased by packed format.
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions.
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
public:
/*! \brief The internal std::function */
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief invoke the packed function by directly passing in arguments.
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
* \return The first return value.
*/
template<typename... Args>
inline void operator()(Args&& ...args) const;
@ -41,9 +49,25 @@ class PackedFunc {
*/
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
/*! \return the internal body function */
inline FType body() const {
return body_;
}
inline FType body() const;
/*!
* \brief Register f as into global function table
* \param name The name of the function.
* \param f The function to be registered.
* \return Reference to the registered function.
* \note The returned reference is valid until the end of the program
*/
static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
static std::vector<std::string> ListGlobalNames();
private:
/*! \brief internal container of packed function */
@ -56,6 +80,10 @@ inline void PackedFunc::CallPacked(
body_(args, type_codes, num_args);
}
inline PackedFunc::FType PackedFunc::body() const {
return body_;
}
template<bool stop, std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_ {
static inline void run(const std::tuple<Args...>& args, F f) {
@ -124,4 +152,4 @@ inline void PackedFunc::operator()(Args&& ...args) const {
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RUNTIME_H_
#endif // TVM_RUNTIME_PACKED_FUNC_H_

Просмотреть файл

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
@ -13,7 +13,7 @@ from .._base import c_str, py_str, string_types
from .._base import check_call, ctypes2docstring
from .. import _api_internal
from . import _runtime_api
from ._types import TVMValue, TypeCode
from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer
# type definitions
APIFuncHandle = ctypes.c_void_p
@ -57,6 +57,13 @@ def _return_func(x):
return _runtime_api._function_cls(handle)
def _return_handle(x):
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
return handle
RET_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.INT: lambda x: x.v_int64,
@ -66,6 +73,15 @@ RET_SWITCH = {
TypeCode.FUNC_HANDLE: _return_func
}
PACK_ARG_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.HANDLE: lambda x: _return_handle,
}
class SliceBase(object):
"""base class of slice object"""
pass
@ -159,10 +175,53 @@ def const(value, dtype=None):
return _api_internal._const(value, dtype)
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
local_pyfunc(*pyargs)
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return _runtime_api._function_cls(handle)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
if isinstance(value, (NodeBase, _runtime_api.FunctionBase)):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _api_internal._Array(*value)
@ -176,10 +235,11 @@ def convert(value):
return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
elif callable(value):
return convert_to_tvm_func(value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
return value
raise ValueError("don't know how to handle type %s" % type(value))
return value
def _push_arg(arg):
@ -270,6 +330,59 @@ def register_node(type_key=None):
NODE_TYPE[cls.__name__] = cls
return cls
def register_func(func_name, f=None):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function
The function to be registered.
Returns
-------
fregister : function
Register function if f is not specified.
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(myf):
"""internal register function"""
if not isinstance(myf, _runtime_api.FunctionBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle))
if f:
register(f)
else:
return register
def get_global_func(name):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
Returns
-------
func : tvm.nd.Function
The function to be returned.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
return _runtime_api._function_cls(handle)
def _init_api_module(root_namespace):
"""List and add all the functions to current module."""
plist = ctypes.POINTER(ctypes.c_char_p)()

Просмотреть файл

@ -70,3 +70,16 @@ class TVMType(ctypes.Structure):
if self.lanes != 1:
x += "x%d" % self.lanes
return x
TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)

Просмотреть файл

@ -1,9 +1,9 @@
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
# pylint: disable=redefined-builtin, undefined-variable, unused-import
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
from ._ctypes._api import _init_api_module, convert
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
from . import _api_internal
from . import make as _make
from . import expr as _expr

Просмотреть файл

@ -52,5 +52,10 @@ TVM_REGISTER_API(_codegen_DummyHelloFunction)
*ret = runtime::PackedFunc(DummyHelloFunction);
});
TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = BuildStackVM(args.at(0));
});
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -46,7 +46,8 @@ TVM_REGISTER_API(_make_Call)
args.at(1),
args.at(2),
static_cast<Call::CallType>(args.at(3).operator int()),
args.at(4));
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Allocate)

Просмотреть файл

@ -4,6 +4,7 @@
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h>
@ -27,6 +28,13 @@ TVM_REGISTER_API(_const)
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(_str)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ir::StringImm::make(args.at(0));
});
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;

Просмотреть файл

@ -9,7 +9,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <limits>
#include <string>

Просмотреть файл

@ -0,0 +1,497 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_stack_vm.cc
*/
#include <limits>
#include "./codegen_stack_vm.h"
namespace tvm {
namespace codegen {
using namespace ir;
runtime::PackedFunc BuildStackVM(LoweredFunc func) {
StackVM vm = codegen::CodeGenStackVM().Compile(func);
auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) {
LOG(INFO) << "Run stack VM";
StackVM::State* s = StackVM::ThreadLocalState();
s->sp = 0;
s->pc = 0;
if (s->heap.size() < vm.heap_size) {
s->heap.resize(vm.heap_size);
}
s->heap[0].v_handle = (void*)args; // NOLINT(*)
s->heap[1].v_handle = (void*)type_codes; // NOLINT(*)
s->heap[2].v_int64 = num_args;
vm.Run(s);
};
return runtime::PackedFunc(f);
}
TVMValue TVMPrint(const TVMValue* args, int num_args) {
CHECK_EQ(num_args, 2);
int tcode = static_cast<int>(args[1].v_int64);
int code = (tcode >> (8 * 3)) & 255;
int bits = (tcode >> (8 * 2)) & 255;
int lanes = tcode & ((1 << 16) - 1);
Type t((halide_type_code_t)code, bits, lanes);
if (t.is_handle()) {
LOG(INFO) << t << ": " << args[0].v_handle;
} else if (t.is_float()) {
LOG(INFO) << t << ": " << args[0].v_float64;
} else {
LOG(INFO) << t << ": " << args[0].v_int64;
}
TVMValue r; r.v_int64 = 0;
return r;
}
CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*)
static FType inst; return inst;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
int vid = AllocVarID(v.get());
CHECK_EQ(static_cast<size_t>(vid), i);
}
this->Push(f->body);
return std::move(vm_);
}
void CodeGenStackVM::Push(const Stmt& n) {
static const FType& f = vtable();
f(n, this);
if (debug_) {
this->PushOp(StackVM::ASSERT_SP, 0);
}
}
void CodeGenStackVM::Push(const Expr& n) {
static const FType& f = vtable();
f(n, this);
}
void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
StackVM::Code code;
code.op_code = opcode;
vm_.code.push_back(code);
}
void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
CHECK(operand >= std::numeric_limits<int>::min() &&
operand <= std::numeric_limits<int>::max());
vm_.code.at(operand_index).v_int = static_cast<int>(operand);
}
int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) {
int64_t pc = static_cast<int64_t>(vm_.code.size());
StackVM::Code code;
code.op_code = opcode;
vm_.code.push_back(code);
code.v_int = operand;
vm_.code.push_back(code);
return pc + 1;
}
int CodeGenStackVM::GetStrID(const std::string& key) {
auto it = str_idmap_.find(key);
if (it != str_idmap_.end()) return it->second;
int sid = static_cast<int>(vm_.str_data.size());
vm_.str_data.push_back(key);
str_idmap_[key] = sid;
return sid;
}
int CodeGenStackVM::AllocVarID(const Variable* v) {
CHECK(!var_idmap_.count(v));
int vid = static_cast<int>(vm_.heap_size);
CHECK_EQ(vm_.heap_size, var_idmap_.size());
vm_.heap_id_name.push_back(v->name_hint);
++vm_.heap_size;
var_idmap_[v] = vid;
return vid;
}
int CodeGenStackVM::GetGlobalFuncID(std::string name) {
auto it = fun_idmap_.find(name);
if (it != fun_idmap_.end()) return it->second;
using runtime::PackedFunc;
PackedFunc f = PackedFunc::GetGlobal(name);
auto extern_f = [f](const TVMValue* args, int num_args) {
CHECK_EQ(num_args % 2, 0);
num_args = num_args / 2;
std::vector<int> type_codes(std::max(num_args, 1));
for (int i = 0; i < num_args; ++i) {
int tcode = static_cast<int>(args[num_args + i].v_int64);
int code = (tcode >> (8 * 3)) & 255;
type_codes[i] = code;
}
f.CallPacked(args, &type_codes[0], num_args);
TVMValue r; r.v_int64 = 0;
return r;
};
int fid = static_cast<int>(vm_.extern_func.size());
vm_.extern_func.push_back(extern_f);
fun_idmap_[name] = fid;
return fid;
}
int CodeGenStackVM::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
return it->second;
}
void CodeGenStackVM::Push_(const ir::Load* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
if (op->type == UInt(32) && op->index.as<IntImm>()) {
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
} else {
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->PushOp(StackVM::GetLoad(op->type));
}
}
void CodeGenStackVM::Push_(const ir::Store* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->Push(op->value);
this->PushOp(StackVM::GetStore(op->value.type()));
}
void CodeGenStackVM::Push_(const ir::Allocate* op) {
CHECK(!is_zero(op->condition));
int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
this->Push(op->new_expr);
this->PushOp(StackVM::STORE_HEAP, vid);
} else {
LOG(FATAL) << "Dynamic allocation not supported";
}
}
void CodeGenStackVM::Push_(const ir::Call* op) {
if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
this->Push(l->index);
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
CHECK_EQ(op->args.size(), 3U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
if (op->type.is_handle()) {
this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE);
} else if (op->type.is_float()) {
this->PushOp(StackVM::TVM_LOAD_ARG_FP64);
} else if (op->type.is_int() || op->type.is_uint()) {
this->PushOp(StackVM::TVM_LOAD_ARG_INT64);
} else {
LOG(FATAL) << "donot know how to handle type" << op->type;
}
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
this->Push(op->args[0]);
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
default: LOG(FATAL) << "unknown field code";
}
} else if (op->is_intrinsic(intrinsic::tvm_call_global)) {
CHECK_GE(op->args.size(), 1U);
const StringImm* s = op->args[0].as<StringImm>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
for (size_t i = 1; i < op->args.size(); ++i) {
this->Push(op->args[i]);
}
for (size_t i = 1; i < op->args.size(); ++i) {
Type t = op->args[i].type();
int code = t.code();
int bits = t.bits();
int lanes = t.lanes();
int tcode = (code << (8 * 3)) | (bits << 16) | lanes;
this->PushOp(StackVM::PUSH_I64, tcode);
}
int num_args = static_cast<int>((op->args.size() - 1) * 2);
this->PushOp(StackVM::PUSH_I64, num_args);
this->PushOp(StackVM::CALL_EXTERN, GetGlobalFuncID(s->value));
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_I64);
} else {
this->HandleUnknownCall(op);
}
}
void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) {
LOG(FATAL) << "donot know how to handle call " << op->name;
}
inline void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b,
CodeGenStackVM* p) {
p->Push(a);
p->Push(b);
Type t = a.type();
if (t.is_int()) {
p->PushOp(op_int64);
} else if (t.is_uint()) {
if (t.bits() <= 32) {
p->PushOp(op_int64);
} else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
}
} else {
p->PushOp(StackVM::CodeI64ToF64(op_int64));
}
}
inline void PushCast(Type dst,
Type src,
CodeGenStackVM* p) {
if (dst.is_int()) {
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
} else if (dst.is_uint() && dst.bits() <= 32) {
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
} else if (dst.is_float()) {
if (src.is_float()) return;
}
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
}
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<StringImm>([](const StringImm *op, CodeGenStackVM *p) {
int sid = p->GetStrID(op->value);
p->PushOp(StackVM::PUSH_I64, sid);
})
.set_dispatch<IntImm>([](const IntImm *op, CodeGenStackVM *p) {
CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
})
.set_dispatch<UIntImm>([](const UIntImm *op, CodeGenStackVM *p) {
CHECK(op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
})
.set_dispatch<FloatImm>([](const FloatImm *op, CodeGenStackVM *p) {
LOG(FATAL) << "Float Imm is not supported";
});
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Variable>([](const Variable *op, CodeGenStackVM* p) {
int vid = p->GetVarID(op);
p->PushOp(StackVM::LOAD_HEAP, vid);
})
.set_dispatch<Cast>([](const Cast *op, CodeGenStackVM* p) {
p->Push(op->value);
PushCast(op->type, op->value.type(), p);
})
.set_dispatch<Add>([](const Add *op, CodeGenStackVM* p) {
PushBinary(StackVM::ADD_I64, op->a, op->b, p);
})
.set_dispatch<Sub>([](const Sub *op, CodeGenStackVM* p) {
PushBinary(StackVM::SUB_I64, op->a, op->b, p);
})
.set_dispatch<Mul>([](const Mul *op, CodeGenStackVM* p) {
PushBinary(StackVM::MUL_I64, op->a, op->b, p);
})
.set_dispatch<Div>([](const Div *op, CodeGenStackVM* p) {
PushBinary(StackVM::DIV_I64, op->a, op->b, p);
})
.set_dispatch<Mod>([](const Mod *op, CodeGenStackVM* p) {
PushBinary(StackVM::MOD_I64, op->a, op->b, p);
})
.set_dispatch<Min>([](const Min *op, CodeGenStackVM* p) {
p->Push(op->a);
p->Push(op->b);
p->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::LT_I64);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<Max>([](const Max *op, CodeGenStackVM* p) {
p->Push(op->a);
p->Push(op->b);
p->PushOp(StackVM::PUSH_VALUE, 0);
p->PushOp(StackVM::PUSH_VALUE, -2);
p->PushOp(StackVM::LT_I64);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<EQ>([](const EQ *op, CodeGenStackVM* p) {
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
})
.set_dispatch<LE>([](const LE *op, CodeGenStackVM* p) {
PushBinary(StackVM::LE_I64, op->a, op->b, p);
})
.set_dispatch<NE>([](const NE *op, CodeGenStackVM* p) {
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<LT>([](const LT *op, CodeGenStackVM* p) {
PushBinary(StackVM::LT_I64, op->a, op->b, p);
})
.set_dispatch<GE>([](const GE *op, CodeGenStackVM* p) {
PushBinary(StackVM::LT_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<GT>([](const GT *op, CodeGenStackVM* p) {
PushBinary(StackVM::LE_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<And>([](const And *op, CodeGenStackVM* p) {
p->Push(op->a);
int64_t pc_jump = p->GetPC();
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->b);
int64_t diff = p->GetPC() - pc_jump;
p->SetOperand(opr_index, diff);
})
.set_dispatch<Or>([](const Or *op, CodeGenStackVM* p) {
p->Push(op->a);
int64_t pc_jump = p->GetPC();
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0);
p->Push(op->b);
int64_t diff = p->GetPC() - pc_jump;
p->SetOperand(opr_index, diff);
})
.set_dispatch<Not>([](const Not* op, CodeGenStackVM* p) {
p->PushOp(StackVM::NOT);
});
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenStackVM* p) {
p->Push(op->body);
})
.set_dispatch<For>([](const For *op, CodeGenStackVM* p) {
CHECK(is_zero(op->min));
int vid = p->AllocVarID(op->loop_var.get());
p->PushOp(StackVM::PUSH_I64, 0);
int64_t loop_head = p->GetPC();
p->PushOp(StackVM::STORE_HEAP, vid);
p->PushOp(StackVM::LOAD_HEAP, vid);
p->Push(op->extent);
p->PushOp(StackVM::LT_I64);
int64_t label_fjump = p->GetPC();
int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->body);
p->PushOp(StackVM::LOAD_HEAP, vid);
p->PushOp(StackVM::PUSH_I64, 1);
p->PushOp(StackVM::ADD_I64);
int64_t label_bjump = p->GetPC();
int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0);
int64_t loop_end = p->GetPC();
p->PushOp(StackVM::POP);
p->SetOperand(foward_jump, loop_end - label_fjump);
p->SetOperand(backward_jump, loop_head - label_bjump);
})
.set_dispatch<Block>([](const Block *op, CodeGenStackVM* p) {
p->Push(op->first);
if (op->rest.defined()) p->Push(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenStackVM* p) {
if (is_const(op->value)) return;
p->Push(op->value);
p->PushOp(StackVM::POP);
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenStackVM* p) {
p->Push(op->condition);
int64_t label_ejump = p->GetPC();
int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->then_case);
if (op->else_case.defined()) {
int64_t label_then_jump = p->GetPC();
int64_t then_jump = p->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = p->GetPC();
p->SetOperand(else_jump, else_begin - label_ejump);
p->PushOp(StackVM::POP);
p->Push(op->else_case);
int64_t if_end = p->GetPC();
p->SetOperand(then_jump, if_end - label_then_jump);
} else {
int64_t if_end = p->GetPC();
p->SetOperand(else_jump, if_end - label_ejump);
p->PushOp(StackVM::POP);
}
})
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) {
p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, vid);
p->Push(op->body);
})
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) {
LOG(FATAL) << "Ramp is not supported";
})
.set_dispatch<Broadcast>([](const Broadcast *op, CodeGenStackVM* p) {
LOG(FATAL) << "Broadcast is not supported";
})
.set_dispatch<Select>([](const Select *op, CodeGenStackVM* p) {
p->Push(op->true_value);
p->Push(op->false_value);
p->Push(op->condition);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenStackVM* p) {
if (op->message.as<StringImm>()) {
int sid = p->GetStrID(op->message.as<StringImm>()->value);
p->Push(op->condition);
p->PushOp(StackVM::ASSERT, sid);
}
})
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenStackVM* p) {
p->Push(op->body);
})
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) {
p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, vid);
p->Push(op->body);
})
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Store>([](const Store *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Allocate>([](const Allocate *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
p->Push_(op);
});
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -0,0 +1,110 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_stack_vm.h
* \brief Codegen into Simple Stack VM.
*/
#ifndef TVM_CODEGEN_CODEGEN_STACK_VM_H_
#define TVM_CODEGEN_CODEGEN_STACK_VM_H_
#include <tvm/ir.h>
#include <tvm/module.h>
#include <tvm/codegen.h>
#include <string>
#include <unordered_map>
#include "../jit/stack_vm.h"
namespace tvm {
namespace codegen {
using jit::StackVM;
/*!
* \brief A base class to generate a stack VM.
* This module is used to generate host wrapper
* into device function when only device JIT is available.
*/
class CodeGenStackVM {
public:
/*!
* \brief Generate a stack VM representing
* \param f The function to be compiled
* \note Only call compile once,
* create a new codegen object each time.
*/
StackVM Compile(LoweredFunc f);
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
void Push(const Expr& n);
/*!
* \brief Push the opcode to the code.
* \param opcode The code to be pushed.
*/
void PushOp(StackVM::OpCode opcode);
/*!
* \brief Push the opcode and operand to the code.
* \param opcode The opcode.
* \param operand The operand to be pushed.
* \return operand_index, indicating location of operand
*/
int64_t PushOp(StackVM::OpCode opcode, int operand);
/*!
* \brief Set the relative jump offset to be offset.
* \param operand_index The indexed returned by PushOp.
* \param operand The operand to be set.
*/
void SetOperand(int64_t operand_index, int64_t operand);
/*! \return The current program pointer */
int64_t GetPC() const {
return static_cast<int64_t>(vm_.code.size());
}
/*!
* \brief Get string id in vm
* \param key The string to get id.
* \return the id of the string.
*/
int GetStrID(const std::string& key);
/*!
* \brief Push the function to the VM and get a id.
* \param f The function to be pushed.
*/
int GetGlobalFuncID(std::string name);
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the heap index of the var.
*/
int AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the heap index of the var.
*/
int GetVarID(const Variable* v) const;
// overloadable functions
virtual void Push_(const ir::Load* op);
virtual void Push_(const ir::Store* op);
virtual void Push_(const ir::Allocate* op);
virtual void Push_(const ir::Call* op);
virtual void HandleUnknownCall(const ir::Call* op);
/*! \brief function to to print normal code */
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>;
// vtable to print code
static FType& vtable(); // NOLINT(*)
private:
bool debug_{false};
/*! \brief The vm to be generated */
StackVM vm_;
/*! \brief id of each variable */
std::unordered_map<const Variable*, int> var_idmap_;
/*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each function */
std::unordered_map<std::string, int> fun_idmap_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_STACK_VM_H_

334
src/jit/stack_vm.cc Normal file
Просмотреть файл

@ -0,0 +1,334 @@
/*!
* Copyright (c) 2017 by Contributors
* Implementation stack VM.
* \file stack_vm.cc
*/
#include <dmlc/thread_local.h>
#include "./stack_vm.h"
namespace tvm {
namespace jit {
typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
StackVM::State* StackVM::ThreadLocalState() {
return StackVMStateStore::Get();
}
#define STACK_VM_BINOP(OP, FIELD) \
{ \
stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \
sp -= 1; pc += 1; \
}
#define STACK_VM_CMPOP(OP, FIELD) \
{ \
stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
sp -= 1; pc += 1; \
}
#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
{ \
stack[sp].FIELD = static_cast<DST_TYPE>( \
*static_cast<SRC_TYPE*>(stack[sp].v_handle)); \
pc += 1; \
}
#define STACK_VM_STORE(FIELD, DST_TYPE) \
{ \
*static_cast<DST_TYPE*>(stack[sp - 1].v_handle) = \
static_cast<DST_TYPE>(stack[sp].FIELD); \
sp -= 2; pc += 1; \
}
#define STACK_VM_TVM_LOAD_ARG(OP, TYPE) \
{ \
TVMValue* args = static_cast<TVMValue*>(stack[sp - 2].v_handle); \
int64_t index = stack[sp].v_int64; \
int tc = static_cast<int*>(stack[sp - 1].v_handle)[index]; \
CHECK(OP) \
<< " argument " << index << " is expected to be " << TYPE; \
stack[sp - 2] = args[index]; \
sp -= 2; \
pc += 1; \
}
#define STACK_VM_TVM_ARRARY_GET(FIELD, TYPE, SFIELD) \
{ \
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle); \
stack[sp].FIELD = (TYPE)(arr->SFIELD); \
pc += 1; \
}
#define STACK_VM_PRINT_CODE0(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \
}
#define STACK_VM_PRINT_CODE1(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
<< "[" << pc + 1 << "]" << std::endl; \
return pc + 2; \
}
#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \
<< " " << heap_id_name[code[pc + 1].v_int] << "\n" \
<< "[" << pc + 1 << "]" << std::endl; \
return pc + 2; \
}
#define STACK_VM_PRINT_JUMP(CODE) \
case CODE: { \
os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \
<< " to " << pc + code[pc + 1].v_int << '\n' \
<< "[" << pc + 1 << "]" << std::endl; \
return pc + 2; \
}
int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
switch (code[pc].op_code) {
// int
STACK_VM_PRINT_CODE0(ADD_I64);
STACK_VM_PRINT_CODE0(SUB_I64);
STACK_VM_PRINT_CODE0(MUL_I64);
STACK_VM_PRINT_CODE0(MOD_I64);
STACK_VM_PRINT_CODE0(DIV_I64);
STACK_VM_PRINT_CODE0(EQ_I64);
STACK_VM_PRINT_CODE0(LT_I64);
STACK_VM_PRINT_CODE0(LE_I64);
// floats
STACK_VM_PRINT_CODE0(ADD_F64);
STACK_VM_PRINT_CODE0(SUB_F64);
STACK_VM_PRINT_CODE0(MUL_F64);
STACK_VM_PRINT_CODE0(DIV_F64);
STACK_VM_PRINT_CODE0(EQ_F64);
STACK_VM_PRINT_CODE0(LT_F64);
STACK_VM_PRINT_CODE0(LE_F64);
// addressing load
STACK_VM_PRINT_CODE0(ADDR_LOAD_UINT32);
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT32);
STACK_VM_PRINT_CODE0(ADDR_LOAD_INT64);
STACK_VM_PRINT_CODE0(ADDR_LOAD_FP64);
STACK_VM_PRINT_CODE0(ADDR_LOAD_HANDLE);
STACK_VM_PRINT_CODE0(ADDR_STORE_INT64);
STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
STACK_VM_PRINT_CODE0(NOT);
STACK_VM_PRINT_CODE0(ADDR_ADD);
// stack ops
STACK_VM_PRINT_CODE1(PUSH_I64);
STACK_VM_PRINT_CODE1(PUSH_VALUE);
STACK_VM_PRINT_CODE0(POP);
STACK_VM_PRINT_CODE0(SELECT);
STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP);
STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP);
STACK_VM_PRINT_CODE1(CALL_EXTERN);
STACK_VM_PRINT_CODE1(ASSERT);
STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE);
STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE);
STACK_VM_PRINT_JUMP(RJUMP);
STACK_VM_PRINT_CODE1(ASSERT_SP);
// Intrinsics
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_INT64);
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_FP64);
STACK_VM_PRINT_CODE0(TVM_LOAD_ARG_HANDLE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_DATA);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
}
LOG(FATAL) << "unknown op code " << code[pc].op_code;
return 0;
}
std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*)
int64_t pc = 0;
const int64_t code_size = static_cast<int64_t>(vm.code.size());
os << "Program dump: code-size=" << code_size << '\n'
<< "----------begin-----------------\n";
while (pc < code_size) {
pc = vm.PrintCode(os, pc);
}
os << "----------end--------------------\n";
return os;
}
void StackVM::Run(State* s) const {
int64_t sp = s->sp;
int64_t pc = s->pc;
std::vector<TVMValue>& stack = s->stack;
std::vector<TVMValue>& heap = s->heap;
if (stack.size() < stack_size) {
stack.resize(stack_size);
}
int64_t stack_cap = static_cast<int64_t>(stack_size - 4);
if (heap.size() < heap_size) {
heap.resize(heap_size);
}
const int64_t code_size = static_cast<int64_t>(code.size());
while (pc < code_size) {
switch (code[pc].op_code) {
case ADD_I64: STACK_VM_BINOP(+, v_int64); break;
case SUB_I64: STACK_VM_BINOP(-, v_int64); break;
case MUL_I64: STACK_VM_BINOP(*, v_int64); break;
case DIV_I64: STACK_VM_BINOP(/, v_int64); break;
case MOD_I64: STACK_VM_BINOP(%, v_int64); break;
case EQ_I64: STACK_VM_CMPOP(==, v_int64); break;
case LT_I64: STACK_VM_CMPOP(<, v_int64); break;
case LE_I64: STACK_VM_CMPOP(<=, v_int64); break;
case ADD_F64: STACK_VM_BINOP(+, v_float64); break;
case SUB_F64: STACK_VM_BINOP(-, v_float64); break;
case MUL_F64: STACK_VM_BINOP(*, v_float64); break;
case DIV_F64: STACK_VM_BINOP(/, v_float64); break;
case EQ_F64: STACK_VM_CMPOP(==, v_float64); break;
case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
// addressing
case ADDR_LOAD_UINT32: STACK_VM_LOAD(v_int64, int64_t, uint32_t); break;
case ADDR_LOAD_INT32: STACK_VM_LOAD(v_int64, int64_t, int32_t); break;
case ADDR_LOAD_INT64: STACK_VM_LOAD(v_int64, int64_t, int64_t); break;
case ADDR_LOAD_FP64: STACK_VM_LOAD(v_float64, double, double); break;
case ADDR_LOAD_HANDLE: STACK_VM_LOAD(v_handle, void*, void*); break;
case ADDR_STORE_INT64: STACK_VM_STORE(v_int64, int64_t); break;
case ADDR_ADD: {
stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
sp = sp - 1;
pc = pc + 1;
break;
}
case ARRAY_LOAD_UINT32: {
stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int]; // NOLINT(*)
pc = pc + 2;
break;
}
case NOT: {
stack[sp].v_int64 = !stack[sp].v_int64;
pc += 1;
break;
}
case PUSH_I64: {
stack[sp + 1].v_int64 = code[pc + 1].v_int;
sp += 1;
pc += 2;
break;
}
case PUSH_VALUE: {
int relpos = code[pc + 1].v_int;
CHECK_LE(relpos, 0);
stack[sp + 1] = stack[sp + relpos];
sp += 1;
pc += 2;
break;
}
case POP: {
sp -= 1;
pc += 1;
break;
}
case SELECT: {
stack[sp - 2] = (stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]);
sp -= 2;
pc += 1;
break;
}
case LOAD_HEAP: {
stack[sp + 1] = heap[code[pc + 1].v_int];
sp += 1;
pc += 2;
break;
}
case STORE_HEAP: {
heap[code[pc + 1].v_int] = stack[sp];
sp -= 1;
pc += 2;
break;
}
case CALL_EXTERN: {
int num_args = static_cast<int>(stack[sp].v_int64);
int call_fid = code[pc + 1].v_int;
stack[sp - num_args] = extern_func[call_fid](
&stack[sp - num_args], num_args);
sp = sp - num_args;
pc += 2;
break;
}
case ASSERT: {
CHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
sp -= 1;
pc += 2;
break;
}
case RJUMP_IF_TRUE: {
if (stack[sp].v_int64) {
pc += code[pc + 1].v_int;
} else {
pc += 2;
}
break;
}
case RJUMP_IF_FALSE: {
if (!stack[sp].v_int64) {
pc += code[pc + 1].v_int;
} else {
pc += 2;
}
break;
}
case RJUMP: {
pc += code[pc + 1].v_int;
break;
}
case ASSERT_SP: {
int64_t expected = code[pc + 1].v_int;
CHECK_EQ(sp, expected)
<< "sp assertion failed, expected="
<< expected << " now=" << sp << ", pc=" << pc;
pc += 2;
break;
}
case TVM_LOAD_ARG_INT64: {
STACK_VM_TVM_LOAD_ARG(tc == kInt, "int"); break;
}
case TVM_LOAD_ARG_FP64: {
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
}
case TVM_LOAD_ARG_HANDLE: {
STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break;
}
case TVM_ARRAY_GET_DATA: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
}
case TVM_ARRAY_GET_SHAPE: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, shape); break;
}
case TVM_ARRAY_GET_STRIDES: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, strides); break;
}
case TVM_ARRAY_GET_NDIM: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
}
case TVM_ARRAY_GET_TYPE_CODE: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break;
}
case TVM_ARRAY_GET_TYPE_BITS: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
}
case TVM_ARRAY_GET_TYPE_LANES: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.lanes); break;
}
}
CHECK_LT(sp, stack_cap) << "Stack overflow";
}
}
} // namespace jit
} // namespace tvm

298
src/jit/stack_vm.h Normal file
Просмотреть файл

@ -0,0 +1,298 @@
/*!
* Copyright (c) 2016 by Contributors
* \file stack_vm.h
* \brief A simple stack-based virtual machine.
*
* This can be used to interepret host side code
* to setup calls into device functions
* when only JIT for device is available(via NVRTC or OpenCL).
*/
#ifndef TVM_JIT_STACK_VM_H_
#define TVM_JIT_STACK_VM_H_
#include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string>
#include <vector>
namespace tvm {
namespace jit {
/*!
* \brief A simple stack-based virtual machine.
*/
class StackVM {
public:
/*!
* \brief The opcode of stack vm
* \note Notation
* - sp Stack pointer
* - pc Program pointer
*/
enum OpCode {
// integer ops
ADD_I64,
SUB_I64,
MUL_I64,
DIV_I64,
MOD_I64,
EQ_I64,
LT_I64,
LE_I64,
// floating ops
ADD_F64,
SUB_F64,
MUL_F64,
DIV_F64,
EQ_F64,
LT_F64,
LE_F64,
// load operation
ADDR_LOAD_UINT32,
ADDR_LOAD_INT32,
ADDR_LOAD_INT64,
ADDR_LOAD_FP64,
ADDR_LOAD_HANDLE,
// store operations
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
// sp = sp - 2;
ADDR_STORE_INT64,
/*!
* \brief Quick routine to load uint32 from constant offset.
* \code
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int];
* pc = pc + 2;
* \endcode
*/
ARRAY_LOAD_UINT32,
// logical ops
NOT,
/*!
* \brief Add address by an offset.
* \code
* stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
* sp = sp - 1;
* \endcode
*/
ADDR_ADD,
/*!
* \brief push integer fetched from next pc position into stack
* \code
* stack[sp + 1].v_int64 = code[pc + 1].v_int;
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
PUSH_I64,
/*!
* \brief push a value given relative index on the stack
* \code
* stack[sp + 1] = stack[sp + code[pc + 1].v_int];
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
PUSH_VALUE,
/*!
* \brief Load data from heap to top of stack
* \code
* stack[sp + 1] = heap[code[pc + 1].v_int];
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
LOAD_HEAP,
/*!
* \brief Store data to heap
* \code
* heap[code[pc + 1].v_int] = stack[sp];
* sp = sp - 1;
* \endcode
*/
STORE_HEAP,
/*! \brief pop value from top of the stack */
POP,
/*!
* \brief select based on operands.
* \code
* stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
* sp = sp - 2;
* \endcode
*/
SELECT,
/*!
* \brief call an extern function
* \code
* num_args = stack[sp].v_int64;
* call_fid = code[pc + 1].v_int;
* f = extern_func[call_fid];
* stack[sp - num_args] = f(&stack[sp - num_args], num_args);
* sp = sp - num_args;
* \endcode
*/
CALL_EXTERN,
/*!
* \brief Assert condition is true.
* \code
* CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
* sp = sp - 1;
* \endcode
*/
ASSERT,
/*!
* \brief Relative Jump if the condition is true,
* Does not change the stack status.
* \code
* if (stack[sp]) {
* pc += code[pc + 1].v_int
* } else {
* pc = pc + 2;
* }
* \endcode
*/
RJUMP_IF_TRUE,
/*!
* \brief Relative Jump if the condition is true,
* Does not change the stack status.
* \code
* if (stack[sp]) {
* pc += code[pc + 1].v_int
* } else {
* pc = pc + 2;
* }
* \endcode
*/
RJUMP_IF_FALSE,
/*!
* \brief Relative jump to a location.
* \code
* pc += code[pc + 1].v_int;
* \endcode
*/
RJUMP,
/*!
* \brief debug instruction.
* \code
* CHECK_EQ(sp, code[pc + 1]).v_int;
* pc += 2;
* \code
*/
ASSERT_SP,
// Intrinsics for API function,
TVM_LOAD_ARG_INT64,
TVM_LOAD_ARG_FP64,
TVM_LOAD_ARG_HANDLE,
TVM_ARRAY_GET_DATA,
TVM_ARRAY_GET_SHAPE,
TVM_ARRAY_GET_STRIDES,
TVM_ARRAY_GET_NDIM,
TVM_ARRAY_GET_TYPE_CODE,
TVM_ARRAY_GET_TYPE_BITS,
TVM_ARRAY_GET_TYPE_LANES
};
/*! \brief The code structure */
union Code {
OpCode op_code;
int v_int;
};
/*! \brief The state object of StackVM */
struct State {
/*! \brief The execution stack */
std::vector<TVMValue> stack;
/*! \brief The global heap space */
std::vector<TVMValue> heap;
/*! \brief stack pointer */
int64_t sp{0};
/*! \brief program counter */
int64_t pc{0};
};
/*! \brief execute the stack vm with given state */
void Run(State* state) const;
/*!
* \brief Print instruction at location pc
* \param os The ostream
* \param pc The pc
* \return the pc to next instruction.
*/
int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*)
/*! \brief Get thread local state of the stack VM */
static State* ThreadLocalState();
/*! \brief extern function that will mutate the state */
using ExternFunc = std::function<TVMValue (const TVMValue* args, int num_args)>;
/*! \brief The instructions */
std::vector<Code> code;
/*! \brief constant error messages */
std::vector<std::string> str_data;
/*! \brief Extern functions */
std::vector<ExternFunc> extern_func;
/*! \brief name of each heap id*/
std::vector<std::string> heap_id_name;
/*! \brief The memory size needed */
size_t heap_size{0};
/*! \brief The stack size required */
size_t stack_size{1024};
/*!
* \brief Convert I64 opcode to F64 Ones
* \param code The op code.
* \return the F64 op code.
*/
static OpCode CodeI64ToF64(OpCode code) {
switch (code) {
case ADD_I64: return ADD_F64;
case SUB_I64: return SUB_F64;
case MUL_I64: return MUL_F64;
case DIV_I64: return DIV_F64;
case EQ_I64: return EQ_F64;
case LT_I64: return LT_F64;
case LE_I64: return LE_F64;
case MOD_I64: LOG(FATAL) << "cannot handle mod for float";
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
}
}
/*!
* \brief Get load opcode for type t
* \param t the type code.
* \return The load opcode
*/
static OpCode GetLoad(Type t) {
CHECK_EQ(t.lanes(), 1);
if (t.is_handle()) return ADDR_LOAD_HANDLE;
if (t.is_int()) {
switch (t.bits()) {
case 32 : return ADDR_LOAD_INT32;
case 64 : return ADDR_LOAD_INT64;
}
} else if (t.is_uint()) {
switch (t.bits()) {
case 32 : return ADDR_LOAD_UINT32;
}
} else if (t.is_float()) {
switch (t.bits()) {
case 64 : return ADDR_LOAD_FP64;
}
}
LOG(FATAL) << "Cannot load type " << t;
return ADDR_LOAD_FP64;
}
/*!
* \brief Get store opcode for type t
* \param t the type code.
* \return The load opcode
*/
static OpCode GetStore(Type t) {
CHECK_EQ(t.lanes(), 1);
if (t.is_int()) {
switch (t.bits()) {
case 64 : return ADDR_STORE_INT64;
}
}
LOG(FATAL) << "Cannot store type " << t;
return ADDR_LOAD_FP64;
}
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
};
} // namespace jit
} // namespace tvm
#endif // TVM_JIT_STACK_VM_H_

Просмотреть файл

@ -4,7 +4,7 @@
* \brief Device specific implementations
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
#include <algorithm>
#include "./runtime_base.h"
#include "./device_api.h"
@ -170,7 +170,7 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc::FType*>(func);
delete static_cast<PackedFunc*>(func);
API_END();
}
@ -179,7 +179,35 @@ int TVMFuncCall(TVMFunctionHandle func,
int* arg_type_codes,
int num_args) {
API_BEGIN();
(*static_cast<const PackedFunc::FType*>(func))(
(*static_cast<const PackedFunc*>(func)).CallPacked(
args, arg_type_codes, num_args);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](const TVMValue* args,
const int* type_codes,
int num_args) {
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, resource_handle);
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](const TVMValue* args,
const int* type_codes,
int num_args) {
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, rpack.get());
});
}
API_END();
}

Просмотреть файл

@ -0,0 +1,73 @@
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_registry.cc
* \brief The global registry of packed function.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
#include <memory>
#include "./runtime_base.h"
namespace tvm {
namespace runtime {
struct PackedFuncRegistry {
// map storing the functions.
// We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python)
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, PackedFunc*> fmap;
static PackedFuncRegistry* Global() {
static PackedFuncRegistry inst;
return &inst;
}
};
const PackedFunc& PackedFunc::RegisterGlobal(
const std::string& name, PackedFunc f) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it == r->fmap.end())
<< "Global PackedFunc " << name << " is already registered";
PackedFunc* fp = new PackedFunc(f);
r->fmap[name] = fp;
return *fp;
}
const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it != r->fmap.end())
<< "Global PackedFunc " << name << " is not registered";
return *(it->second);
}
std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys;
keys.reserve(r->fmap.size());
for (const auto &kv : r->fmap) {
keys.push_back(kv.first);
}
return keys;
}
} // namespace runtime
} // namespace tvm
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
using tvm::runtime::PackedFunc;
API_BEGIN();
PackedFunc::RegisterGlobal(name, *static_cast<PackedFunc*>(f));
API_END();
}
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
using tvm::runtime::PackedFunc;
API_BEGIN();
*out = new PackedFunc(PackedFunc::GetGlobal(name));
API_END();
}

Просмотреть файл

@ -1,6 +1,6 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
TEST(PackedFunc, Basic) {
using namespace tvm::runtime;

Просмотреть файл

@ -0,0 +1,76 @@
import tvm
import numpy as np
def tvm_call_global(*args):
args = tvm.convert(args)
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
def tvm_call_back_get_shape(shape0):
assert shape0 == a.shape[0]
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "print_shape", [Ab], 1)
print(fapi.body)
f = tvm.codegen.BuildStackVM(fapi)
f(a)
@tvm.register_func
def tvm_stack_vm_print(*x):
print(x)
def test_stack_vm_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.Block(
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1,
i + 1),
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
def test_stack_vm_cond():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.IfThenElse(
tvm.make.EQ(i, 4),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1, i + 1),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 2, i + 1)))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
if __name__ == "__main__":
test_stack_vm_cond()

Просмотреть файл

@ -1,17 +0,0 @@
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x)
y = tvm.nd.array(x, ctx=ctx)
f = tvm.codegen.DummyHelloFunction()
f(y, 10)
if __name__ == "__main__":
test_function()

Просмотреть файл

@ -0,0 +1,40 @@
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x)
y = tvm.nd.array(x, ctx=ctx)
f = tvm.codegen.DummyHelloFunction()
f(y, 10)
def test_get_global():
targs = (10, 10.0, "hello")
# register into global function table
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
# get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
f(*targs)
def test_convert():
# convert a function to tvm function
targs = (10, 10.0, "hello", 10)
def myfunc(*args):
assert(tuple(args) == targs)
f = tvm.convert(myfunc)
assert isinstance(f, tvm.nd.Function)
f(*targs)
if __name__ == "__main__":
test_function()
test_convert()
test_get_global()