[EXT] Allow easy extraction of extern module (#926)
This commit is contained in:
Родитель
433756b9fd
Коммит
d99bcaf156
|
@ -22,12 +22,11 @@ struct extension_class_info<tvm_ext::IntVector> {
|
|||
} // namespace tvm
|
||||
} // namespace runtime
|
||||
|
||||
|
||||
namespace tvm_ext {
|
||||
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
|
||||
namespace tvm_ext {
|
||||
|
||||
TVM_REGISTER_EXT_TYPE(IntVector);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
|
||||
|
@ -66,3 +65,18 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
|
|||
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
|
||||
});
|
||||
} // namespace tvm_ext
|
||||
|
||||
// This callback approach allows extension allows tvm to extract
|
||||
// This way can be helpful when we want to use a header only
|
||||
// minimum version of TVM Runtime.
|
||||
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
|
||||
const PackedFunc& fregister =
|
||||
*static_cast<PackedFunc*>(pregister);
|
||||
auto mul = [](TVMArgs args, TVMRetValue *rv) {
|
||||
int x = args[0];
|
||||
int y = args[1];
|
||||
*rv = x * y;
|
||||
};
|
||||
fregister("mul", PackedFunc(mul));
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -44,8 +44,14 @@ def test_ext_vec():
|
|||
|
||||
tvm.convert(ivec_cb)(ivec)
|
||||
|
||||
def test_extract_ext():
|
||||
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
|
||||
assert fdict["mul"](3, 4) == 12
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ext_dev()
|
||||
test_ext_vec()
|
||||
test_bind_add()
|
||||
test_sym_add()
|
||||
test_extract_ext()
|
||||
|
|
|
@ -24,6 +24,13 @@
|
|||
#define TVM_EXTERN_C
|
||||
#endif
|
||||
|
||||
// Macros to do weak linking
|
||||
#ifdef _MSC_VER
|
||||
#define TVM_WEAK __declspec(selectany)
|
||||
#else
|
||||
#define TVM_WEAK __attribute__((weak))
|
||||
#endif
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
|
||||
|
@ -313,6 +320,17 @@ typedef int (*TVMPackedCFunc)(
|
|||
*/
|
||||
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
|
||||
|
||||
/*!
|
||||
* \brief Signature for extension function declarer.
|
||||
*
|
||||
* TVM call this function to get the extension functions
|
||||
* The declarer will call register_func to register function and their name.
|
||||
*
|
||||
* \param resource_func_handle The register function
|
||||
* \return 0 if success, -1 if failure happens
|
||||
*/
|
||||
typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle);
|
||||
|
||||
/*!
|
||||
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
|
||||
*
|
||||
|
|
|
@ -38,8 +38,14 @@ class Module {
|
|||
* \param query_imports Whether also query dependency modules.
|
||||
* \return The result function.
|
||||
* This function will return PackedFunc(nullptr) if function do not exist.
|
||||
* \note Implemented in packed_func.cc
|
||||
*/
|
||||
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
|
||||
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
|
||||
/*! \return internal container */
|
||||
inline ModuleNode* operator->();
|
||||
/*! \return internal container */
|
||||
inline const ModuleNode* operator->() const;
|
||||
// The following functions requires link with runtime.
|
||||
/*!
|
||||
* \brief Import another module into this module.
|
||||
* \param other The module to be imported.
|
||||
|
@ -57,10 +63,6 @@ class Module {
|
|||
*/
|
||||
TVM_DLL static Module LoadFromFile(const std::string& file_name,
|
||||
const std::string& format = "");
|
||||
/*! \return internal container */
|
||||
inline ModuleNode* operator->();
|
||||
/*! \return internal container */
|
||||
inline const ModuleNode* operator->() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModuleNode> node_;
|
||||
|
|
|
@ -24,6 +24,11 @@ struct Type;
|
|||
struct Expr;
|
||||
}
|
||||
|
||||
// Whether use TVM runtime in header only mode.
|
||||
#ifndef TVM_RUNTIME_HEADER_ONLY
|
||||
#define TVM_RUNTIME_HEADER_ONLY 0
|
||||
#endif
|
||||
|
||||
namespace tvm {
|
||||
// Forward declare NodeRef and Node for extensions.
|
||||
// This header works fine without depend on NodeRef
|
||||
|
@ -564,11 +569,15 @@ class TVMRetValue : public TVMPODValue_ {
|
|||
SwitchToPOD(other.type_code());
|
||||
value_ = other.value_;
|
||||
} else {
|
||||
#if TVM_RUNTIME_HEADER_ONLY
|
||||
LOG(FATAL) << "Header only mode do not support ext type";
|
||||
#else
|
||||
this->Clear();
|
||||
type_code_ = other.type_code();
|
||||
value_.v_handle =
|
||||
(*(ExtTypeVTable::Get(other.type_code())->clone))(
|
||||
other.value().v_handle);
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -600,7 +609,11 @@ class TVMRetValue : public TVMPODValue_ {
|
|||
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
|
||||
}
|
||||
if (type_code_ > kExtBegin) {
|
||||
#if TVM_RUNTIME_HEADER_ONLY
|
||||
LOG(FATAL) << "Header only mode do not support ext type";
|
||||
#else
|
||||
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
|
||||
#endif
|
||||
}
|
||||
type_code_ = kNull;
|
||||
}
|
||||
|
@ -882,6 +895,20 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
|
|||
vt.destroy = ExtTypeInfo<T>::destroy;
|
||||
return ExtTypeVTable::RegisterInternal(code, vt);
|
||||
}
|
||||
|
||||
// Implement Module::GetFunction
|
||||
// Put implementation in this file so we have seen the PackedFunc
|
||||
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
|
||||
PackedFunc pf = node_->GetFunction(name, node_);
|
||||
if (pf != nullptr) return pf;
|
||||
if (query_imports) {
|
||||
for (const Module& m : node_->imports_) {
|
||||
pf = m.node_->GetFunction(name, m.node_);
|
||||
if (pf != nullptr) return pf;
|
||||
}
|
||||
}
|
||||
return pf;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_PACKED_FUNC_H_
|
||||
|
|
|
@ -234,6 +234,31 @@ def list_global_func_names():
|
|||
return fnames
|
||||
|
||||
|
||||
def extract_ext_funcs(finit):
|
||||
"""
|
||||
Extract the extension PackedFuncs from a C module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
finit : ctypes function
|
||||
a ctypes that takes signature of TVMExtensionDeclarer
|
||||
|
||||
Returns
|
||||
-------
|
||||
fdict : dict of str to Function
|
||||
The extracted functions
|
||||
"""
|
||||
fdict = {}
|
||||
def _list(name, func):
|
||||
fdict[name] = func
|
||||
myf = convert_to_tvm_func(_list)
|
||||
ret = finit(myf.handle)
|
||||
_ = myf
|
||||
if ret != 0:
|
||||
raise RuntimeError("cannot initialize with %s" % finit)
|
||||
return fdict
|
||||
|
||||
|
||||
def _get_api(f):
|
||||
flocal = f
|
||||
flocal.is_global = True
|
||||
|
|
|
@ -8,7 +8,7 @@ from ._ffi.base import string_types
|
|||
from ._ffi.node import register_node, NodeBase
|
||||
from ._ffi.node import convert_to_node as _convert_to_node
|
||||
from ._ffi.function import Function
|
||||
from ._ffi.function import _init_api, register_func, get_global_func
|
||||
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
|
||||
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
|
||||
from ._ffi.runtime_ctypes import TVMType
|
||||
from . import _api_internal
|
||||
|
|
|
@ -23,16 +23,16 @@ from . import target as _target
|
|||
from . import make
|
||||
|
||||
class DumpIR(object):
|
||||
"""Dump IR for each pass.
|
||||
With it, you can dump ir just like gcc/llvm.
|
||||
"""
|
||||
Dump IR for each pass.
|
||||
With it, you can dump ir just like gcc/llvm.
|
||||
|
||||
How to use:
|
||||
-----------
|
||||
.. code-block:: python
|
||||
|
||||
with tvm.build_config(dump_pass_ir=True)
|
||||
run()
|
||||
How to use:
|
||||
-----------
|
||||
.. code-block:: python
|
||||
|
||||
with tvm.build_config(dump_pass_ir=True)
|
||||
run()
|
||||
"""
|
||||
scope_level = 0
|
||||
def __init__(self):
|
||||
|
@ -40,9 +40,9 @@ class DumpIR(object):
|
|||
self._recover_list = []
|
||||
|
||||
def decorate(self, func):
|
||||
''' decorate the pass function'''
|
||||
""" decorate the pass function"""
|
||||
def dump(*args, **kwargs):
|
||||
'''dump function'''
|
||||
"""dump function"""
|
||||
retv = func(*args, **kwargs)
|
||||
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
|
||||
return retv
|
||||
|
@ -59,7 +59,7 @@ class DumpIR(object):
|
|||
return dump
|
||||
|
||||
def decorate_irpass(self):
|
||||
'''decorate ir_pass and ScheduleOps'''
|
||||
"""decorate ir_pass and ScheduleOps"""
|
||||
self._old_sgpass = schedule.ScheduleOps
|
||||
schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
|
||||
vset = vars(ir_pass)
|
||||
|
@ -71,7 +71,7 @@ class DumpIR(object):
|
|||
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
|
||||
|
||||
def decorate_custompass(self):
|
||||
''' decorate add_lower_pass pass in BuildConfig'''
|
||||
""" decorate add_lower_pass pass in BuildConfig"""
|
||||
cfg = BuildConfig.current
|
||||
self._old_custom_pass = cfg.add_lower_pass
|
||||
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
|
||||
|
@ -79,7 +79,7 @@ class DumpIR(object):
|
|||
BuildConfig.current.add_lower_pass = pass_list
|
||||
|
||||
def enter(self):
|
||||
'''only decorate outermost nest'''
|
||||
"""only decorate outermost nest"""
|
||||
if DumpIR.scope_level > 0:
|
||||
return
|
||||
self.decorate_irpass()
|
||||
|
@ -88,7 +88,7 @@ class DumpIR(object):
|
|||
DumpIR.scope_level += 1
|
||||
|
||||
def exit(self):
|
||||
'''recover outermost nest'''
|
||||
"""recover outermost nest"""
|
||||
if DumpIR.scope_level > 1:
|
||||
return
|
||||
# recover decorated functions
|
||||
|
|
|
@ -13,19 +13,6 @@
|
|||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
PackedFunc Module::GetFunction(
|
||||
const std::string& name, bool query_imports) {
|
||||
PackedFunc pf = node_->GetFunction(name, node_);
|
||||
if (pf != nullptr) return pf;
|
||||
if (query_imports) {
|
||||
for (const Module& m : node_->imports_) {
|
||||
pf = m.node_->GetFunction(name, m.node_);
|
||||
if (pf != nullptr) return pf;
|
||||
}
|
||||
}
|
||||
return pf;
|
||||
}
|
||||
|
||||
void Module::Import(Module other) {
|
||||
// specially handle rpc
|
||||
if (!std::strcmp((*this)->type_key(), "rpc")) {
|
||||
|
|
|
@ -6,6 +6,7 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
|
|||
|
||||
# Test TVM
|
||||
make cython || exit -1
|
||||
make cython3 || exit -1
|
||||
|
||||
# Test extern package package
|
||||
cd apps/extension
|
||||
|
|
Загрузка…
Ссылка в новой задаче