Add safe destructor
This commit is contained in:
Родитель
3e693f53e0
Коммит
35277c2f95
|
@ -105,6 +105,15 @@ class Node {
|
|||
protected:
|
||||
// node ref can see this
|
||||
friend class NodeRef;
|
||||
/*!
|
||||
* \brief optional: safe destruction function
|
||||
* Can be called in destructor of composite types.
|
||||
* This can be used to avoid stack overflow when
|
||||
* recursive destruction long graph(1M nodes),
|
||||
*
|
||||
* It is totally OK to not call this in destructor.
|
||||
*/
|
||||
void Destroy();
|
||||
/*! \brief the node type enum */
|
||||
NodeType node_type_{kOtherNodes};
|
||||
};
|
||||
|
@ -127,6 +136,7 @@ class NodeRef {
|
|||
template<typename T, typename>
|
||||
friend class Array;
|
||||
friend class APIVariantValue;
|
||||
friend class Node;
|
||||
NodeRef() = default;
|
||||
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
|
||||
/*! \brief the internal node */
|
||||
|
|
|
@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode {
|
|||
node_type_ = kUnaryOpNode;
|
||||
dtype_ = this->src.dtype();
|
||||
}
|
||||
~UnaryOpNode() {
|
||||
this->Destroy();
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "UnaryOpNode";
|
||||
}
|
||||
|
@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode {
|
|||
node_type_ = kBinaryOpNode;
|
||||
dtype_ = this->lhs.dtype();
|
||||
}
|
||||
~BinaryOpNode() {
|
||||
this->Destroy();
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "BinaryOpNode";
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ class NodeBase(object):
|
|||
check_call(_LIB.TVMNodeGetAttr(
|
||||
self.handle, c_str(name),
|
||||
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
|
||||
return RET_SWITCH[ret_typeid.value](ret_val)
|
||||
ret = RET_SWITCH[ret_typeid.value](ret_val)
|
||||
|
||||
|
||||
def _type_key(handle):
|
||||
|
|
|
@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
|
|||
|
||||
namespace tvm {
|
||||
|
||||
void Node::Destroy() {
|
||||
bool safe = true;
|
||||
this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) {
|
||||
if (r->node_.get() != nullptr) safe = false;
|
||||
});
|
||||
|
||||
if (!safe) {
|
||||
// explicit deletion via DFS
|
||||
// this is used to avoid stackoverflow caused by chain of deletions
|
||||
std::vector<Node*> stack{this};
|
||||
std::vector<std::shared_ptr<Node> > to_delete;
|
||||
while (!stack.empty()) {
|
||||
Node* n = stack.back();
|
||||
stack.pop_back();
|
||||
n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) {
|
||||
if (r->node_.unique()) {
|
||||
stack.push_back(r->node_.get());
|
||||
to_delete.emplace_back(std::move(r->node_));
|
||||
} else {
|
||||
r->node_.reset();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(VarNode);
|
||||
TVM_REGISTER_NODE_TYPE(IntNode);
|
||||
TVM_REGISTER_NODE_TYPE(FloatNode);
|
||||
|
|
Загрузка…
Ссылка в новой задаче