[RUNTIME][DEBUG]Support remote debugging (#1866)
This commit is contained in:
Родитель
0b4cc05050
Коммит
47b8c36dcf
|
@ -5,8 +5,9 @@ import tempfile
|
|||
import shutil
|
||||
from datetime import datetime
|
||||
from tvm._ffi.base import string_types
|
||||
from tvm.contrib import graph_runtime
|
||||
from tvm._ffi.function import get_global_func
|
||||
from tvm.contrib import graph_runtime
|
||||
from tvm.rpc import base as rpc_base
|
||||
from . import debug_result
|
||||
|
||||
_DUMP_ROOT_PREFIX = "tvmdbg_"
|
||||
|
@ -49,8 +50,12 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
|
|||
|
||||
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
|
||||
if num_rpc_ctx == len(ctx):
|
||||
raise NotSupportedError("Remote graph debugging is not supported.")
|
||||
|
||||
libmod = rpc_base._ModuleHandle(libmod)
|
||||
try:
|
||||
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create")
|
||||
except ValueError:
|
||||
raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \
|
||||
"config.cmake and rebuild TVM to enable debug mode")
|
||||
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
|
||||
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
|
||||
|
||||
|
|
|
@ -146,5 +146,18 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
|
|||
<< args.num_args;
|
||||
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
|
||||
"graph_runtime.remote_create is "
|
||||
"at least 4, but it has "
|
||||
<< args.num_args;
|
||||
void* mhandle = args[1];
|
||||
const auto& contexts = GetAllContext(args);
|
||||
*rv = GraphRuntimeDebugCreate(
|
||||
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
|
||||
});
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
|
|
@ -2,6 +2,8 @@ import os
|
|||
import tvm
|
||||
import numpy as np
|
||||
import json
|
||||
from tvm import rpc
|
||||
from tvm.contrib import util
|
||||
from tvm.contrib.debugger import debug_runtime as graph_runtime
|
||||
|
||||
def test_graph_simple():
|
||||
|
@ -70,7 +72,32 @@ def test_graph_simple():
|
|||
#verify dump root delete after cleanup
|
||||
assert(not os.path.exists(directory))
|
||||
|
||||
def check_remote():
|
||||
if not tvm.module.enabled("llvm"):
|
||||
print("Skip because llvm is not enabled")
|
||||
return
|
||||
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
|
||||
server = rpc.Server("localhost")
|
||||
remote = rpc.connect(server.host, server.port)
|
||||
temp = util.tempdir()
|
||||
ctx = remote.cpu(0)
|
||||
path_dso = temp.relpath("dev_lib.so")
|
||||
mlib.export_library(path_dso)
|
||||
remote.upload(path_dso)
|
||||
mlib = remote.load_module("dev_lib.so")
|
||||
try:
|
||||
mod = graph_runtime.create(graph, mlib, remote.cpu(0))
|
||||
except ValueError:
|
||||
print("Skip because debug graph_runtime not enabled")
|
||||
return
|
||||
a = np.random.uniform(size=(n,)).astype(A.dtype)
|
||||
mod.run(x=tvm.nd.array(a, ctx))
|
||||
out = tvm.nd.empty((n,), ctx=ctx)
|
||||
out = mod.get_output(0, out)
|
||||
np.testing.assert_equal(out.asnumpy(), a + 1)
|
||||
|
||||
check_verify()
|
||||
check_remote()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_graph_simple()
|
||||
|
|
Загрузка…
Ссылка в новой задаче