139 строки
3.7 KiB
C++
139 строки
3.7 KiB
C++
/*!
|
|
* Copyright (c) 2017 by Contributors
|
|
* \file tvm/api_registry.h
|
|
* \brief This file contains utilities related to
|
|
* the TVM's global function registry.
|
|
*/
|
|
#ifndef TVM_API_REGISTRY_H_
|
|
#define TVM_API_REGISTRY_H_
|
|
|
|
#include <string>
|
|
#include "base.h"
|
|
#include "packed_func_ext.h"
|
|
#include "runtime/registry.h"
|
|
|
|
namespace tvm {
|
|
/*!
|
|
* \brief Register an API function globally.
|
|
* It simply redirects to TVM_REGISTER_GLOBAL
|
|
*
|
|
* \code
|
|
* TVM_REGISTER_API(MyPrint)
|
|
* .set_body([](TVMArgs args, TVMRetValue* rv) {
|
|
* // my code.
|
|
* });
|
|
* \endcode
|
|
*/
|
|
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)
|
|
|
|
/*!
|
|
* \brief Node container of EnvFunc
|
|
* \sa EnvFunc
|
|
*/
|
|
class EnvFuncNode : public Node {
|
|
public:
|
|
/*! \brief Unique name of the global function */
|
|
std::string name;
|
|
/*! \brief The internal packed function */
|
|
PackedFunc func;
|
|
/*! \brief constructor */
|
|
EnvFuncNode() {}
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("name", &name);
|
|
}
|
|
|
|
static constexpr const char* _type_key = "EnvFunc";
|
|
TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node);
|
|
};
|
|
|
|
/*!
|
|
* \brief A serializable function backed by TVM's global environment.
|
|
*
|
|
* This is a wrapper to enable serializable global PackedFunc.
|
|
* An EnvFunc is saved by its name in the global registry
|
|
* under the assumption that the same function is registered during load.
|
|
*/
|
|
class EnvFunc : public NodeRef {
|
|
public:
|
|
EnvFunc() {}
|
|
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
|
|
/*! \return The internal global function pointer */
|
|
const EnvFuncNode* operator->() const {
|
|
return static_cast<EnvFuncNode*>(node_.get());
|
|
}
|
|
/*!
|
|
* \brief Invoke the function.
|
|
* \param args The arguments
|
|
* \returns The return value.
|
|
*/
|
|
template<typename... Args>
|
|
runtime::TVMRetValue operator()(Args&&... args) const {
|
|
const EnvFuncNode* n = operator->();
|
|
CHECK(n != nullptr);
|
|
return n->func(std::forward<Args>(args)...);
|
|
}
|
|
/*!
|
|
* \brief Get a global function based on the name.
|
|
* \param name The name of the global function.
|
|
* \return The created global function.
|
|
* \note The function can be unique
|
|
*/
|
|
TVM_DLL static EnvFunc Get(const std::string& name);
|
|
/*! \brief specify container node */
|
|
using ContainerType = EnvFuncNode;
|
|
};
|
|
|
|
/*!
|
|
* \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
|
|
*/
|
|
template<typename FType>
|
|
class TypedEnvFunc;
|
|
|
|
/*!
|
|
* \anchor TypedEnvFuncAnchor
|
|
* \brief A typed version of EnvFunc.
|
|
* It is backed by a GlobalFuncNode internally.
|
|
*
|
|
* \tparam R The return value of the function.
|
|
* \tparam Args The argument signature of the function.
|
|
* \sa EnvFunc
|
|
*/
|
|
template<typename R, typename... Args>
|
|
class TypedEnvFunc<R(Args...)> : public NodeRef {
|
|
public:
|
|
/*! \brief short hand for this function type */
|
|
using TSelf = TypedEnvFunc<R(Args...)>;
|
|
TypedEnvFunc() {}
|
|
explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief Assign global function to a TypedEnvFunc
|
|
* \param other Another global function.
|
|
* \return reference to self.
|
|
*/
|
|
TSelf& operator=(const EnvFunc& other) {
|
|
this->node_ = other.node_;
|
|
return *this;
|
|
}
|
|
/*! \return The internal global function pointer */
|
|
const EnvFuncNode* operator->() const {
|
|
return static_cast<EnvFuncNode*>(node_.get());
|
|
}
|
|
/*!
|
|
* \brief Invoke the function.
|
|
* \param args The arguments
|
|
* \returns The return value.
|
|
*/
|
|
R operator()(Args... args) const {
|
|
const EnvFuncNode* n = operator->();
|
|
CHECK(n != nullptr);
|
|
return runtime::detail::typed_packed_call_dispatcher<R>
|
|
::run(n->func, std::forward<Args>(args)...);
|
|
}
|
|
/*! \brief specify container node */
|
|
using ContainerType = EnvFuncNode;
|
|
};
|
|
|
|
} // namespace tvm
|
|
#endif // TVM_API_REGISTRY_H_
|