This commit is contained in:
tqchen 2016-10-16 22:53:47 -07:00
Родитель 3e693f53e0
Коммит 35277c2f95
4 изменённых файлов: 43 добавлений и 1 удалений

Просмотреть файл

@ -105,6 +105,15 @@ class Node {
protected:
// node ref can see this
friend class NodeRef;
/*!
* \brief optional: safe destruction function
* Can be called in destructor of composite types.
* This can be used to avoid stack overflow when
* recursive destruction long graph(1M nodes),
*
* It is totally OK to not call this in destructor.
*/
void Destroy();
/*! \brief the node type enum */
NodeType node_type_{kOtherNodes};
};
@ -127,6 +136,7 @@ class NodeRef {
template<typename T, typename>
friend class Array;
friend class APIVariantValue;
friend class Node;
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
/*! \brief the internal node */

Просмотреть файл

@ -82,6 +82,9 @@ class UnaryOpNode : public ExprNode {
node_type_ = kUnaryOpNode;
dtype_ = this->src.dtype();
}
~UnaryOpNode() {
this->Destroy();
}
const char* type_key() const override {
return "UnaryOpNode";
}
@ -114,6 +117,9 @@ struct BinaryOpNode : public ExprNode {
node_type_ = kBinaryOpNode;
dtype_ = this->lhs.dtype();
}
~BinaryOpNode() {
this->Destroy();
}
const char* type_key() const override {
return "BinaryOpNode";
}

Просмотреть файл

@ -50,7 +50,7 @@ class NodeBase(object):
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
ret = RET_SWITCH[ret_typeid.value](ret_val)
def _type_key(handle):

Просмотреть файл

@ -11,6 +11,32 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
namespace tvm {
void Node::Destroy() {
bool safe = true;
this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) {
if (r->node_.get() != nullptr) safe = false;
});
if (!safe) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) {
if (r->node_.unique()) {
stack.push_back(r->node_.get());
to_delete.emplace_back(std::move(r->node_));
} else {
r->node_.reset();
}
});
}
}
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode);