[DEBUG] get_node_output : To retrieve out put of any node - for debug purpose. (#820)
This commit is contained in:
Родитель
ba8d00c2da
Коммит
eb8077fff8
4
Makefile
4
Makefile
|
@ -151,6 +151,10 @@ ifeq ($(USE_GRAPH_RUNTIME), 1)
|
|||
RUNTIME_DEP += $(GRAPH_OBJ)
|
||||
endif
|
||||
|
||||
ifeq ($(USE_GRAPH_RUNTIME_DEBUG), 1)
|
||||
CFLAGS += -DTVM_GRAPH_RUNTIME_DEBUG
|
||||
endif
|
||||
|
||||
include make/contrib/cblas.mk
|
||||
include make/contrib/random.mk
|
||||
include make/contrib/nnpack.mk
|
||||
|
|
|
@ -50,6 +50,9 @@ USE_RPC = 1
|
|||
# Whether enable tiny embedded graph runtime.
|
||||
USE_GRAPH_RUNTIME = 1
|
||||
|
||||
# Whether enable additional graph debug functions
|
||||
USE_GRAPH_RUNTIME_DEBUG = 0
|
||||
|
||||
# whether build with LLVM support
|
||||
# Requires LLVM version >= 4.0
|
||||
# Set LLVM_CONFIG to your version, uncomment to build with llvm support
|
||||
|
|
|
@ -72,6 +72,10 @@ class GraphModule(object):
|
|||
self._set_input = module["set_input"]
|
||||
self._run = module["run"]
|
||||
self._get_output = module["get_output"]
|
||||
try:
|
||||
self._debug_get_output = module["debug_get_output"]
|
||||
except AttributeError:
|
||||
pass
|
||||
self._load_params = module["load_params"]
|
||||
self.ctx = ctx
|
||||
|
||||
|
@ -121,6 +125,23 @@ class GraphModule(object):
|
|||
self._get_output(index, out)
|
||||
return out
|
||||
|
||||
def debug_get_output(self, node, out):
|
||||
"""Run graph upto node and get the output to out
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : int / str
|
||||
The node index or name
|
||||
|
||||
out : NDArray
|
||||
The output array container
|
||||
"""
|
||||
if hasattr(self, '_debug_get_output'):
|
||||
self._debug_get_output(node, out)
|
||||
else:
|
||||
raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
|
||||
return out
|
||||
|
||||
def load_params(self, params_bytes):
|
||||
"""Load parameters from serialized byte array of parameter dict.
|
||||
|
||||
|
|
|
@ -107,7 +107,44 @@ class GraphRuntime : public ModuleNode {
|
|||
uint32_t eid = this->entry_id(outputs_[index]);
|
||||
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
|
||||
}
|
||||
#ifdef TVM_GRAPH_RUNTIME_DEBUG
|
||||
/*!
|
||||
* \brief Get the node index given the name of node.
|
||||
* \param name The name of the node.
|
||||
* \return The index of node.
|
||||
*/
|
||||
int GetNodeIndex(const std::string& name) {
|
||||
for (uint32_t nid = 0; nid< nodes_.size(); ++nid) {
|
||||
if (nodes_[nid].name == name) {
|
||||
return static_cast<int>(nid);
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "cannot find " << name << " among nodex";
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Copy index-th node to data_out.
|
||||
*
|
||||
* This method will do a partial run of the the graph
|
||||
* from begining upto the index-th node and return output of index-th node.
|
||||
* This is costly operation and suggest to use only for debug porpose.
|
||||
*
|
||||
* \param index: The index of the node.
|
||||
* \param data_out the node data.
|
||||
*/
|
||||
void DebugGetNodeOutput(int index, DLTensor* data_out) {
|
||||
CHECK_LT(static_cast<size_t>(index), nodes_.size());
|
||||
uint32_t eid = index;
|
||||
|
||||
for (size_t i = 0; i < op_execs_.size(); ++i) {
|
||||
if (static_cast<int>(i) == index) break;
|
||||
if (op_execs_[i]) op_execs_[i]();
|
||||
}
|
||||
|
||||
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
|
||||
}
|
||||
#endif
|
||||
/*!
|
||||
* \brief Load parameters from binary stream
|
||||
* \param strm The input stream.
|
||||
|
@ -556,6 +593,16 @@ PackedFunc GraphRuntime::GetFunction(
|
|||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
||||
this->GetOutput(args[0], args[1]);
|
||||
});
|
||||
#ifdef TVM_GRAPH_RUNTIME_DEBUG
|
||||
} else if (name == "debug_get_output") {
|
||||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
||||
if (args[0].type_code() == kStr) {
|
||||
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
|
||||
} else {
|
||||
this->DebugGetNodeOutput(args[0], args[1]);
|
||||
}
|
||||
});
|
||||
#endif
|
||||
} else if (name == "run") {
|
||||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
||||
this->Run();
|
||||
|
|
Загрузка…
Ссылка в новой задаче