diff --git a/Makefile b/Makefile index 453415de..44a500d2 100644 --- a/Makefile +++ b/Makefile @@ -151,6 +151,10 @@ ifeq ($(USE_GRAPH_RUNTIME), 1) RUNTIME_DEP += $(GRAPH_OBJ) endif +ifeq ($(USE_GRAPH_RUNTIME_DEBUG), 1) + CFLAGS += -DTVM_GRAPH_RUNTIME_DEBUG +endif + include make/contrib/cblas.mk include make/contrib/random.mk include make/contrib/nnpack.mk diff --git a/make/config.mk b/make/config.mk index 256771ac..eee96ac1 100644 --- a/make/config.mk +++ b/make/config.mk @@ -50,6 +50,9 @@ USE_RPC = 1 # Whether enable tiny embedded graph runtime. USE_GRAPH_RUNTIME = 1 +# Whether enable additional graph debug functions +USE_GRAPH_RUNTIME_DEBUG = 0 + # whether build with LLVM support # Requires LLVM version >= 4.0 # Set LLVM_CONFIG to your version, uncomment to build with llvm support diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 7e919586..ddabac00 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -72,6 +72,10 @@ class GraphModule(object): self._set_input = module["set_input"] self._run = module["run"] self._get_output = module["get_output"] + try: + self._debug_get_output = module["debug_get_output"] + except AttributeError: + pass self._load_params = module["load_params"] self.ctx = ctx @@ -121,6 +125,23 @@ class GraphModule(object): self._get_output(index, out) return out + def debug_get_output(self, node, out): + """Run graph upto node and get the output to out + + Parameters + ---------- + node : int / str + The node index or name + + out : NDArray + The output array container + """ + if hasattr(self, '_debug_get_output'): + self._debug_get_output(node, out) + else: + raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") + return out + def load_params(self, params_bytes): """Load parameters from serialized byte array of parameter dict. diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index d244fe5f..bf07a8c3 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -107,7 +107,44 @@ class GraphRuntime : public ModuleNode { uint32_t eid = this->entry_id(outputs_[index]); TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr)); } +#ifdef TVM_GRAPH_RUNTIME_DEBUG + /*! + * \brief Get the node index given the name of node. + * \param name The name of the node. + * \return The index of node. + */ + int GetNodeIndex(const std::string& name) { + for (uint32_t nid = 0; nid< nodes_.size(); ++nid) { + if (nodes_[nid].name == name) { + return static_cast(nid); + } + } + LOG(FATAL) << "cannot find " << name << " among nodex"; + return -1; + } + /*! + * \brief Copy index-th node to data_out. + * + * This method will do a partial run of the the graph + * from begining upto the index-th node and return output of index-th node. + * This is costly operation and suggest to use only for debug porpose. + * + * \param index: The index of the node. + * \param data_out the node data. + */ + void DebugGetNodeOutput(int index, DLTensor* data_out) { + CHECK_LT(static_cast(index), nodes_.size()); + uint32_t eid = index; + + for (size_t i = 0; i < op_execs_.size(); ++i) { + if (static_cast(i) == index) break; + if (op_execs_[i]) op_execs_[i](); + } + + TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr)); + } +#endif /*! * \brief Load parameters from binary stream * \param strm The input stream. @@ -556,6 +593,16 @@ PackedFunc GraphRuntime::GetFunction( return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->GetOutput(args[0], args[1]); }); +#ifdef TVM_GRAPH_RUNTIME_DEBUG + } else if (name == "debug_get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args[0].type_code() == kStr) { + this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); + } else { + this->DebugGetNodeOutput(args[0], args[1]); + } + }); +#endif } else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run();