2016-10-15 04:07:37 +03:00
|
|
|
/*!
|
|
|
|
* Copyright (c) 2016 by Contributors
|
|
|
|
* \file base.h
|
|
|
|
* \brief Defines the base data structure
|
|
|
|
*/
|
|
|
|
#ifndef TVM_BASE_H_
|
|
|
|
#define TVM_BASE_H_
|
|
|
|
|
|
|
|
#include <dmlc/logging.h>
|
|
|
|
#include <dmlc/registry.h>
|
2016-10-27 02:11:46 +03:00
|
|
|
#include <tvm/node.h>
|
2016-10-15 04:07:37 +03:00
|
|
|
#include <string>
|
|
|
|
#include <memory>
|
|
|
|
#include <functional>
|
2017-04-21 19:47:30 +03:00
|
|
|
#include "./runtime/registry.h"
|
2016-10-15 04:07:37 +03:00
|
|
|
|
|
|
|
namespace tvm {
|
|
|
|
|
2016-10-26 21:32:43 +03:00
|
|
|
using ::tvm::Node;
|
|
|
|
using ::tvm::NodeRef;
|
|
|
|
using ::tvm::AttrVisitor;
|
2016-10-15 04:07:37 +03:00
|
|
|
|
2017-03-26 04:26:28 +03:00
|
|
|
/*! \brief Macro to make it easy to define node ref type given node */
|
|
|
|
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
|
2017-09-24 03:51:24 +03:00
|
|
|
class TypeName : public ::tvm::NodeRef { \
|
2017-03-26 04:26:28 +03:00
|
|
|
public: \
|
|
|
|
TypeName() {} \
|
2017-09-24 03:51:24 +03:00
|
|
|
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} \
|
2017-03-26 04:26:28 +03:00
|
|
|
const NodeName* operator->() const { \
|
|
|
|
return static_cast<const NodeName*>(node_.get()); \
|
|
|
|
} \
|
|
|
|
using ContainerType = NodeName; \
|
|
|
|
}; \
|
|
|
|
|
|
|
|
|
2017-01-12 21:07:37 +03:00
|
|
|
/*!
|
|
|
|
* \brief save the node as well as all the node it depends on as json.
|
|
|
|
* This can be used to serialize any TVM object
|
|
|
|
*
|
|
|
|
* \return the string representation of the node.
|
|
|
|
*/
|
|
|
|
std::string SaveJSON(const NodeRef& node);
|
|
|
|
|
|
|
|
/*!
|
|
|
|
* \brief Internal implementation of LoadJSON
|
|
|
|
* Load tvm Node object from json and return a shared_ptr of Node.
|
|
|
|
* \param json_str The json string to load from.
|
|
|
|
*
|
|
|
|
* \return The shared_ptr of the Node.
|
|
|
|
*/
|
|
|
|
std::shared_ptr<Node> LoadJSON_(std::string json_str);
|
|
|
|
|
|
|
|
/*!
|
|
|
|
* \brief Load the node from json string.
|
|
|
|
* This can be used to deserialize any TVM object.
|
|
|
|
*
|
|
|
|
* \param json_str The json string to load from.
|
|
|
|
*
|
|
|
|
* \tparam NodeType the nodetype
|
|
|
|
*
|
|
|
|
* \code
|
|
|
|
* Expr e = LoadJSON<Expr>(json_str);
|
|
|
|
* \endcode
|
|
|
|
*/
|
|
|
|
template<typename NodeType,
|
|
|
|
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
|
|
|
|
inline NodeType LoadJSON(const std::string& json_str) {
|
|
|
|
return NodeType(LoadJSON_(json_str));
|
|
|
|
}
|
|
|
|
|
2016-10-15 04:07:37 +03:00
|
|
|
/*! \brief typedef the factory function of data iterator */
|
|
|
|
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
|
|
|
|
/*!
|
2016-10-17 07:33:42 +03:00
|
|
|
* \brief Registry entry for NodeFactory
|
2016-10-15 04:07:37 +03:00
|
|
|
*/
|
|
|
|
struct NodeFactoryReg
|
|
|
|
: public dmlc::FunctionRegEntryBase<NodeFactoryReg,
|
|
|
|
NodeFactory> {
|
|
|
|
};
|
|
|
|
|
|
|
|
#define TVM_REGISTER_NODE_TYPE(TypeName) \
|
2017-01-12 21:07:37 +03:00
|
|
|
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
|
|
|
|
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
|
|
|
|
.set_body([]() { return std::make_shared<TypeName>(); })
|
2016-10-15 04:07:37 +03:00
|
|
|
|
|
|
|
} // namespace tvm
|
|
|
|
#endif // TVM_BASE_H_
|