diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 986a7b16..25d17d52 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -5,8 +5,9 @@ import tempfile import shutil from datetime import datetime from tvm._ffi.base import string_types -from tvm.contrib import graph_runtime from tvm._ffi.function import get_global_func +from tvm.contrib import graph_runtime +from tvm.rpc import base as rpc_base from . import debug_result _DUMP_ROOT_PREFIX = "tvmdbg_" @@ -49,8 +50,12 @@ def create(graph_json_str, libmod, ctx, dump_root=None): ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): - raise NotSupportedError("Remote graph debugging is not supported.") - + libmod = rpc_base._ModuleHandle(libmod) + try: + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create") + except ValueError: + raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \ + "config.cmake and rebuild TVM to enable debug mode") func_obj = fcreate(graph_json_str, libmod, *device_type_id) return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root) diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 7faee442..452a4840 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -146,5 +146,18 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") << args.num_args; *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); }); + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for " + "graph_runtime.remote_create is " + "at least 4, but it has " + << args.num_args; + void* mhandle = args[1]; + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeDebugCreate( + args[0], *static_cast(mhandle), contexts); + }); + } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index ab6b7299..b9d8b689 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -2,6 +2,8 @@ import os import tvm import numpy as np import json +from tvm import rpc +from tvm.contrib import util from tvm.contrib.debugger import debug_runtime as graph_runtime def test_graph_simple(): @@ -70,7 +72,32 @@ def test_graph_simple(): #verify dump root delete after cleanup assert(not os.path.exists(directory)) + def check_remote(): + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mlib = tvm.build(s, [A, B], "llvm", name="myadd") + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + temp = util.tempdir() + ctx = remote.cpu(0) + path_dso = temp.relpath("dev_lib.so") + mlib.export_library(path_dso) + remote.upload(path_dso) + mlib = remote.load_module("dev_lib.so") + try: + mod = graph_runtime.create(graph, mlib, remote.cpu(0)) + except ValueError: + print("Skip because debug graph_runtime not enabled") + return + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=tvm.nd.array(a, ctx)) + out = tvm.nd.empty((n,), ctx=ctx) + out = mod.get_output(0, out) + np.testing.assert_equal(out.asnumpy(), a + 1) + check_verify() + check_remote() if __name__ == "__main__": test_graph_simple()