178 строки
5.8 KiB
C++
178 строки
5.8 KiB
C++
/*!
|
|
* Copyright (c) 2017 by Contributors
|
|
* \file module.h
|
|
* \brief Runtime container of the functions generated by TVM,
|
|
* This is used to support dynamically link, load and save
|
|
* functions from different convention under unified API.
|
|
*/
|
|
#ifndef TVM_RUNTIME_MODULE_H_
|
|
#define TVM_RUNTIME_MODULE_H_
|
|
|
|
#include <dmlc/io.h>
|
|
#include <memory>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include "./c_runtime_api.h"
|
|
|
|
namespace tvm {
|
|
namespace runtime {
|
|
|
|
// The internal container of module.
|
|
class ModuleNode;
|
|
class PackedFunc;
|
|
|
|
/*!
|
|
* \brief Module container of TVM.
|
|
*/
|
|
class Module {
|
|
public:
|
|
Module() {}
|
|
// constructor from container.
|
|
explicit Module(std::shared_ptr<ModuleNode> n)
|
|
: node_(n) {}
|
|
/*!
|
|
* \brief Get packed function from current module by name.
|
|
*
|
|
* \param name The name of the function.
|
|
* \param query_imports Whether also query dependency modules.
|
|
* \return The result function.
|
|
* This function will return PackedFunc(nullptr) if function do not exist.
|
|
* \note Implemented in packed_func.cc
|
|
*/
|
|
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
|
|
/*! \return internal container */
|
|
inline ModuleNode* operator->();
|
|
/*! \return internal container */
|
|
inline const ModuleNode* operator->() const;
|
|
// The following functions requires link with runtime.
|
|
/*!
|
|
* \brief Import another module into this module.
|
|
* \param other The module to be imported.
|
|
*
|
|
* \note Cyclic dependency is not allowed among modules,
|
|
* An error will be thrown when cyclic dependency is detected.
|
|
*/
|
|
TVM_DLL void Import(Module other);
|
|
/*!
|
|
* \brief Load a module from file.
|
|
* \param file_name The name of the host function module.
|
|
* \param format The format of the file.
|
|
* \note This function won't load the import relationship.
|
|
* Re-create import relationship by calling Import.
|
|
*/
|
|
TVM_DLL static Module LoadFromFile(const std::string& file_name,
|
|
const std::string& format = "");
|
|
|
|
private:
|
|
std::shared_ptr<ModuleNode> node_;
|
|
};
|
|
|
|
/*!
|
|
* \brief Base node container of module.
|
|
* Do not create this directly, instead use Module.
|
|
*/
|
|
class ModuleNode {
|
|
public:
|
|
/*! \brief virtual destructor */
|
|
virtual ~ModuleNode() {}
|
|
/*! \return The module type key */
|
|
virtual const char* type_key() const = 0;
|
|
/*!
|
|
* \brief Get a PackedFunc from module.
|
|
*
|
|
* The PackedFunc may not be fully initialized,
|
|
* there might still be first time running overhead when
|
|
* executing the function on certain devices.
|
|
* For benchmarking, use prepare to eliminate
|
|
*
|
|
* \param name the name of the function.
|
|
* \param sptr_to_self The shared_ptr that points to this module node.
|
|
*
|
|
* \return PackedFunc(nullptr) when it is not available.
|
|
*
|
|
* \note The function will always remain valid.
|
|
* If the function need resource from the module(e.g. late linking),
|
|
* it should capture sptr_to_self.
|
|
*/
|
|
virtual PackedFunc GetFunction(
|
|
const std::string& name,
|
|
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
|
|
/*!
|
|
* \brief Save the module to file.
|
|
* \param file_name The file to be saved to.
|
|
* \param format The format of the file.
|
|
*/
|
|
virtual void SaveToFile(const std::string& file_name,
|
|
const std::string& format);
|
|
/*!
|
|
* \brief Save the module to binary stream.
|
|
* \param stream The binary stream to save to.
|
|
* \note It is recommended to implement this for device modules,
|
|
* but not necessarily host modules.
|
|
* We can use this to do AOT loading of bundled device functions.
|
|
*/
|
|
TVM_DLL virtual void SaveToBinary(dmlc::Stream* stream);
|
|
/*!
|
|
* \brief Get the source code of module, when available.
|
|
* \param format Format of the source code, can be empty by default.
|
|
* \return Possible source code when available.
|
|
*/
|
|
TVM_DLL virtual std::string GetSource(const std::string& format = "");
|
|
/*!
|
|
* \brief Get a function from current environment
|
|
* The environment includes all the imports as well as Global functions.
|
|
*
|
|
* \param name name of the function.
|
|
* \return The corresponding function.
|
|
*/
|
|
TVM_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
|
|
/*! \return The module it imports from */
|
|
const std::vector<Module>& imports() const {
|
|
return imports_;
|
|
}
|
|
|
|
protected:
|
|
friend class Module;
|
|
/*! \brief The modules this module depend on */
|
|
std::vector<Module> imports_;
|
|
|
|
private:
|
|
/*! \brief Cache used by GetImport */
|
|
std::unordered_map<std::string,
|
|
std::unique_ptr<PackedFunc> > import_cache_;
|
|
};
|
|
|
|
/*! \brief namespace for constant symbols */
|
|
namespace symbol {
|
|
/*! \brief Global variable to store module context. */
|
|
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
|
|
/*! \brief Global variable to store device module blob */
|
|
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
|
|
/*! \brief Number of bytes of device module blob. */
|
|
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
|
|
/*! \brief global function to set device */
|
|
constexpr const char* tvm_set_device = "__tvm_set_device";
|
|
/*! \brief Auxiliary counter to global barrier. */
|
|
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
|
|
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
|
|
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
|
|
/*! \brief Placeholder for the module's entry function. */
|
|
constexpr const char* tvm_module_main = "__tvm_main__";
|
|
} // namespace symbol
|
|
|
|
// implementations of inline functions.
|
|
inline ModuleNode* Module::operator->() {
|
|
return node_.get();
|
|
}
|
|
|
|
inline const ModuleNode* Module::operator->() const {
|
|
return node_.get();
|
|
}
|
|
|
|
} // namespace runtime
|
|
} // namespace tvm
|
|
|
|
#include "./packed_func.h"
|
|
#endif // TVM_RUNTIME_MODULE_H_
|