onnxruntime-tvm/include/tvm/base.h

94 строки
3.1 KiB
C++

/*!
* Copyright (c) 2016 by Contributors
* \file tvm/base.h
* \brief Defines the base data structure
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node.h>
#include <string>
#include <memory>
#include <functional>
#include "./runtime/registry.h"
namespace tvm {
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
}; \
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
std::shared_ptr<Node> LoadJSON_(std::string json_str);
/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}
/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*!
* \brief Registry entry for NodeFactory
*/
struct NodeFactoryReg
: public dmlc::FunctionRegEntryBase<NodeFactoryReg,
NodeFactory> {
};
#define TVM_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })
TVM_DLL::dmlc::Registry<::tvm::NodeFactoryReg > * GetTVMNodeFactoryRegistry();
#define TVM_EXTERNAL_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::GetTVMNodeFactoryRegistry()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })
} // namespace tvm
#endif // TVM_BASE_H_