[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)
This commit is contained in:
Родитель
4242b9cff5
Коммит
ff06917c59
|
@ -4,7 +4,7 @@ language: cpp
|
|||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
# - osx
|
||||
|
||||
env:
|
||||
# code analysis
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file api_registry.h
|
||||
* \brief This file defines the TVM API registry.
|
||||
*
|
||||
* The API registry stores type-erased functions.
|
||||
* Each registered function is automatically exposed
|
||||
* to front-end language(e.g. python).
|
||||
* Front-end can also pass callbacks as PackedFunc, or register
|
||||
* then into the same global registry in C++.
|
||||
* The goal is to mix the front-end language and the TVM back-end.
|
||||
*
|
||||
* \code
|
||||
* // register the function as MyAPIFuncName
|
||||
* TVM_REGISTER_API(MyAPIFuncName)
|
||||
* .set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
* // my code.
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
#ifndef TVM_API_REGISTRY_H_
|
||||
#define TVM_API_REGISTRY_H_
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <string>
|
||||
#include "./base.h"
|
||||
#include "./runtime/packed_func.h"
|
||||
#include "./packed_func_ext.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
/*! \brief Utility to register API. */
|
||||
class APIRegistry {
|
||||
public:
|
||||
/*!
|
||||
* \brief set the body of the function to be f
|
||||
* \param f The body of the function.
|
||||
*/
|
||||
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
|
||||
/*!
|
||||
* \brief set the body of the function to be f
|
||||
* \param f The body of the function.
|
||||
*/
|
||||
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
|
||||
return set_body(PackedFunc(f));
|
||||
}
|
||||
/*!
|
||||
* \brief Register a function with given name
|
||||
* \param name The name of the function.
|
||||
*/
|
||||
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)
|
||||
|
||||
private:
|
||||
/*! \brief name of the function */
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Get API function by name.
|
||||
*
|
||||
* \param name The name of the function.
|
||||
* \return the corresponding API function.
|
||||
* \note It is really PackedFunc::GetGlobal under the hood.
|
||||
*/
|
||||
inline PackedFunc GetAPIFunc(const std::string& name) {
|
||||
return PackedFunc::GetGlobal(name);
|
||||
}
|
||||
|
||||
#define _TVM_REGISTER_VAR_DEF_ \
|
||||
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_
|
||||
|
||||
/*!
|
||||
* \brief Register API function globally.
|
||||
* \code
|
||||
* TVM_REGISTER_API(MyPrint)
|
||||
* .set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
* // my code.
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
#define TVM_REGISTER_API(OpName) \
|
||||
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
|
||||
::tvm::APIRegistry::__REGISTER__(#OpName)
|
||||
} // namespace tvm
|
||||
#endif // TVM_API_REGISTRY_H_
|
|
@ -2,6 +2,13 @@
|
|||
* Copyright (c) 2016 by Contributors
|
||||
* \file c_api.h
|
||||
* \brief C API of TVM DSL
|
||||
*
|
||||
* \note The API is designed in a minimum way.
|
||||
* Most of the API functions are registered and can be pulled out.
|
||||
*
|
||||
* The common flow is:
|
||||
* - Use TVMFuncListGlobalNames to get global function name
|
||||
* - Use TVMFuncCall to call these functions.
|
||||
*/
|
||||
#ifndef TVM_C_API_H_
|
||||
#define TVM_C_API_H_
|
||||
|
@ -9,76 +16,9 @@
|
|||
#include "./runtime/c_runtime_api.h"
|
||||
|
||||
TVM_EXTERN_C {
|
||||
/*! \brief handle to functions */
|
||||
typedef void* APIFuncHandle;
|
||||
/*! \brief handle to node */
|
||||
typedef void* NodeHandle;
|
||||
|
||||
/*!
|
||||
* \brief List all the node function name
|
||||
* \param out_size The number of functions
|
||||
* \param out_array The array of function names.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMListAPIFuncNames(int *out_size,
|
||||
const char*** out_array);
|
||||
/*!
|
||||
* \brief get function handle by name
|
||||
* \param name The name of function
|
||||
* \param handle The returning function handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
|
||||
APIFuncHandle *handle);
|
||||
|
||||
/*!
|
||||
* \brief Get the detailed information about function.
|
||||
* \param handle The operator handle.
|
||||
* \param real_name The returned name of the function.
|
||||
* This name is not the alias name of the atomic symbol.
|
||||
* \param description The returned description of the symbol.
|
||||
* \param num_doc_args Number of arguments that contain documents.
|
||||
* \param arg_names Name of the arguments of doc args
|
||||
* \param arg_type_infos Type informations about the arguments.
|
||||
* \param arg_descriptions Description information about the arguments.
|
||||
* \param return_type Return type of the function, if any.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
|
||||
const char **real_name,
|
||||
const char **description,
|
||||
int *num_doc_args,
|
||||
const char ***arg_names,
|
||||
const char ***arg_type_infos,
|
||||
const char ***arg_descriptions,
|
||||
const char **return_type);
|
||||
|
||||
/*!
|
||||
* \brief Push an argument to the function calling stack.
|
||||
* If push fails, the stack will be reset to empty
|
||||
*
|
||||
* \param arg The argument
|
||||
* \param type_code The type_code of argument as in TVMTypeCode
|
||||
* \return 0 when success, -1 when failure happens
|
||||
* \note API calls always exchanges with type bits=64, lanes=1
|
||||
*/
|
||||
TVM_DLL int TVMAPIPushStack(TVMValue arg,
|
||||
int type_code);
|
||||
|
||||
/*!
|
||||
* \brief call a function by using arguments in the stack.
|
||||
* The stack will be cleanup to empty after this call, whether the call is successful.
|
||||
*
|
||||
* \param handle The function handle
|
||||
* \param ret_val The return value.
|
||||
* \param ret_type_code the type code of return value.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
* \note API calls always exchanges with type bits=64, lanes=1
|
||||
*/
|
||||
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code);
|
||||
|
||||
/*!
|
||||
* \brief free the node handle
|
||||
* \param handle The node handle to be freed.
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "./base.h"
|
||||
#include "./runtime/packed_func.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file packed_func_ext.h
|
||||
* \brief Extension package to PackedFunc
|
||||
* This enales pass NodeRef types into/from PackedFunc.
|
||||
*/
|
||||
#ifndef TVM_PACKED_FUNC_EXT_H_
|
||||
#define TVM_PACKED_FUNC_EXT_H_
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "./base.h"
|
||||
#include "./expr.h"
|
||||
|
||||
namespace tvm {
|
||||
using runtime::TVMArgs;
|
||||
using runtime::TVMRetValue;
|
||||
using runtime::PackedFunc;
|
||||
|
||||
namespace runtime {
|
||||
/*!
|
||||
* \brief Runtime type checker for node type.
|
||||
* \tparam T the type to be checked.
|
||||
*/
|
||||
template<typename T>
|
||||
struct NodeTypeChecker {
|
||||
static inline bool Check(Node* sptr) {
|
||||
// This is the only place in the project where RTTI is used
|
||||
// It can be turned off, but will make non strict checking.
|
||||
// TODO(tqchen) possibly find alternative to turn of RTTI
|
||||
using ContainerType = typename T::ContainerType;
|
||||
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
using ContainerType = typename T::ContainerType;
|
||||
os << ContainerType::_type_key;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct NodeTypeChecker<Array<T> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<ArrayNode>()) return false;
|
||||
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
||||
for (const auto& p : n->data) {
|
||||
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "array<";
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
os << ">";
|
||||
}
|
||||
};
|
||||
|
||||
template<typename K, typename V>
|
||||
struct NodeTypeChecker<Map<K, V> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<MapNode>()) return false;
|
||||
MapNode* n = static_cast<MapNode*>(sptr);
|
||||
for (const auto& kv : n->data) {
|
||||
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
||||
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "map<";
|
||||
NodeTypeChecker<K>::PrintName(os);
|
||||
os << ',';
|
||||
NodeTypeChecker<V>::PrintName(os);
|
||||
os << '>';
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline std::string NodeTypeName() {
|
||||
std::ostringstream os;
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
// extensions for tvm arg value
|
||||
|
||||
template<typename TNodeRef, typename>
|
||||
inline TVMArgValue::operator TNodeRef() const {
|
||||
static_assert(
|
||||
std::is_base_of<NodeRef, TNodeRef>::value,
|
||||
"Conversion only works for NodeRef");
|
||||
if (type_code_ == kNull) return TNodeRef();
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
|
||||
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
|
||||
<< "Expected type " << NodeTypeName<TNodeRef>()
|
||||
<< " but get " << sptr->type_key();
|
||||
return TNodeRef(sptr);
|
||||
}
|
||||
|
||||
inline TVMArgValue::operator Halide::Expr() const {
|
||||
if (type_code_ == kNull) return Expr();
|
||||
if (type_code_ == kInt) {
|
||||
return Expr(static_cast<int>(value_.v_int64));
|
||||
}
|
||||
if (type_code_ == kFloat) {
|
||||
return Expr(static_cast<float>(value_.v_float64));
|
||||
}
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
|
||||
if (sptr->is_type<IterVarNode>()) {
|
||||
return IterVar(sptr)->var;
|
||||
}
|
||||
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
|
||||
<< "Expected type " << NodeTypeName<Expr>()
|
||||
<< " but get " << sptr->type_key();
|
||||
return Expr(sptr);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
return *ptr<std::shared_ptr<Node> >();
|
||||
}
|
||||
|
||||
|
||||
template<typename TNodeRef, typename>
|
||||
inline bool TVMArgValue::IsNodeType() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
std::shared_ptr<Node>& sptr =
|
||||
*ptr<std::shared_ptr<Node> >();
|
||||
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
|
||||
}
|
||||
|
||||
// extensions for TVMRetValue
|
||||
inline TVMRetValue& TVMRetValue::operator=(
|
||||
const std::shared_ptr<Node>& other) {
|
||||
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
|
||||
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename TNodeRef, typename>
|
||||
inline TVMRetValue::operator TNodeRef() const {
|
||||
static_assert(
|
||||
std::is_base_of<NodeRef, TNodeRef>::value,
|
||||
"Conversion only works for NodeRef");
|
||||
if (type_code_ == kNull) return TNodeRef();
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
return TNodeRef(*ptr<std::shared_ptr<Node> >());
|
||||
}
|
||||
|
||||
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
|
||||
values_[i].v_handle = &(other.node_);
|
||||
type_codes_[i] = kNodeHandle;
|
||||
}
|
||||
|
||||
// Type related stuffs
|
||||
inline Type TVMType2Type(TVMType t) {
|
||||
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
|
||||
}
|
||||
|
||||
inline TVMType Type2TVMType(Type t) {
|
||||
TVMType ret;
|
||||
ret.code = static_cast<uint8_t>(t.code());
|
||||
ret.bits = static_cast<uint8_t>(t.bits());
|
||||
ret.lanes = static_cast<uint16_t>(t.lanes());
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
|
||||
return this->operator=(Type2TVMType(t));
|
||||
}
|
||||
|
||||
inline TVMRetValue::operator Halide::Type() const {
|
||||
return TVMType2Type(operator TVMType());
|
||||
}
|
||||
|
||||
inline TVMArgValue::operator Halide::Type() const {
|
||||
return TVMType2Type(operator TVMType());
|
||||
}
|
||||
|
||||
inline void TVMArgsSetter::operator()(
|
||||
size_t i, const Halide::Type& t) const {
|
||||
this->operator()(i, Type2TVMType(t));
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_PACKED_FUNC_EXT_H_
|
|
@ -36,18 +36,6 @@
|
|||
TVM_EXTERN_C {
|
||||
/*! \brief type of array index. */
|
||||
typedef uint32_t tvm_index_t;
|
||||
|
||||
/*!
|
||||
* \brief Union type of values
|
||||
* being passed through API and function calls.
|
||||
*/
|
||||
typedef union {
|
||||
int64_t v_int64;
|
||||
double v_float64;
|
||||
void* v_handle;
|
||||
const char* v_str;
|
||||
} TVMValue;
|
||||
|
||||
/*!
|
||||
* \brief The type code in TVMType
|
||||
* \note TVMType is used in two places.
|
||||
|
@ -60,9 +48,11 @@ typedef enum {
|
|||
// The next few fields are extension types
|
||||
// that is used by TVM API calls.
|
||||
kNull = 4U,
|
||||
kNodeHandle = 5U,
|
||||
kStr = 6U,
|
||||
kFuncHandle = 7U
|
||||
kArrayHandle = 5U,
|
||||
kTVMType = 6U,
|
||||
kNodeHandle = 7U,
|
||||
kStr = 8U,
|
||||
kFuncHandle = 9U
|
||||
} TVMTypeCode;
|
||||
|
||||
/*!
|
||||
|
@ -77,13 +67,25 @@ typedef enum {
|
|||
*/
|
||||
typedef struct {
|
||||
/*! \brief type code, in TVMTypeCode */
|
||||
uint8_t type_code;
|
||||
uint8_t code;
|
||||
/*! \brief number of bits of the type */
|
||||
uint8_t bits;
|
||||
/*! \brief number of lanes, */
|
||||
uint16_t lanes;
|
||||
} TVMType;
|
||||
|
||||
/*!
|
||||
* \brief Union type of values
|
||||
* being passed through API and function calls.
|
||||
*/
|
||||
typedef union {
|
||||
int64_t v_int64;
|
||||
double v_float64;
|
||||
void* v_handle;
|
||||
const char* v_str;
|
||||
TVMType v_type;
|
||||
} TVMValue;
|
||||
|
||||
/*!
|
||||
* \brief The device type
|
||||
*/
|
||||
|
@ -133,11 +135,10 @@ typedef struct {
|
|||
* can be NULL, which indicates the default one.
|
||||
*/
|
||||
typedef void* TVMStreamHandle;
|
||||
/*!
|
||||
* \brief Pointer to function handle that points to
|
||||
* a generated TVM function.
|
||||
*/
|
||||
/*! \brief Handle to packed function handle. */
|
||||
typedef void* TVMFunctionHandle;
|
||||
/*! \brief Handle to hold return value. */
|
||||
typedef void* TVMRetValueHandle;
|
||||
/*! \brief the array handle */
|
||||
typedef TVMArray* TVMArrayHandle;
|
||||
|
||||
|
@ -228,20 +229,45 @@ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
|
|||
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
|
||||
|
||||
/*!
|
||||
* \brief Call a function whose parameters are all packed.
|
||||
* \brief Call a Packed TVM Function.
|
||||
*
|
||||
* \param func node handle of the function.
|
||||
* \param args The arguments
|
||||
* \param arg_values The arguments
|
||||
* \param type_codes The type codes of the arguments
|
||||
* \param num_args Number of arguments.
|
||||
*
|
||||
* \param ret_val The return value.
|
||||
* \param ret_type_code the type code of return value.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
* \note TVM calls always exchanges with type bits=64, lanes=1
|
||||
*
|
||||
* \note API calls always exchanges with type bits=64, lanes=1
|
||||
* If API call returns container handles (e.g. FunctionHandle)
|
||||
* these handles should be managed by the front-end.
|
||||
* The front-end need to call free function (e.g. TVMFuncFree)
|
||||
* to free these handles.
|
||||
*/
|
||||
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
||||
TVMValue* args,
|
||||
TVMValue* arg_values,
|
||||
int* type_codes,
|
||||
int num_args);
|
||||
int num_args,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code);
|
||||
|
||||
/*!
|
||||
* \brief Set the return value of TVMPackedCFunc.
|
||||
*
|
||||
* This function is called by TVMPackedCFunc to set the return value.
|
||||
* When this function is not called, the function returns null by default.
|
||||
*
|
||||
* \param ret The return value handle, pass by ret in TVMPackedCFunc
|
||||
* \param value The value to be returned.
|
||||
* \param type_code The type of the value to be returned.
|
||||
*/
|
||||
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
|
||||
TVMValue value,
|
||||
int type_code);
|
||||
|
||||
/*!
|
||||
* \brief C type of packed function.
|
||||
|
@ -249,10 +275,17 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
|||
* \param args The arguments
|
||||
* \param type_codes The type codes of the arguments
|
||||
* \param num_args Number of arguments.
|
||||
* \param ret The return value handle.
|
||||
* \param resource_handle The handle additional resouce handle from fron-end.
|
||||
*
|
||||
* \sa TVMCFuncSetReturn
|
||||
*/
|
||||
typedef void (*TVMPackedCFunc)(
|
||||
TVMValue* args, int* type_codes, int num_args, void* resource_handle);
|
||||
TVMValue* args,
|
||||
int* type_codes,
|
||||
int num_args,
|
||||
TVMRetValueHandle ret,
|
||||
void* resource_handle);
|
||||
|
||||
/*!
|
||||
* \brief C callback to free the resource handle in C packed function.
|
||||
|
@ -291,8 +324,20 @@ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
|
|||
*
|
||||
* \param name The name of the function.
|
||||
* \param out the result function pointer.
|
||||
*
|
||||
* \note The function handle of global function is managed by TVM runtime,
|
||||
* So TVMFuncFree is should not be called when it get deleted.
|
||||
*/
|
||||
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
|
||||
|
||||
/*!
|
||||
* \brief List all the globally registered function name
|
||||
* \param out_size The number of functions
|
||||
* \param out_array The array of function names.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
|
||||
const char*** out_array);
|
||||
} // TVM_EXTERN_C
|
||||
|
||||
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
|
||||
|
|
|
@ -1,19 +1,41 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file packed_func.h
|
||||
* \brief Runtime related c++ class.
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
|
||||
#define TVM_RUNTIME_PACKED_FUNC_H_
|
||||
|
||||
#include <dmlc/logging.h>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include "./c_runtime_api.h"
|
||||
|
||||
namespace Halide {
|
||||
// Forward declare type for extensions
|
||||
// The header works fine without depending on this.
|
||||
struct Type;
|
||||
struct Expr;
|
||||
}
|
||||
|
||||
namespace tvm {
|
||||
// Forward declare NodeRef and Node for extensions.
|
||||
// This header works fine without depend on NodeRef
|
||||
// as long as it is not used.
|
||||
class Node;
|
||||
class NodeRef;
|
||||
|
||||
namespace runtime {
|
||||
// forward declarations
|
||||
class TVMArgs;
|
||||
class TVMArgValue;
|
||||
class TVMRetValue;
|
||||
class TVMArgsSetter;
|
||||
|
||||
/*!
|
||||
* \brief Packed function is a type-erased function.
|
||||
|
@ -25,8 +47,25 @@ namespace runtime {
|
|||
*/
|
||||
class PackedFunc {
|
||||
public:
|
||||
/*! \brief The internal std::function */
|
||||
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
|
||||
/*!
|
||||
* \brief The internal std::function
|
||||
* \param args The arguments to the function.
|
||||
* \param rv The return value.
|
||||
*
|
||||
* \code
|
||||
* // Example code on how to implemented FType
|
||||
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
|
||||
* // automatically convert arguments to desired type.
|
||||
* int a0 = args[0];
|
||||
* float a1 = args[1];
|
||||
* ...
|
||||
* // automatically assign values to rv
|
||||
* std::string my_return_value = "x";
|
||||
* *rv = my_return_value;
|
||||
* }
|
||||
* \endcode
|
||||
*/
|
||||
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
|
||||
/*! \brief default constructor */
|
||||
PackedFunc() {}
|
||||
/*!
|
||||
|
@ -38,16 +77,24 @@ class PackedFunc {
|
|||
* \brief Call packed function by directly passing in unpacked format.
|
||||
* \param args Arguments to be passed.
|
||||
* \tparam Args arguments to be passed.
|
||||
*
|
||||
* \code
|
||||
* // Example code on how to call packed function
|
||||
* void CallPacked(PackedFunc f) {
|
||||
* // call like normal functions by pass in arguments
|
||||
* // return value is automatically converted back
|
||||
* int rvalue = f(1, 2.0);
|
||||
* }
|
||||
* \endcode
|
||||
*/
|
||||
template<typename... Args>
|
||||
inline void operator()(Args&& ...args) const;
|
||||
inline TVMRetValue operator()(Args&& ...args) const;
|
||||
/*!
|
||||
* \brief Call the function in packed format.
|
||||
* \param args The arguments
|
||||
* \param type_codes The type_codes of the arguments
|
||||
* \param num_args Number of arguments.
|
||||
* \param rv The return value.
|
||||
*/
|
||||
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
|
||||
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
|
||||
/*! \return the internal body function */
|
||||
inline FType body() const;
|
||||
/*!
|
||||
|
@ -74,82 +121,552 @@ class PackedFunc {
|
|||
FType body_;
|
||||
};
|
||||
|
||||
// implementations
|
||||
inline void PackedFunc::CallPacked(
|
||||
const TVMValue* args, const int* type_codes, int num_args) const {
|
||||
body_(args, type_codes, num_args);
|
||||
/*! \brief Arguments into TVM functions. */
|
||||
class TVMArgs {
|
||||
public:
|
||||
const TVMValue* values;
|
||||
const int* type_codes;
|
||||
int num_args;
|
||||
/*!
|
||||
* \brief constructor
|
||||
* \param values The argument values
|
||||
* \param type_codes The argument type codes
|
||||
* \param num_args number of arguments.
|
||||
*/
|
||||
TVMArgs(const TVMValue* values,
|
||||
const int* type_codes,
|
||||
int num_args)
|
||||
: values(values),
|
||||
type_codes(type_codes),
|
||||
num_args(num_args) { }
|
||||
/*! \return size of the arguments */
|
||||
inline int size() const;
|
||||
/*!
|
||||
* \brief Get i-th argument
|
||||
* \param i the index.
|
||||
* \return the ith argument.
|
||||
*/
|
||||
inline TVMArgValue operator[](int i) const;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Convert type code to its name
|
||||
* \param type_code The type code .
|
||||
* \return The name of type code.
|
||||
*/
|
||||
inline const char* TypeCode2Str(int type_code);
|
||||
|
||||
/*!
|
||||
* \brief convert a string to TVM type.
|
||||
* \param s The string to be converted.
|
||||
* \return The corresponding tvm type.
|
||||
*/
|
||||
inline TVMType String2TVMType(std::string s);
|
||||
|
||||
// macro to check type code.
|
||||
#define TVM_CHECK_TYPE_CODE(CODE, T) \
|
||||
CHECK_EQ(CODE, T) << " expected " \
|
||||
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
|
||||
|
||||
/*!
|
||||
* \brief Internal base class to
|
||||
* handle conversion to POD values.
|
||||
*/
|
||||
class TVMPODValue_ {
|
||||
public:
|
||||
operator double() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kFloat);
|
||||
return value_.v_float64;
|
||||
}
|
||||
operator int64_t() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||
return value_.v_int64;
|
||||
}
|
||||
operator uint64_t() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||
return value_.v_int64;
|
||||
}
|
||||
operator int() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||
CHECK_LE(value_.v_int64,
|
||||
std::numeric_limits<int>::max());
|
||||
return static_cast<int>(value_.v_int64);
|
||||
}
|
||||
operator bool() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||
return value_.v_int64 != 0;
|
||||
}
|
||||
operator void*() const {
|
||||
if (type_code_ == kNull) return nullptr;
|
||||
if (type_code_ == kArrayHandle) return value_.v_handle;
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kHandle);
|
||||
return value_.v_handle;
|
||||
}
|
||||
operator TVMArray*() const {
|
||||
if (type_code_ == kNull) return nullptr;
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle);
|
||||
return static_cast<TVMArray*>(value_.v_handle);
|
||||
}
|
||||
int type_code() const {
|
||||
return type_code_;
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class TVMArgsSetter;
|
||||
friend class TVMRetValue;
|
||||
TVMPODValue_() : type_code_(kNull) {}
|
||||
TVMPODValue_(TVMValue value, int type_code)
|
||||
: value_(value), type_code_(type_code) {}
|
||||
/*!
|
||||
* \brief return handle as specific pointer type.
|
||||
* \tparam T the data type.
|
||||
* \return The pointer type.
|
||||
*/
|
||||
template<typename T>
|
||||
T* ptr() const {
|
||||
return static_cast<T*>(value_.v_handle);
|
||||
}
|
||||
/*! \brief The value */
|
||||
TVMValue value_;
|
||||
/*! \brief the type code */
|
||||
int type_code_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A single argument value to PackedFunc.
|
||||
* Containing both type_code and TVMValue
|
||||
*
|
||||
* Provides utilities to do type cast into other types.
|
||||
*/
|
||||
class TVMArgValue : public TVMPODValue_ {
|
||||
public:
|
||||
/*!
|
||||
* \brief constructor
|
||||
* \param value of the function
|
||||
* \param type_code The type code.
|
||||
*/
|
||||
TVMArgValue(TVMValue value, int type_code)
|
||||
: TVMPODValue_(value, type_code) {
|
||||
}
|
||||
// reuse converter from parent
|
||||
using TVMPODValue_::operator double;
|
||||
using TVMPODValue_::operator int64_t;
|
||||
using TVMPODValue_::operator uint64_t;
|
||||
using TVMPODValue_::operator int;
|
||||
using TVMPODValue_::operator bool;
|
||||
using TVMPODValue_::operator void*;
|
||||
using TVMPODValue_::operator TVMArray*;
|
||||
// conversion operator.
|
||||
operator std::string() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kStr);
|
||||
return std::string(value_.v_str);
|
||||
}
|
||||
operator TVMType() const {
|
||||
if (type_code_ == kStr) {
|
||||
return String2TVMType(operator std::string());
|
||||
}
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
|
||||
return value_.v_type;
|
||||
}
|
||||
operator PackedFunc() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
|
||||
return *ptr<PackedFunc>();
|
||||
}
|
||||
const TVMValue& value() const {
|
||||
return value_;
|
||||
}
|
||||
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||
template<typename TNodeRef,
|
||||
typename = typename std::enable_if<
|
||||
std::is_class<TNodeRef>::value>::type>
|
||||
inline operator TNodeRef() const;
|
||||
template<typename TNodeRef,
|
||||
typename = typename std::enable_if<
|
||||
std::is_class<TNodeRef>::value>::type>
|
||||
inline bool IsNodeType() const;
|
||||
inline operator Halide::Type() const;
|
||||
inline operator Halide::Expr() const;
|
||||
// get internal node ptr, if it is node
|
||||
inline std::shared_ptr<Node>& node_sptr();
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Return Value container,
|
||||
* Unlike TVMArgValue, which only holds reference and do not delete
|
||||
* the underlying container during destruction.
|
||||
*
|
||||
* TVMRetValue holds value and will manage the underlying containers
|
||||
* when it stores a complicated data type.
|
||||
*/
|
||||
class TVMRetValue : public TVMPODValue_ {
|
||||
public:
|
||||
/*! \brief default constructor */
|
||||
TVMRetValue() {}
|
||||
/*!
|
||||
* \brief move constructor from anoter return value.
|
||||
* \param other The other return value.
|
||||
*/
|
||||
TVMRetValue(TVMRetValue&& other)
|
||||
: TVMPODValue_(other.value_, other.type_code_) {
|
||||
other.type_code_ = kNull;
|
||||
}
|
||||
/*! \brief destructor */
|
||||
~TVMRetValue() {
|
||||
this->Clear();
|
||||
}
|
||||
// reuse converter from parent
|
||||
using TVMPODValue_::operator double;
|
||||
using TVMPODValue_::operator int64_t;
|
||||
using TVMPODValue_::operator uint64_t;
|
||||
using TVMPODValue_::operator int;
|
||||
using TVMPODValue_::operator bool;
|
||||
using TVMPODValue_::operator void*;
|
||||
using TVMPODValue_::operator TVMArray*;
|
||||
// Disable copy and assign from another value, but allow move.
|
||||
TVMRetValue(const TVMRetValue& other) {
|
||||
this->Assign(other);
|
||||
}
|
||||
// conversion operators
|
||||
operator std::string() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kStr);
|
||||
return *ptr<std::string>();
|
||||
}
|
||||
operator TVMType() const {
|
||||
if (type_code_ == kStr) {
|
||||
return String2TVMType(operator std::string());
|
||||
}
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
|
||||
return value_.v_type;
|
||||
}
|
||||
operator PackedFunc() const {
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
|
||||
return *ptr<PackedFunc>();
|
||||
}
|
||||
// Assign operators
|
||||
TVMRetValue& operator=(TVMRetValue&& other) {
|
||||
this->Clear();
|
||||
value_ = other.value_;
|
||||
type_code_ = other.type_code_;
|
||||
other.type_code_ = kNull;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(double value) {
|
||||
this->SwitchToPOD(kFloat);
|
||||
value_.v_float64 = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(std::nullptr_t value) {
|
||||
this->SwitchToPOD(kNull);
|
||||
value_.v_handle = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(void* value) {
|
||||
this->SwitchToPOD(kHandle);
|
||||
value_.v_handle = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(int64_t value) {
|
||||
this->SwitchToPOD(kInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(int value) {
|
||||
this->SwitchToPOD(kInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(TVMType t) {
|
||||
this->SwitchToPOD(kTVMType);
|
||||
value_.v_type = t;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(bool value) {
|
||||
this->SwitchToPOD(kInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(std::string value) {
|
||||
this->SwitchToClass(kStr, value);
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(PackedFunc f) {
|
||||
this->SwitchToClass(kFuncHandle, f);
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
|
||||
this->Assign(other);
|
||||
return *this;
|
||||
}
|
||||
TVMRetValue& operator=(TVMArgValue other) {
|
||||
this->Assign(other);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Move the value back to front-end via C API.
|
||||
* This marks the current container as null.
|
||||
* The managed resources is moved to front-end and
|
||||
* the front end should take charge in managing them.
|
||||
*
|
||||
* \param ret_value The return value.
|
||||
* \param ret_type_code The return type code.
|
||||
*/
|
||||
void MoveToCHost(TVMValue* ret_value,
|
||||
int* ret_type_code) {
|
||||
// cannot move str; need specially handle.
|
||||
CHECK(type_code_ != kStr);
|
||||
*ret_value = value_;
|
||||
*ret_type_code = type_code_;
|
||||
type_code_ = kNull;
|
||||
}
|
||||
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||
inline TVMRetValue& operator=(const NodeRef& other);
|
||||
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
|
||||
template<typename TNodeRef,
|
||||
typename = typename std::enable_if<
|
||||
std::is_class<TNodeRef>::value>::type>
|
||||
inline operator TNodeRef() const;
|
||||
// type related
|
||||
inline operator Halide::Type() const;
|
||||
inline TVMRetValue& operator=(const Halide::Type& other);
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
void Assign(const T& other) {
|
||||
switch (other.type_code()) {
|
||||
case kStr: {
|
||||
SwitchToClass<std::string>(kStr, other);
|
||||
break;
|
||||
}
|
||||
case kFuncHandle: {
|
||||
SwitchToClass<PackedFunc>(kFuncHandle, other);
|
||||
break;
|
||||
}
|
||||
case kNodeHandle: {
|
||||
SwitchToClass<std::shared_ptr<Node> >(
|
||||
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
SwitchToPOD(other.type_code());
|
||||
value_ = other.value_;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get the internal container.
|
||||
void SwitchToPOD(int type_code) {
|
||||
if (type_code_ != type_code) {
|
||||
this->Clear();
|
||||
type_code_ = type_code;
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SwitchToClass(int type_code, T v) {
|
||||
if (type_code_ != type_code) {
|
||||
this->Clear();
|
||||
type_code_ = type_code;
|
||||
value_.v_handle = new T(v);
|
||||
} else {
|
||||
*static_cast<T*>(value_.v_handle) = v;
|
||||
}
|
||||
}
|
||||
void Clear() {
|
||||
if (type_code_ == kNull) return;
|
||||
switch (type_code_) {
|
||||
case kStr: delete ptr<std::string>(); break;
|
||||
case kFuncHandle: delete ptr<PackedFunc>(); break;
|
||||
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
|
||||
}
|
||||
type_code_ = kNull;
|
||||
}
|
||||
};
|
||||
|
||||
// implementation details
|
||||
inline const char* TypeCode2Str(int type_code) {
|
||||
switch (type_code) {
|
||||
case kInt: return "int";
|
||||
case kFloat: return "float";
|
||||
case kStr: return "str";
|
||||
case kHandle: return "Handle";
|
||||
case kNull: return "NULL";
|
||||
case kNodeHandle: return "NodeHandle";
|
||||
case kArrayHandle: return "ArrayHandle";
|
||||
case kTVMType: return "TVMType";
|
||||
case kFuncHandle: return "FunctionHandle";
|
||||
default: LOG(FATAL) << "unknown type_code="
|
||||
<< static_cast<int>(type_code); return "";
|
||||
}
|
||||
}
|
||||
|
||||
inline TVMType String2TVMType(std::string s) {
|
||||
TVMType t;
|
||||
t.bits = 32; t.lanes = 1;
|
||||
const char* scan;
|
||||
if (s.substr(0, 3) == "int") {
|
||||
t.code = kInt; scan = s.c_str() + 3;
|
||||
} else if (s.substr(0, 4) == "uint") {
|
||||
t.code = kUInt; scan = s.c_str() + 4;
|
||||
} else if (s.substr(0, 5) == "float") {
|
||||
t.code = kFloat; scan = s.c_str() + 5;
|
||||
} else if (s == "handle") {
|
||||
t.code = kHandle;
|
||||
t.bits = 64; // handle uses 64 bit by default.
|
||||
scan = s.c_str() + 6;
|
||||
} else {
|
||||
scan = s.c_str();
|
||||
LOG(FATAL) << "unknown type " << s;
|
||||
}
|
||||
unsigned bits = t.bits, lanes = t.lanes;
|
||||
sscanf(scan, "%ux%u", &bits, &lanes);
|
||||
t.bits = static_cast<uint8_t>(bits);
|
||||
t.lanes = static_cast<uint16_t>(lanes);
|
||||
return t;
|
||||
}
|
||||
|
||||
inline TVMArgValue TVMArgs::operator[](int i) const {
|
||||
CHECK_LT(i, num_args)
|
||||
<< "not enough argument passed, "
|
||||
<< num_args << " passed"
|
||||
<< "but request arg" << i;
|
||||
return TVMArgValue(values[i], type_codes[i]);
|
||||
}
|
||||
|
||||
inline int TVMArgs::size() const {
|
||||
return num_args;
|
||||
}
|
||||
|
||||
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
|
||||
body_(args, rv);
|
||||
}
|
||||
|
||||
inline PackedFunc::FType PackedFunc::body() const {
|
||||
return body_;
|
||||
}
|
||||
|
||||
// internal namespace
|
||||
namespace detail {
|
||||
template<bool stop, std::size_t I, typename F, typename ...Args>
|
||||
struct for_each_dispatcher_ {
|
||||
static inline void run(const std::tuple<Args...>& args, F f) {
|
||||
struct for_each_dispatcher {
|
||||
static void run(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
|
||||
f(I, std::get<I>(args));
|
||||
for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
|
||||
for_each_dispatcher<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
|
||||
}
|
||||
};
|
||||
|
||||
template<std::size_t I, typename F, typename ...Args>
|
||||
struct for_each_dispatcher_<true, I, F, Args...> {
|
||||
static inline void run(const std::tuple<Args...>& args, F f) {}
|
||||
struct for_each_dispatcher<true, I, F, Args...> {
|
||||
static void run(std::tuple<Args...>& args, const F& f) {} // NOLINT(*)
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template<typename F, typename ...Args>
|
||||
inline void for_each(const std::tuple<Args...>& args, F f) {
|
||||
for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
|
||||
inline void for_each(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
|
||||
detail::for_each_dispatcher<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
|
||||
}
|
||||
|
||||
namespace arg_setter {
|
||||
template<typename T>
|
||||
inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*)
|
||||
template<>
|
||||
inline void Set<double>(TVMValue& arg, int& t, double value) { // NOLINT(*)
|
||||
arg.v_float64 = value;
|
||||
t = kFloat;
|
||||
/* \brief argument settter to PackedFunc */
|
||||
class TVMArgsSetter {
|
||||
public:
|
||||
TVMArgsSetter(TVMValue* values, int* type_codes)
|
||||
: values_(values), type_codes_(type_codes) {}
|
||||
// setters for POD types
|
||||
template<typename T,
|
||||
typename = typename std::enable_if<std::is_integral<T>::value>::type>
|
||||
void operator()(size_t i, T value) const {
|
||||
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||
type_codes_[i] = kInt;
|
||||
}
|
||||
template<>
|
||||
inline void Set<int>(TVMValue& arg, int& t, int value) { // NOLINT(*)
|
||||
arg.v_int64 = value;
|
||||
t = kInt;
|
||||
void operator()(size_t i, uint64_t value) const {
|
||||
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||
CHECK_LE(value,
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
|
||||
type_codes_[i] = kInt;
|
||||
}
|
||||
template<>
|
||||
inline void Set<long>(TVMValue& arg, int& t, long value) { // NOLINT(*)
|
||||
arg.v_int64 = value;
|
||||
t = kInt;
|
||||
void operator()(size_t i, double value) const {
|
||||
values_[i].v_float64 = value;
|
||||
type_codes_[i] = kFloat;
|
||||
}
|
||||
template<>
|
||||
inline void Set<TVMArray*>(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*)
|
||||
arg.v_handle = value;
|
||||
t = kHandle;
|
||||
void operator()(size_t i, std::nullptr_t value) const {
|
||||
values_[i].v_handle = value;
|
||||
type_codes_[i] = kNull;
|
||||
}
|
||||
template<>
|
||||
inline void Set<void*>(TVMValue& arg, int& t, void* value) { // NOLINT(*)
|
||||
arg.v_handle = value;
|
||||
t = kHandle;
|
||||
void operator()(size_t i, const TVMArgValue& value) const {
|
||||
values_[i] = value.value_;
|
||||
type_codes_[i] = value.type_code_;
|
||||
}
|
||||
} // namespace arg_setter
|
||||
void operator()(size_t i, void* value) const {
|
||||
values_[i].v_handle = value;
|
||||
type_codes_[i] = kHandle;
|
||||
}
|
||||
void operator()(size_t i, TVMArray* value) const {
|
||||
values_[i].v_handle = value;
|
||||
type_codes_[i] = kArrayHandle;
|
||||
}
|
||||
void operator()(size_t i, TVMType value) const {
|
||||
values_[i].v_type = value;
|
||||
type_codes_[i] = kTVMType;
|
||||
}
|
||||
void operator()(size_t i, const char* value) const {
|
||||
values_[i].v_str = value;
|
||||
type_codes_[i] = kStr;
|
||||
}
|
||||
// setters for container type
|
||||
// They must be reference(instead of const ref)
|
||||
// to make sure they are alive in the tuple(instead of getting converted)
|
||||
void operator()(size_t i, std::string& value) const { // NOLINT(*)
|
||||
values_[i].v_str = value.c_str();
|
||||
type_codes_[i] = kStr;
|
||||
}
|
||||
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*)
|
||||
values_[i].v_handle = &value;
|
||||
type_codes_[i] = kFuncHandle;
|
||||
}
|
||||
void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*)
|
||||
if (value.type_code() == kStr) {
|
||||
values_[i].v_str = value.ptr<std::string>()->c_str();
|
||||
type_codes_[i] = kStr;
|
||||
} else {
|
||||
values_[i] = value.value_;
|
||||
type_codes_[i] = value.type_code();
|
||||
}
|
||||
}
|
||||
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||
inline void operator()(size_t i, NodeRef& other) const; // NOLINT(*)
|
||||
inline void operator()(size_t i, const Halide::Type& t) const;
|
||||
|
||||
private:
|
||||
/*! \brief The values fields */
|
||||
TVMValue* values_;
|
||||
/*! \brief The type code fields */
|
||||
int* type_codes_;
|
||||
};
|
||||
|
||||
class TVMArgsGetter {
|
||||
public:
|
||||
explicit TVMArgsGetter(TVMArgs args)
|
||||
: args_(args) {}
|
||||
|
||||
struct PackedFuncArgSetter {
|
||||
TVMValue* args;
|
||||
int* type_codes;
|
||||
template<typename T>
|
||||
inline void operator()(size_t i, T v) const {
|
||||
arg_setter::Set(args[i], type_codes[i], v);
|
||||
inline void operator()(size_t i, T& target) const { // NOLINT(*)
|
||||
target = args_[i].operator T();
|
||||
}
|
||||
private:
|
||||
TVMArgs args_;
|
||||
};
|
||||
|
||||
template<typename... Args>
|
||||
inline void PackedFunc::operator()(Args&& ...args) const {
|
||||
auto targ = std::make_tuple(std::forward<Args>(args)...);
|
||||
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
|
||||
auto targs = std::make_tuple(std::forward<Args>(args)...);
|
||||
const int kNumArgs = sizeof...(Args);
|
||||
TVMValue tvm_args[kNumArgs];
|
||||
int tvm_arg_type_ids[kNumArgs];
|
||||
for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids});
|
||||
body_(tvm_args, tvm_arg_type_ids, kNumArgs);
|
||||
TVMValue values[kNumArgs];
|
||||
int type_codes[kNumArgs];
|
||||
for_each(targs, TVMArgsSetter(values, type_codes));
|
||||
TVMRetValue rv;
|
||||
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
|
||||
return rv;
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_PACKED_FUNC_H_
|
||||
|
|
|
@ -10,5 +10,6 @@
|
|||
#include "./expr.h"
|
||||
#include "./tensor.h"
|
||||
#include "./operation.h"
|
||||
#include "./packed_func_ext.h"
|
||||
|
||||
#endif // TVM_TVM_H_
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# pylint: disable=redefined-builtin, wildcard-import
|
||||
"""C++ backend related python scripts"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import register_node
|
||||
from ._ctypes._node import register_node
|
||||
|
||||
from . import tensor
|
||||
from . import expr
|
||||
|
|
|
@ -91,45 +91,3 @@ def c_array(ctype, values):
|
|||
Created ctypes array
|
||||
"""
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
|
||||
"""Convert ctypes returned doc string information into parameters docstring.
|
||||
|
||||
num_args : nn_uint
|
||||
Number of arguments.
|
||||
|
||||
arg_names : ctypes.POINTER(ctypes.c_char_p)
|
||||
Argument names.
|
||||
|
||||
arg_types : ctypes.POINTER(ctypes.c_char_p)
|
||||
Argument type information.
|
||||
|
||||
arg_descs : ctypes.POINTER(ctypes.c_char_p)
|
||||
Argument description information.
|
||||
|
||||
remove_dup : boolean, optional
|
||||
Whether remove duplication or not.
|
||||
|
||||
Returns
|
||||
-------
|
||||
docstr : str
|
||||
Python docstring of parameter sections.
|
||||
"""
|
||||
param_keys = set()
|
||||
param_str = []
|
||||
for i in range(num_args.value):
|
||||
key = py_str(arg_names[i])
|
||||
if key in param_keys and remove_dup:
|
||||
continue
|
||||
param_keys.add(key)
|
||||
type_info = py_str(arg_types[i])
|
||||
ret = '%s : %s' % (key, type_info)
|
||||
if len(arg_descs[i]) != 0:
|
||||
ret += '\n ' + py_str(arg_descs[i])
|
||||
param_str.append(ret)
|
||||
doc_str = ('Parameters\n' +
|
||||
'----------\n' +
|
||||
'%s\n')
|
||||
doc_str = doc_str % ('\n'.join(param_str))
|
||||
return doc_str
|
||||
|
|
|
@ -1,416 +0,0 @@
|
|||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
|
||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements
|
||||
"""Symbolic configuration API."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
import ctypes
|
||||
import sys
|
||||
from numbers import Number, Integral
|
||||
|
||||
from .._base import _LIB
|
||||
from .._base import c_str, py_str, string_types
|
||||
from .._base import check_call, ctypes2docstring
|
||||
from .. import _api_internal
|
||||
from . import _runtime_api
|
||||
from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer
|
||||
|
||||
# type definitions
|
||||
APIFuncHandle = ctypes.c_void_p
|
||||
NodeHandle = ctypes.c_void_p
|
||||
FunctionHandle = ctypes.c_void_p
|
||||
|
||||
class APIType(object):
|
||||
"""TVMType used in API calls"""
|
||||
INT = ctypes.c_int(TypeCode.INT)
|
||||
UINT = ctypes.c_int(TypeCode.UINT)
|
||||
FLOAT = ctypes.c_int(TypeCode.FLOAT)
|
||||
HANDLE = ctypes.c_int(TypeCode.HANDLE)
|
||||
NULL = ctypes.c_int(TypeCode.NULL)
|
||||
NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE)
|
||||
STR = ctypes.c_int(TypeCode.STR)
|
||||
FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE)
|
||||
|
||||
|
||||
NODE_TYPE = {
|
||||
}
|
||||
|
||||
def _return_node(x):
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, NodeHandle):
|
||||
handle = NodeHandle(handle)
|
||||
ret_val = TVMValue()
|
||||
ret_type_code = ctypes.c_int()
|
||||
ret_success = ctypes.c_int()
|
||||
check_call(_LIB.TVMNodeGetAttr(
|
||||
handle, c_str("type_key"),
|
||||
ctypes.byref(ret_val),
|
||||
ctypes.byref(ret_type_code),
|
||||
ctypes.byref(ret_success)))
|
||||
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
|
||||
|
||||
|
||||
def _return_func(x):
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, FunctionHandle):
|
||||
handle = FunctionHandle(handle)
|
||||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def _return_handle(x):
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, ctypes.c_void_p):
|
||||
handle = ctypes.c_void_p(handle)
|
||||
return handle
|
||||
|
||||
|
||||
RET_SWITCH = {
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||
TypeCode.STR: lambda x: py_str(x.v_str),
|
||||
TypeCode.NODE_HANDLE: _return_node,
|
||||
TypeCode.FUNC_HANDLE: _return_func
|
||||
}
|
||||
|
||||
PACK_ARG_SWITCH = {
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||
TypeCode.STR: lambda x: py_str(x.v_str),
|
||||
TypeCode.HANDLE: lambda x: _return_handle,
|
||||
}
|
||||
|
||||
|
||||
class SliceBase(object):
|
||||
"""base class of slice object"""
|
||||
pass
|
||||
|
||||
class NodeBase(object):
|
||||
"""Symbol is symbolic graph."""
|
||||
__slots__ = ["handle"]
|
||||
# pylint: disable=no-member
|
||||
def __init__(self, handle):
|
||||
"""Initialize the function with handle
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handle : SymbolHandle
|
||||
the handle to the underlying C++ Symbol
|
||||
"""
|
||||
self.handle = handle
|
||||
|
||||
def __repr__(self):
|
||||
return _api_internal._format_str(self)
|
||||
|
||||
def __del__(self):
|
||||
check_call(_LIB.TVMNodeFree(self.handle))
|
||||
|
||||
def __getattr__(self, name):
|
||||
ret_val = TVMValue()
|
||||
ret_type_code = ctypes.c_int()
|
||||
ret_success = ctypes.c_int()
|
||||
check_call(_LIB.TVMNodeGetAttr(
|
||||
self.handle, c_str(name),
|
||||
ctypes.byref(ret_val),
|
||||
ctypes.byref(ret_type_code),
|
||||
ctypes.byref(ret_success)))
|
||||
value = RET_SWITCH[ret_type_code.value](ret_val)
|
||||
if not ret_success.value:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute '%s'" % (str(type(self)), name))
|
||||
return value
|
||||
|
||||
def __hash__(self):
|
||||
return _api_internal._raw_ptr(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, NodeBase):
|
||||
return False
|
||||
return self.__hash__() == other.__hash__()
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __dir__(self):
|
||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||
size = ctypes.c_uint()
|
||||
check_call(_LIB.TVMNodeListAttrNames(
|
||||
self.handle, ctypes.byref(size), ctypes.byref(plist)))
|
||||
names = []
|
||||
for i in range(size.value):
|
||||
names.append(py_str(plist[i]))
|
||||
return names
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (None,), self.__getstate__())
|
||||
|
||||
def __getstate__(self):
|
||||
handle = self.handle
|
||||
if handle is not None:
|
||||
return {'handle': _api_internal._save_json(self)}
|
||||
else:
|
||||
return {'handle': None}
|
||||
|
||||
def __setstate__(self, state):
|
||||
# pylint: disable=assigning-non-slot
|
||||
handle = state['handle']
|
||||
if handle is not None:
|
||||
json_str = handle
|
||||
_push_arg(json_str)
|
||||
other = _api_internal._load_json(json_str)
|
||||
self.handle = other.handle
|
||||
other.handle = None
|
||||
else:
|
||||
self.handle = None
|
||||
|
||||
|
||||
def const(value, dtype=None):
|
||||
"""construct a constant"""
|
||||
if dtype is None:
|
||||
if isinstance(value, Integral):
|
||||
dtype = 'int32'
|
||||
else:
|
||||
dtype = 'float32'
|
||||
return _api_internal._const(value, dtype)
|
||||
|
||||
|
||||
def _ctypes_free_resource(rhandle):
|
||||
"""callback to free resources when it it not needed."""
|
||||
pyobj = ctypes.cast(rhandle, ctypes.py_object)
|
||||
ctypes.pythonapi.Py_DecRef(pyobj)
|
||||
|
||||
# Global callback that is always alive
|
||||
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
|
||||
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
|
||||
|
||||
def convert_to_tvm_func(pyfunc):
|
||||
"""Convert a python function to TVM function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pyfunc : python function
|
||||
The python function to be converted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tvmfunc: tvm.nd.Function
|
||||
The converted tvm function.
|
||||
"""
|
||||
local_pyfunc = pyfunc
|
||||
def cfun(args, type_codes, num_args, _):
|
||||
""" ctypes function """
|
||||
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
||||
pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
||||
local_pyfunc(*pyargs)
|
||||
handle = FunctionHandle()
|
||||
f = TVMPackedCFunc(cfun)
|
||||
# NOTE: We will need to use python-api to increase ref count of the f
|
||||
# TVM_FREE_PYOBJ will be called after it is no longer needed.
|
||||
pyobj = ctypes.py_object(f)
|
||||
ctypes.pythonapi.Py_IncRef(pyobj)
|
||||
check_call(_LIB.TVMFuncCreateFromCFunc(
|
||||
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
|
||||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def convert(value):
|
||||
"""Convert a value to expression."""
|
||||
if isinstance(value, (NodeBase, _runtime_api.FunctionBase)):
|
||||
return value
|
||||
elif isinstance(value, Number):
|
||||
return const(value)
|
||||
elif isinstance(value, string_types):
|
||||
return _api_internal._str(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value = [convert(x) for x in value]
|
||||
return _api_internal._Array(*value)
|
||||
elif isinstance(value, dict):
|
||||
vlist = []
|
||||
for it in value.items():
|
||||
if not isinstance(it[0], NodeBase):
|
||||
raise ValueError("key of map must already been a container type")
|
||||
vlist.append(it[0])
|
||||
vlist.append(convert(it[1]))
|
||||
return _api_internal._Map(*vlist)
|
||||
elif isinstance(value, SliceBase):
|
||||
return value.tensor(*value.indices)
|
||||
elif callable(value):
|
||||
return convert_to_tvm_func(value)
|
||||
else:
|
||||
raise ValueError("don't know how to handle type %s" % type(value))
|
||||
return value
|
||||
|
||||
|
||||
def _push_arg(arg):
|
||||
a = TVMValue()
|
||||
if arg is None:
|
||||
_LIB.TVMAPIPushStack(a, APIType.NULL)
|
||||
elif isinstance(arg, NodeBase):
|
||||
a.v_handle = arg.handle
|
||||
_LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE)
|
||||
elif isinstance(arg, Integral):
|
||||
a.v_int64 = ctypes.c_int64(arg)
|
||||
_LIB.TVMAPIPushStack(a, APIType.INT)
|
||||
elif isinstance(arg, Number):
|
||||
a.v_double = ctypes.c_double(arg)
|
||||
_LIB.TVMAPIPushStack(a, APIType.FLOAT)
|
||||
elif isinstance(arg, string_types):
|
||||
a.v_str = c_str(arg)
|
||||
_LIB.TVMAPIPushStack(a, APIType.STR)
|
||||
else:
|
||||
raise TypeError("Don't know how to handle type %s" % type(arg))
|
||||
|
||||
|
||||
def _make_function(handle, name):
|
||||
"""Create an atomic symbol function by handle and funciton name."""
|
||||
real_name = ctypes.c_char_p()
|
||||
desc = ctypes.c_char_p()
|
||||
num_args = ctypes.c_int()
|
||||
arg_names = ctypes.POINTER(ctypes.c_char_p)()
|
||||
arg_types = ctypes.POINTER(ctypes.c_char_p)()
|
||||
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
|
||||
ret_type = ctypes.c_char_p()
|
||||
|
||||
check_call(_LIB.TVMGetAPIFuncInfo(
|
||||
handle, ctypes.byref(real_name), ctypes.byref(desc),
|
||||
ctypes.byref(num_args),
|
||||
ctypes.byref(arg_names),
|
||||
ctypes.byref(arg_types),
|
||||
ctypes.byref(arg_descs),
|
||||
ctypes.byref(ret_type)))
|
||||
|
||||
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
|
||||
func_name = name
|
||||
desc = py_str(desc.value)
|
||||
|
||||
doc_str = ('%s\n\n' +
|
||||
'%s\n')
|
||||
doc_str = doc_str % (desc, param_str)
|
||||
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
|
||||
|
||||
def func(*args):
|
||||
"""TVM function"""
|
||||
cargs = []
|
||||
for x in args:
|
||||
if isinstance(x, (list, tuple, dict, SliceBase)):
|
||||
cargs.append(convert(x))
|
||||
else:
|
||||
cargs.append(x)
|
||||
|
||||
for arg in cargs:
|
||||
_push_arg(arg)
|
||||
ret_val = TVMValue()
|
||||
ret_type_code = ctypes.c_int()
|
||||
check_call(_LIB.TVMAPIFuncCall(
|
||||
handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code)))
|
||||
return RET_SWITCH[ret_type_code.value](ret_val)
|
||||
|
||||
func.__name__ = func_name
|
||||
func.__doc__ = doc_str
|
||||
return func
|
||||
|
||||
|
||||
def register_node(type_key=None):
|
||||
"""register node type
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type_key : str or cls
|
||||
The type key of the node
|
||||
"""
|
||||
if isinstance(type_key, str):
|
||||
def register(cls):
|
||||
"""internal register function"""
|
||||
NODE_TYPE[type_key] = cls
|
||||
return cls
|
||||
return register
|
||||
else:
|
||||
cls = type_key
|
||||
NODE_TYPE[cls.__name__] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def register_func(func_name, f=None):
|
||||
"""Register global function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func_name : str or function
|
||||
The function name
|
||||
|
||||
f : function
|
||||
The function to be registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fregister : function
|
||||
Register function if f is not specified.
|
||||
"""
|
||||
if callable(func_name):
|
||||
f = func_name
|
||||
func_name = f.__name__
|
||||
|
||||
if not isinstance(func_name, str):
|
||||
raise ValueError("expect string function name")
|
||||
def register(myf):
|
||||
"""internal register function"""
|
||||
if not isinstance(myf, _runtime_api.FunctionBase):
|
||||
myf = convert_to_tvm_func(myf)
|
||||
check_call(_LIB.TVMFuncRegisterGlobal(
|
||||
c_str(func_name), myf.handle))
|
||||
if f:
|
||||
register(f)
|
||||
else:
|
||||
return register
|
||||
|
||||
|
||||
def get_global_func(name):
|
||||
"""Get a global function by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the global function
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.nd.Function
|
||||
The function to be returned.
|
||||
"""
|
||||
handle = FunctionHandle()
|
||||
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
||||
return _runtime_api._function_cls(handle)
|
||||
|
||||
|
||||
def _init_api_module(root_namespace):
|
||||
"""List and add all the functions to current module."""
|
||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||
size = ctypes.c_uint()
|
||||
|
||||
check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size),
|
||||
ctypes.byref(plist)))
|
||||
op_names = []
|
||||
for i in range(size.value):
|
||||
op_names.append(py_str(plist[i]))
|
||||
|
||||
module_obj = sys.modules["%s.api" % root_namespace]
|
||||
module_internal = sys.modules["%s._api_internal" % root_namespace]
|
||||
namespace_match = {
|
||||
"_make_": sys.modules["%s.make" % root_namespace],
|
||||
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
|
||||
"_codegen_": sys.modules["%s.codegen" % root_namespace],
|
||||
"_schedule_": sys.modules["%s.schedule" % root_namespace]
|
||||
}
|
||||
|
||||
for name in op_names:
|
||||
hdl = APIFuncHandle()
|
||||
check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl)))
|
||||
fname = name
|
||||
target_module = module_internal if name.startswith('_') else module_obj
|
||||
for k, v in namespace_match.items():
|
||||
if name.startswith(k):
|
||||
fname = name[len(k):]
|
||||
target_module = v
|
||||
function = _make_function(hdl, fname)
|
||||
setattr(target_module, function.__name__, function)
|
|
@ -0,0 +1,250 @@
|
|||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, protected-access
|
||||
"""Symbolic configuration API."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import ctypes
|
||||
import sys
|
||||
from numbers import Number, Integral
|
||||
|
||||
from .._base import _LIB, check_call
|
||||
from .._base import c_str, py_str, string_types
|
||||
from ._types import TVMValue, TypeCode, TVMType
|
||||
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
|
||||
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
|
||||
from ._node import NodeBase, SliceBase, convert_to_node
|
||||
from ._ndarray import NDArrayBase
|
||||
|
||||
FunctionHandle = ctypes.c_void_p
|
||||
TVMRetValueHandle = ctypes.c_void_p
|
||||
|
||||
def _ctypes_free_resource(rhandle):
|
||||
"""callback to free resources when it it not needed."""
|
||||
pyobj = ctypes.cast(rhandle, ctypes.py_object)
|
||||
ctypes.pythonapi.Py_DecRef(pyobj)
|
||||
|
||||
# Global callback that is always alive
|
||||
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
|
||||
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
|
||||
|
||||
def convert_to_tvm_func(pyfunc):
|
||||
"""Convert a python function to TVM function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pyfunc : python function
|
||||
The python function to be converted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tvmfunc: tvm.nd.Function
|
||||
The converted tvm function.
|
||||
"""
|
||||
local_pyfunc = pyfunc
|
||||
def cfun(args, type_codes, num_args, ret, _):
|
||||
""" ctypes function """
|
||||
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
||||
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
||||
rv = local_pyfunc(*pyargs)
|
||||
if rv is not None:
|
||||
if isinstance(rv, tuple):
|
||||
raise ValueError("PackedFunction can only support one reurn value")
|
||||
temp_args = []
|
||||
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
|
||||
if not isinstance(ret, TVMRetValueHandle):
|
||||
ret = TVMRetValueHandle(ret)
|
||||
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
|
||||
_ = temp_args
|
||||
_ = rv
|
||||
|
||||
handle = FunctionHandle()
|
||||
f = TVMPackedCFunc(cfun)
|
||||
# NOTE: We will need to use python-api to increase ref count of the f
|
||||
# TVM_FREE_PYOBJ will be called after it is no longer needed.
|
||||
pyobj = ctypes.py_object(f)
|
||||
ctypes.pythonapi.Py_IncRef(pyobj)
|
||||
check_call(_LIB.TVMFuncCreateFromCFunc(
|
||||
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
|
||||
return Function(handle)
|
||||
|
||||
|
||||
def _make_tvm_args(args, temp_args):
|
||||
"""Pack arguments into c args tvm call accept"""
|
||||
num_args = len(args)
|
||||
values = (TVMValue * num_args)()
|
||||
type_codes = (ctypes.c_int * num_args)()
|
||||
for i, arg in enumerate(args):
|
||||
if arg is None:
|
||||
values[i].v_handle = None
|
||||
type_codes[i] = TypeCode.NULL
|
||||
elif isinstance(arg, NDArrayBase):
|
||||
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
|
||||
type_codes[i] = TypeCode.ARRAY_HANDLE
|
||||
elif isinstance(arg, NodeBase):
|
||||
values[i].v_handle = arg.handle
|
||||
type_codes[i] = TypeCode.NODE_HANDLE
|
||||
elif isinstance(arg, Integral):
|
||||
values[i].v_int64 = arg
|
||||
type_codes[i] = TypeCode.INT
|
||||
elif isinstance(arg, Number):
|
||||
values[i].v_float64 = arg
|
||||
type_codes[i] = TypeCode.FLOAT
|
||||
elif isinstance(arg, TVMType):
|
||||
values[i].v_type = arg
|
||||
type_codes[i] = TypeCode.TVM_TYPE
|
||||
elif isinstance(arg, string_types):
|
||||
values[i].v_str = c_str(arg)
|
||||
type_codes[i] = TypeCode.STR
|
||||
elif isinstance(arg, (list, tuple, dict, SliceBase)):
|
||||
arg = convert_to_node(arg)
|
||||
values[i].v_handle = arg.handle
|
||||
type_codes[i] = TypeCode.NODE_HANDLE
|
||||
temp_args.append(arg)
|
||||
elif isinstance(arg, Function):
|
||||
values[i].v_handle = arg.handle
|
||||
type_codes[i] = TypeCode.FUNC_HANDLE
|
||||
elif callable(arg):
|
||||
arg = convert_to_tvm_func(arg)
|
||||
values[i].v_handle = arg.handle
|
||||
type_codes[i] = TypeCode.FUNC_HANDLE
|
||||
temp_args.append(arg)
|
||||
else:
|
||||
raise TypeError("Don't know how to handle type %s" % type(arg))
|
||||
return values, type_codes, num_args
|
||||
|
||||
|
||||
class Function(object):
|
||||
"""A function object at runtime."""
|
||||
__slots__ = ["handle", "is_global"]
|
||||
# pylint: disable=no-member
|
||||
def __init__(self, handle, is_global=False):
|
||||
"""Initialize the function with handle
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handle : FunctionHandle
|
||||
the handle to the underlying function.
|
||||
|
||||
is_global : bool, optional
|
||||
Whether it is global function
|
||||
"""
|
||||
self.handle = handle
|
||||
self.is_global = is_global
|
||||
|
||||
def __del__(self):
|
||||
if not self.is_global:
|
||||
check_call(_LIB.TVMFuncFree(self.handle))
|
||||
|
||||
def __call__(self, *args):
|
||||
temp_args = []
|
||||
values, tcodes, num_args = _make_tvm_args(args, temp_args)
|
||||
ret_val = TVMValue()
|
||||
ret_tcode = ctypes.c_int()
|
||||
check_call(_LIB.TVMFuncCall(
|
||||
self.handle, values, tcodes, ctypes.c_int(num_args),
|
||||
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
|
||||
_ = temp_args
|
||||
_ = args
|
||||
return RETURN_SWITCH[ret_tcode.value](ret_val)
|
||||
|
||||
|
||||
def _handle_return_func(x):
|
||||
"""Return function"""
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, FunctionHandle):
|
||||
handle = FunctionHandle(handle)
|
||||
return Function(handle, False)
|
||||
|
||||
# setup return handle for function type
|
||||
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
|
||||
|
||||
def register_func(func_name, f=None):
|
||||
"""Register global function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func_name : str or function
|
||||
The function name
|
||||
|
||||
f : function
|
||||
The function to be registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fregister : function
|
||||
Register function if f is not specified.
|
||||
"""
|
||||
if callable(func_name):
|
||||
f = func_name
|
||||
func_name = f.__name__
|
||||
|
||||
if not isinstance(func_name, str):
|
||||
raise ValueError("expect string function name")
|
||||
def register(myf):
|
||||
"""internal register function"""
|
||||
if not isinstance(myf, Function):
|
||||
myf = convert_to_tvm_func(myf)
|
||||
check_call(_LIB.TVMFuncRegisterGlobal(
|
||||
c_str(func_name), myf.handle))
|
||||
if f:
|
||||
register(f)
|
||||
else:
|
||||
return register
|
||||
|
||||
|
||||
def get_global_func(name):
|
||||
"""Get a global function by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the global function
|
||||
|
||||
Returns
|
||||
-------
|
||||
func : tvm.nd.Function
|
||||
The function to be returned.
|
||||
"""
|
||||
handle = FunctionHandle()
|
||||
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
||||
return Function(handle, True)
|
||||
|
||||
|
||||
def list_global_func_names():
|
||||
"""Get list of global functions registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
names : list
|
||||
List of global functions names.
|
||||
"""
|
||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||
size = ctypes.c_uint()
|
||||
|
||||
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
|
||||
ctypes.byref(plist)))
|
||||
fnames = []
|
||||
for i in range(size.value):
|
||||
fnames.append(py_str(plist[i]))
|
||||
return fnames
|
||||
|
||||
|
||||
def _init_api_functions(root_namespace):
|
||||
"""List and add all the functions to current module."""
|
||||
module_obj = sys.modules["%s.api" % root_namespace]
|
||||
module_internal = sys.modules["%s._api_internal" % root_namespace]
|
||||
namespace_match = {
|
||||
"_make_": sys.modules["%s.make" % root_namespace],
|
||||
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
|
||||
"_codegen_": sys.modules["%s.codegen" % root_namespace],
|
||||
"_schedule_": sys.modules["%s.schedule" % root_namespace]
|
||||
}
|
||||
for name in list_global_func_names():
|
||||
fname = name
|
||||
target_module = module_internal if name.startswith('_') else module_obj
|
||||
for k, v in namespace_match.items():
|
||||
if name.startswith(k):
|
||||
fname = name[len(k):]
|
||||
target_module = v
|
||||
f = get_global_func(name)
|
||||
setattr(target_module, fname, f)
|
|
@ -1,18 +1,15 @@
|
|||
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
|
||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
|
||||
"""Symbolic configuration API."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from __future__ import absolute_import
|
||||
|
||||
import ctypes
|
||||
from numbers import Number, Integral
|
||||
import numpy as np
|
||||
|
||||
from .._base import _LIB
|
||||
from .._base import c_array, c_str, string_types
|
||||
from .._base import check_call
|
||||
from ._types import TVMValue, TypeCode, TVMType
|
||||
from .._base import _LIB, check_call
|
||||
from .._base import c_array, c_str
|
||||
from ._types import TVMType, tvm_index_t
|
||||
|
||||
tvm_index_t = ctypes.c_uint32
|
||||
|
||||
class TVMContext(ctypes.Structure):
|
||||
"""TVM context strucure."""
|
||||
|
@ -39,6 +36,19 @@ class TVMContext(ctypes.Structure):
|
|||
return ret.value != 0
|
||||
|
||||
|
||||
class TVMArray(ctypes.Structure):
|
||||
"""TVMValue in C API"""
|
||||
_fields_ = [("data", ctypes.c_void_p),
|
||||
("shape", ctypes.POINTER(tvm_index_t)),
|
||||
("strides", ctypes.POINTER(tvm_index_t)),
|
||||
("ndim", tvm_index_t),
|
||||
("dtype", TVMType),
|
||||
("ctx", TVMContext)]
|
||||
|
||||
|
||||
TVMArrayHandle = ctypes.POINTER(TVMArray)
|
||||
|
||||
|
||||
def cpu(dev_id=0):
|
||||
"""Construct a CPU device
|
||||
|
||||
|
@ -72,18 +82,6 @@ def opencl(dev_id=0):
|
|||
return TVMContext(4, dev_id)
|
||||
|
||||
|
||||
class TVMArray(ctypes.Structure):
|
||||
"""TVMValue in C API"""
|
||||
_fields_ = [("data", ctypes.c_void_p),
|
||||
("shape", ctypes.POINTER(tvm_index_t)),
|
||||
("strides", ctypes.POINTER(tvm_index_t)),
|
||||
("ndim", tvm_index_t),
|
||||
("dtype", TVMType),
|
||||
("ctx", TVMContext)]
|
||||
|
||||
TVMArrayHandle = ctypes.POINTER(TVMArray)
|
||||
|
||||
|
||||
def numpyasarray(np_data):
|
||||
"""Return a TVMArray representation of a numpy array.
|
||||
"""
|
||||
|
@ -102,7 +100,6 @@ def numpyasarray(np_data):
|
|||
|
||||
|
||||
_ndarray_cls = None
|
||||
_function_cls = None
|
||||
|
||||
|
||||
def empty(shape, dtype="float32", ctx=cpu(0)):
|
||||
|
@ -275,51 +272,6 @@ class NDArrayBase(object):
|
|||
return target
|
||||
|
||||
|
||||
class FunctionBase(object):
|
||||
"""A function object at runtim."""
|
||||
__slots__ = ["handle"]
|
||||
# pylint: disable=no-member
|
||||
def __init__(self, handle):
|
||||
"""Initialize the function with handle
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handle : FunctionHandle
|
||||
the handle to the underlying function.
|
||||
"""
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
check_call(_LIB.TVMFuncFree(self.handle))
|
||||
|
||||
def __call__(self, *args):
|
||||
num_args = len(args)
|
||||
tvm_args = (TVMValue * num_args)()
|
||||
tvm_type_code = (ctypes.c_int * num_args)()
|
||||
for i, arg in enumerate(args):
|
||||
if arg is None:
|
||||
tvm_args[i].v_handle = None
|
||||
tvm_type_code[i] = TypeCode.NULL
|
||||
elif isinstance(arg, NDArrayBase):
|
||||
tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
|
||||
tvm_type_code[i] = TypeCode.HANDLE
|
||||
elif isinstance(arg, Integral):
|
||||
tvm_args[i].v_int64 = arg
|
||||
tvm_type_code[i] = TypeCode.INT
|
||||
elif isinstance(arg, Number):
|
||||
tvm_args[i].v_float64 = arg
|
||||
tvm_type_code[i] = TypeCode.FLOAT
|
||||
elif isinstance(arg, string_types):
|
||||
tvm_args[i].v_str = c_str(arg)
|
||||
tvm_type_code[i] = TypeCode.STR
|
||||
else:
|
||||
raise TypeError("Don't know how to handle type %s" % type(arg))
|
||||
check_call(_LIB.TVMFuncCall(
|
||||
self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args)))
|
||||
|
||||
|
||||
def _init_runtime_module(ndarray_class, function_class):
|
||||
def _init_ndarray_module(ndarray_class):
|
||||
global _ndarray_cls
|
||||
global _function_cls
|
||||
_ndarray_cls = ndarray_class
|
||||
_function_cls = function_class
|
|
@ -0,0 +1,184 @@
|
|||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, protected-access
|
||||
# pylint: disable=no-member, missing-docstring
|
||||
"""Symbolic configuration API."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import ctypes
|
||||
from numbers import Number, Integral
|
||||
|
||||
from .._base import _LIB, check_call
|
||||
from .._base import c_str, py_str, string_types
|
||||
from .. import _api_internal
|
||||
from ._types import TVMValue, TypeCode, RETURN_SWITCH
|
||||
|
||||
NodeHandle = ctypes.c_void_p
|
||||
|
||||
"""Maps node type to its constructor"""
|
||||
NODE_TYPE = {
|
||||
}
|
||||
|
||||
def _return_node(x):
|
||||
"""Return function"""
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, NodeHandle):
|
||||
handle = NodeHandle(handle)
|
||||
ret_val = TVMValue()
|
||||
ret_type_code = ctypes.c_int()
|
||||
ret_success = ctypes.c_int()
|
||||
check_call(_LIB.TVMNodeGetAttr(
|
||||
handle, c_str("type_key"),
|
||||
ctypes.byref(ret_val),
|
||||
ctypes.byref(ret_type_code),
|
||||
ctypes.byref(ret_success)))
|
||||
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
|
||||
|
||||
|
||||
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
|
||||
|
||||
|
||||
class SliceBase(object):
|
||||
"""base class of slice object"""
|
||||
pass
|
||||
|
||||
class NodeBase(object):
|
||||
"""Symbol is symbolic graph."""
|
||||
__slots__ = ["handle"]
|
||||
# pylint: disable=no-member
|
||||
def __init__(self, handle):
|
||||
"""Initialize the function with handle
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handle : SymbolHandle
|
||||
the handle to the underlying C++ Symbol
|
||||
"""
|
||||
self.handle = handle
|
||||
|
||||
def __repr__(self):
|
||||
return _api_internal._format_str(self)
|
||||
|
||||
def __del__(self):
|
||||
check_call(_LIB.TVMNodeFree(self.handle))
|
||||
|
||||
def __getattr__(self, name):
|
||||
ret_val = TVMValue()
|
||||
ret_type_code = ctypes.c_int()
|
||||
ret_success = ctypes.c_int()
|
||||
check_call(_LIB.TVMNodeGetAttr(
|
||||
self.handle, c_str(name),
|
||||
ctypes.byref(ret_val),
|
||||
ctypes.byref(ret_type_code),
|
||||
ctypes.byref(ret_success)))
|
||||
if not ret_success.value:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute '%s'" % (str(type(self)), name))
|
||||
return RETURN_SWITCH[ret_type_code.value](ret_val)
|
||||
|
||||
def __hash__(self):
|
||||
return _api_internal._raw_ptr(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, NodeBase):
|
||||
return False
|
||||
return self.__hash__() == other.__hash__()
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __dir__(self):
|
||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||
size = ctypes.c_uint()
|
||||
check_call(_LIB.TVMNodeListAttrNames(
|
||||
self.handle, ctypes.byref(size), ctypes.byref(plist)))
|
||||
names = []
|
||||
for i in range(size.value):
|
||||
names.append(py_str(plist[i]))
|
||||
return names
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (None,), self.__getstate__())
|
||||
|
||||
def __getstate__(self):
|
||||
handle = self.handle
|
||||
if handle is not None:
|
||||
return {'handle': _api_internal._save_json(self)}
|
||||
else:
|
||||
return {'handle': None}
|
||||
|
||||
def __setstate__(self, state):
|
||||
# pylint: disable=assigning-non-slot
|
||||
handle = state['handle']
|
||||
if handle is not None:
|
||||
json_str = handle
|
||||
other = _api_internal._load_json(json_str)
|
||||
self.handle = other.handle
|
||||
other.handle = None
|
||||
else:
|
||||
self.handle = None
|
||||
|
||||
|
||||
def const(value, dtype=None):
|
||||
"""construct a constant"""
|
||||
if dtype is None:
|
||||
if isinstance(value, Integral):
|
||||
dtype = 'int32'
|
||||
else:
|
||||
dtype = 'float32'
|
||||
return _api_internal._const(value, dtype)
|
||||
|
||||
|
||||
def convert_to_node(value):
|
||||
"""Convert a python value to corresponding node type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : str
|
||||
The value to be inspected.
|
||||
|
||||
Returns
|
||||
-------
|
||||
node : Node
|
||||
The corresponding node value.
|
||||
"""
|
||||
if isinstance(value, NodeBase):
|
||||
return value
|
||||
elif isinstance(value, Number):
|
||||
return const(value)
|
||||
elif isinstance(value, string_types):
|
||||
return _api_internal._str(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value = [convert_to_node(x) for x in value]
|
||||
return _api_internal._Array(*value)
|
||||
elif isinstance(value, dict):
|
||||
vlist = []
|
||||
for it in value.items():
|
||||
if not isinstance(it[0], NodeBase):
|
||||
raise ValueError("key of map must already been a container type")
|
||||
vlist.append(it[0])
|
||||
vlist.append(convert_to_node(it[1]))
|
||||
return _api_internal._Map(*vlist)
|
||||
elif isinstance(value, SliceBase):
|
||||
return value.tensor(*value.indices)
|
||||
else:
|
||||
raise ValueError("don't know how to convert type %s to node" % type(value))
|
||||
|
||||
|
||||
def register_node(type_key=None):
|
||||
"""register node type
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type_key : str or cls
|
||||
The type key of the node
|
||||
"""
|
||||
if isinstance(type_key, str):
|
||||
def register(cls):
|
||||
"""internal register function"""
|
||||
NODE_TYPE[type_key] = cls
|
||||
return cls
|
||||
return register
|
||||
else:
|
||||
cls = type_key
|
||||
NODE_TYPE[cls.__name__] = cls
|
||||
return cls
|
|
@ -4,13 +4,9 @@ from __future__ import absolute_import as _abs
|
|||
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from .._base import py_str
|
||||
|
||||
class TVMValue(ctypes.Union):
|
||||
"""TVMValue in C API"""
|
||||
_fields_ = [("v_int64", ctypes.c_int64),
|
||||
("v_float64", ctypes.c_double),
|
||||
("v_handle", ctypes.c_void_p),
|
||||
("v_str", ctypes.c_char_p)]
|
||||
tvm_index_t = ctypes.c_uint32
|
||||
|
||||
class TypeCode(object):
|
||||
"""Type code used in API calls"""
|
||||
|
@ -19,9 +15,11 @@ class TypeCode(object):
|
|||
FLOAT = 2
|
||||
HANDLE = 3
|
||||
NULL = 4
|
||||
NODE_HANDLE = 5
|
||||
STR = 6
|
||||
FUNC_HANDLE = 7
|
||||
ARRAY_HANDLE = 5
|
||||
TVM_TYPE = 6
|
||||
NODE_HANDLE = 7
|
||||
STR = 8
|
||||
FUNC_HANDLE = 9
|
||||
|
||||
def _api_type(code):
|
||||
"""create a type accepted by API"""
|
||||
|
@ -40,13 +38,13 @@ class TVMType(ctypes.Structure):
|
|||
CODE2STR = {
|
||||
0 : 'int',
|
||||
1 : 'uint',
|
||||
2 : 'float'
|
||||
2 : 'float',
|
||||
4 : 'handle'
|
||||
}
|
||||
def __init__(self, type_str, lanes=1):
|
||||
super(TVMType, self).__init__()
|
||||
if isinstance(type_str, np.dtype):
|
||||
type_str = str(type_str)
|
||||
|
||||
if type_str.startswith("int"):
|
||||
self.type_code = 0
|
||||
bits = int(type_str[3:])
|
||||
|
@ -56,6 +54,9 @@ class TVMType(ctypes.Structure):
|
|||
elif type_str.startswith("float"):
|
||||
self.type_code = 2
|
||||
bits = int(type_str[5:])
|
||||
elif type_str.startswith("handle"):
|
||||
self.type_code = 4
|
||||
bits = 64
|
||||
else:
|
||||
raise ValueError("Donot know how to handle type %s" % type_str)
|
||||
|
||||
|
@ -71,15 +72,61 @@ class TVMType(ctypes.Structure):
|
|||
x += "x%d" % self.lanes
|
||||
return x
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.bits == other.bits and
|
||||
self.type_code == other.type_code and
|
||||
self.lanes == other.lanes)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class TVMValue(ctypes.Union):
|
||||
"""TVMValue in C API"""
|
||||
_fields_ = [("v_int64", ctypes.c_int64),
|
||||
("v_float64", ctypes.c_double),
|
||||
("v_handle", ctypes.c_void_p),
|
||||
("v_str", ctypes.c_char_p),
|
||||
("v_type", TVMType)]
|
||||
|
||||
|
||||
TVMPackedCFunc = ctypes.CFUNCTYPE(
|
||||
None,
|
||||
ctypes.POINTER(TVMValue),
|
||||
ctypes.POINTER(ctypes.c_int),
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p)
|
||||
|
||||
|
||||
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
|
||||
None,
|
||||
ctypes.c_void_p)
|
||||
|
||||
|
||||
def _return_handle(x):
|
||||
"""return handle"""
|
||||
handle = x.v_handle
|
||||
if not isinstance(handle, ctypes.c_void_p):
|
||||
handle = ctypes.c_void_p(handle)
|
||||
return handle
|
||||
|
||||
|
||||
RETURN_SWITCH = {
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||
TypeCode.HANDLE: _return_handle,
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.TVM_TYPE: lambda x: x.v_type,
|
||||
TypeCode.STR: lambda x: py_str(x.v_str)
|
||||
}
|
||||
|
||||
|
||||
C_TO_PY_ARG_SWITCH = {
|
||||
TypeCode.INT: lambda x: x.v_int64,
|
||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||
TypeCode.HANDLE: _return_handle,
|
||||
TypeCode.NULL: lambda x: None,
|
||||
TypeCode.TVM_TYPE: lambda x: x.v_type,
|
||||
TypeCode.STR: lambda x: py_str(x.v_str)
|
||||
}
|
||||
|
|
|
@ -2,16 +2,23 @@
|
|||
# pylint: disable=redefined-builtin, undefined-variable, unused-import
|
||||
"""Functions defined in TVM."""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from numbers import Integral as _Integral
|
||||
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
|
||||
|
||||
from ._ctypes._types import TVMType
|
||||
from ._ctypes._node import register_node, NodeBase
|
||||
from ._ctypes._node import convert_to_node as _convert_to_node
|
||||
from ._ctypes._function import Function
|
||||
from ._ctypes._function import _init_api_functions, register_func, get_global_func
|
||||
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
|
||||
from . import _api_internal
|
||||
from . import make as _make
|
||||
from . import expr as _expr
|
||||
from . import collections as _collections
|
||||
|
||||
int32 = "int32"
|
||||
float32 = "float32"
|
||||
handle = "handle"
|
||||
int32 = TVMType("int32")
|
||||
float32 = TVMType("float32")
|
||||
handle = TVMType("handle")
|
||||
|
||||
def const(value, dtype=None):
|
||||
"""construct a constant"""
|
||||
|
@ -266,4 +273,25 @@ def Schedule(ops):
|
|||
return _api_internal._Schedule(ops)
|
||||
|
||||
|
||||
_init_api_module("tvm")
|
||||
def convert(value):
|
||||
"""Convert value to TVM node or function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : python value
|
||||
|
||||
Returns
|
||||
-------
|
||||
tvm_val : Node or function
|
||||
Converted value in TVM
|
||||
"""
|
||||
if isinstance(value, (Function, NodeBase)):
|
||||
return value
|
||||
|
||||
if callable(value):
|
||||
return _convert_tvm_func(value)
|
||||
else:
|
||||
return _convert_to_node(value)
|
||||
|
||||
|
||||
_init_api_functions("tvm")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# pylint: disable=protected-access, no-member
|
||||
"""Collection structure in the high level DSL."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import NodeBase, register_node
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
from . import _api_internal
|
||||
from . import expr as _expr
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# pylint: disable=protected-access, no-member, missing-docstring
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import NodeBase, register_node
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
from . import make as _make
|
||||
|
||||
class ExprOp(object):
|
||||
|
|
|
@ -6,11 +6,11 @@ This is a simplified runtime API for quick testing and proptyping.
|
|||
from __future__ import absolute_import as _abs
|
||||
import numpy as _np
|
||||
|
||||
from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase
|
||||
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
|
||||
from ._ctypes._runtime_api import _init_runtime_module
|
||||
from ._ctypes._runtime_api import init_opencl
|
||||
|
||||
from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase
|
||||
from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync
|
||||
from ._ctypes._ndarray import _init_ndarray_module
|
||||
from ._ctypes._ndarray import init_opencl
|
||||
from ._ctypes._function import Function
|
||||
|
||||
class NDArray(NDArrayBase):
|
||||
"""Lightweight NDArray class of TVM runtime.
|
||||
|
@ -26,11 +26,6 @@ class NDArray(NDArrayBase):
|
|||
pass
|
||||
|
||||
|
||||
class Function(FunctionBase):
|
||||
"""Function class that can executed a generated code."""
|
||||
pass
|
||||
|
||||
|
||||
def array(arr, ctx=cpu(0)):
|
||||
"""Create an array from source arr.
|
||||
|
||||
|
@ -54,4 +49,4 @@ def array(arr, ctx=cpu(0)):
|
|||
return ret
|
||||
|
||||
|
||||
_init_runtime_module(NDArray, Function)
|
||||
_init_ndarray_module(NDArray)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# pylint: disable=protected-access, no-member
|
||||
"""Collection structure in the high level DSL."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import NodeBase, register_node
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
from . import _api_internal
|
||||
from . import tensor as _tensor
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# pylint: disable=protected-access, no-member, missing-docstring
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import NodeBase, register_node
|
||||
from ._ctypes._node import NodeBase, register_node
|
||||
|
||||
class Stmt(NodeBase):
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# pylint: disable=protected-access, no-member, invalid-name
|
||||
"""Tensor related abstractions"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
|
||||
from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
|
||||
from . import collections as _collections
|
||||
from . import _api_internal
|
||||
from . import make as _make
|
||||
|
@ -26,7 +26,7 @@ class Tensor(NodeBase):
|
|||
ndim = self.ndim
|
||||
if len(indices) != ndim:
|
||||
raise ValueError("Need to provide %d index in tensor slice" % ndim)
|
||||
indices = convert(indices)
|
||||
indices = convert_to_node(indices)
|
||||
args = []
|
||||
for x in indices:
|
||||
if isinstance(x, _collections.IterVar):
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Implementation of basic API functions
|
||||
* \file api_base.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/api_registry.h>
|
||||
|
||||
namespace tvm {
|
||||
|
||||
TVM_REGISTER_API(_format_str)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
CHECK(args[0].type_code() == kNodeHandle);
|
||||
std::ostringstream os;
|
||||
os << args[0].operator NodeRef();
|
||||
*ret = os.str();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_raw_ptr)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
CHECK(args[0].type_code() == kNodeHandle);
|
||||
*ret = reinterpret_cast<int64_t>(
|
||||
args[0].node_sptr().get());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_save_json)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = SaveJSON(args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_load_json)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = NodeRef(LoadJSON_(args[0]));
|
||||
});
|
||||
|
||||
} // namespace tvm
|
|
@ -0,0 +1,57 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions related to Codegen
|
||||
* \file c_api_codegen.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/codegen.h>
|
||||
#include <tvm/api_registry.h>
|
||||
#include "../codegen/codegen_c.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
TVM_REGISTER_API(_codegen_CompileToC)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = CodeGenC().Compile(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_MakeAPI)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = MakeAPI(
|
||||
args[0], args[1], args[2], args[3]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_SplitHostDevice)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = SplitHostDevice(args[0]);
|
||||
});
|
||||
|
||||
// generate a dummy packed function for testing
|
||||
void DummyHelloFunction(TVMArgs args, TVMRetValue* rv) {
|
||||
LOG(INFO) << args.size() << " arguments";
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
switch (args.type_codes[i]) {
|
||||
case kNull: LOG(INFO) << i << ":nullptr"; break;
|
||||
case kFloat: LOG(INFO) << i << ": double=" << args.values[i].v_float64; break;
|
||||
case kInt: LOG(INFO) << i << ": long=" << args.values[i].v_int64; break;
|
||||
case kHandle: LOG(INFO) << i << ": handle=" << args.values[i].v_handle; break;
|
||||
case kArrayHandle: LOG(INFO) << i << ": array_handle=" << args.values[i].v_handle; break;
|
||||
default: LOG(FATAL) << "unhandled type " << runtime::TypeCode2Str(args.type_codes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_API(_codegen_DummyHelloFunction)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = runtime::PackedFunc(DummyHelloFunction);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_BuildStackVM)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = BuildStackVM(args[0]);
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -1,98 +1,93 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions related to IR build
|
||||
* \file c_api_ir.cc
|
||||
* \file api_ir.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <ir/IROperator.h>
|
||||
#include "./c_api_registry.h"
|
||||
#include <tvm/api_registry.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace ir {
|
||||
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
TVM_REGISTER_API(_Var)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Variable::make(args.at(1), args.at(0));
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Variable::make(args[1], args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_make_For)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = For::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
static_cast<ForType>(args.at(3).operator int()),
|
||||
static_cast<Halide::DeviceAPI>(args.at(4).operator int()),
|
||||
args.at(5));
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = For::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
static_cast<ForType>(args[3].operator int()),
|
||||
static_cast<Halide::DeviceAPI>(args[4].operator int()),
|
||||
args[5]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_make_Realize)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Realize::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
args.at(3),
|
||||
args.at(4),
|
||||
args.at(5));
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Realize::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4],
|
||||
args[5]);
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_make_Call)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Call::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
static_cast<Call::CallType>(args.at(3).operator int()),
|
||||
args.at(4),
|
||||
args.at(5));
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Call::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
static_cast<Call::CallType>(args[3].operator int()),
|
||||
args[4],
|
||||
args[5]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_make_Allocate)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Allocate::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
args.at(3),
|
||||
args.at(4));
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Allocate::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4]);
|
||||
});
|
||||
|
||||
// make from two arguments
|
||||
#define REGISTER_MAKE1(Node) \
|
||||
TVM_REGISTER_API(_make_## Node) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = Node::make(args.at(0)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = Node::make(args[0]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE2(Node) \
|
||||
TVM_REGISTER_API(_make_## Node) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = Node::make(args.at(0), args.at(1)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = Node::make(args[0], args[1]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE3(Node) \
|
||||
TVM_REGISTER_API(_make_## Node) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = Node::make(args[0], args[1], args[2]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE4(Node) \
|
||||
TVM_REGISTER_API(_make_## Node) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = Node::make(args[0], args[1], args[2], args[3]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE_BINARY_OP(Node) \
|
||||
TVM_REGISTER_API(_make_## Node) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
Expr a = args.at(0), b = args.at(1); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
Expr a = args[0], b = args[1]; \
|
||||
match_types(a, b); \
|
||||
*ret = Node::make(a, b); \
|
||||
}) \
|
||||
.add_argument("lhs", "Expr", "left operand") \
|
||||
.add_argument("rhs", "Expr", "right operand")
|
||||
})
|
||||
|
||||
REGISTER_MAKE3(Reduce);
|
||||
REGISTER_MAKE4(AttrStmt);
|
|
@ -0,0 +1,256 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions related to Higher DSL build.
|
||||
* \file api_lang.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/buffer.h>
|
||||
#include <tvm/schedule.h>
|
||||
#include <tvm/api_registry.h>
|
||||
|
||||
namespace tvm {
|
||||
|
||||
TVM_REGISTER_API(_const)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
if (args[0].type_code() == kInt) {
|
||||
*ret = make_const(args[1], args[0].operator int64_t());
|
||||
} else if (args[0].type_code() == kFloat) {
|
||||
*ret = make_const(args[1], args[0].operator double());
|
||||
} else {
|
||||
LOG(FATAL) << "only accept int or float";
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_str)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ir::StringImm::make(args[0]);
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_Array)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
std::vector<std::shared_ptr<Node> > data;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
data.push_back(args[i].node_sptr());
|
||||
}
|
||||
auto node = std::make_shared<ArrayNode>();
|
||||
node->data = std::move(data);
|
||||
*ret = node;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ArrayGetItem)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
int64_t i = args[1];
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<ArrayNode>());
|
||||
auto* n = static_cast<const ArrayNode*>(sptr.get());
|
||||
CHECK_LT(static_cast<size_t>(i), n->data.size())
|
||||
<< "out of bound of array";
|
||||
*ret = n->data[i];
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ArraySize)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<ArrayNode>());
|
||||
*ret = static_cast<int64_t>(
|
||||
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Map)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
CHECK_EQ(args.size() % 2, 0);
|
||||
MapNode::ContainerType data;
|
||||
for (int i = 0; i < args.num_args; i += 2) {
|
||||
CHECK(args[i].type_code() == kNodeHandle)
|
||||
<< "need content of array to be NodeBase";
|
||||
CHECK(args[i + 1].type_code() == kNodeHandle)
|
||||
<< "need content of array to be NodeBase";
|
||||
data.emplace(std::make_pair(args[i].node_sptr(),
|
||||
args[i + 1].node_sptr()));
|
||||
}
|
||||
auto node = std::make_shared<MapNode>();
|
||||
node->data = std::move(data);
|
||||
*ret = node;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapSize)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
*ret = static_cast<int64_t>(n->data.size());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapGetItem)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
CHECK(args[0].type_code() == kNodeHandle);
|
||||
CHECK(args[1].type_code() == kNodeHandle);
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
auto it = n->data.find(args[1].node_sptr());
|
||||
CHECK(it != n->data.end())
|
||||
<< "cannot find the corresponding key in the Map";
|
||||
*ret = (*it).second;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapCount)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
CHECK(args[0].type_code() == kNodeHandle);
|
||||
CHECK(args[1].type_code() == kNodeHandle);
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
*ret = static_cast<int64_t>(
|
||||
n->data.count(args[1].node_sptr()));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapItems)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
auto& sptr = args[0].node_sptr();
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
auto rkvs = std::make_shared<ArrayNode>();
|
||||
for (const auto& kv : n->data) {
|
||||
rkvs->data.push_back(kv.first);
|
||||
rkvs->data.push_back(kv.second);
|
||||
}
|
||||
*ret = rkvs;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(Range)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
if (args.size() == 1) {
|
||||
*ret = Range(0, args[0]);
|
||||
} else {
|
||||
*ret = Range(args[0], args[1]);
|
||||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Buffer)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = BufferNode::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Tensor)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = TensorNode::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_TensorEqual)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = args[0].operator Tensor() == args[1].operator Tensor();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_TensorHash)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = static_cast<int64_t>(
|
||||
std::hash<Tensor>()(args[0].operator Tensor()));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Placeholder)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = Placeholder(args[0],
|
||||
args[1],
|
||||
args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ComputeOp)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ComputeOpNode::make(args[0],
|
||||
args[1],
|
||||
args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_OpGetOutput)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = args[0].operator Operation().output(
|
||||
args[1].operator int64_t());
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_IterVar)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = IterVar(args[0], args[1], args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Schedule)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = Schedule(args[0].operator Array<Operation>());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSetScope)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Stage()
|
||||
.set_scope(args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSplitByFactor)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
IterVar outer, inner;
|
||||
args[0].operator Stage()
|
||||
.split(args[1], &outer, &inner, args[2]);
|
||||
*ret = Array<IterVar>({outer, inner});
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSplitByOuter)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
IterVar inner;
|
||||
args[0].operator Stage()
|
||||
.split(args[1], args[2], &inner, args[3]);
|
||||
*ret = inner;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageFuse)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
IterVar fused;
|
||||
args[0].operator Stage()
|
||||
.split(args[1], args[2], &fused);
|
||||
*ret = fused;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeAt)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Stage()
|
||||
.compute_at(args[1], args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeInline)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Stage()
|
||||
.compute_inline();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeRoot)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Stage()
|
||||
.compute_root();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageReorder)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
args[0].operator Stage()
|
||||
.reorder(args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageTile)
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
IterVar x_outer, y_outer, x_inner, y_inner;
|
||||
args[0].operator Stage()
|
||||
.tile(args[1], args[2], &x_outer, &y_outer,
|
||||
&x_inner, &y_inner, args[3], args[4]);
|
||||
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
|
||||
});
|
||||
|
||||
} // namespace tvm
|
|
@ -1,53 +1,51 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Exposre of pass functions.
|
||||
* \file c_api_pass.cc
|
||||
* \file api_pass.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include "./c_api_registry.h"
|
||||
#include <tvm/api_registry.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace ir {
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
TVM_REGISTER_API(_pass_Simplify)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
|
||||
*ret = Simplify(args.at(0).operator Stmt());
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
if (args[0].IsNodeType<Stmt>()) {
|
||||
*ret = Simplify(args[0].operator Stmt());
|
||||
} else {
|
||||
*ret = Simplify(args.at(0).operator Expr());
|
||||
*ret = Simplify(args[0].operator Expr());
|
||||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_pass_Equal)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
|
||||
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
if (args[0].IsNodeType<Stmt>()) {
|
||||
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
|
||||
} else {
|
||||
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
|
||||
*ret = Equal(args[0].operator Expr(), args[1].operator Expr());
|
||||
}
|
||||
});
|
||||
|
||||
// make from two arguments
|
||||
#define REGISTER_PASS1(PassName) \
|
||||
TVM_REGISTER_API(_pass_## PassName) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = PassName(args.at(0)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = PassName(args[0]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_PASS2(PassName) \
|
||||
TVM_REGISTER_API(_pass_## PassName) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = PassName(args.at(0), args.at(1)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = PassName(args[0], args[1]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_PASS4(PassName) \
|
||||
TVM_REGISTER_API(_pass_## PassName) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = PassName(args[0], args[1], args[2], args[3]); \
|
||||
}) \
|
||||
|
||||
REGISTER_PASS1(ConvertSSA);
|
|
@ -0,0 +1,35 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file api_registry.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/api_registry.h>
|
||||
#include <memory>
|
||||
|
||||
namespace tvm {
|
||||
|
||||
struct APIManager {
|
||||
std::unordered_map<std::string, std::unique_ptr<APIRegistry> > fmap;
|
||||
|
||||
static APIManager* Global() {
|
||||
static APIManager inst;
|
||||
return &inst;
|
||||
}
|
||||
};
|
||||
|
||||
APIRegistry& APIRegistry::__REGISTER__(const std::string& name) { // NOLINT(*)
|
||||
APIManager* m = APIManager::Global();
|
||||
CHECK(!m->fmap.count(name))
|
||||
<< "API function " << name << " has already been registered";
|
||||
std::unique_ptr<APIRegistry> p(new APIRegistry());
|
||||
p->name_ = name;
|
||||
m->fmap[name] = std::move(p);
|
||||
return *(m->fmap[name]);
|
||||
}
|
||||
|
||||
APIRegistry& APIRegistry::set_body(PackedFunc f) { // NOLINT(*)
|
||||
PackedFunc::RegisterGlobal(name_, f);
|
||||
return *this;
|
||||
}
|
||||
} // namespace tvm
|
|
@ -1,30 +1,28 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Implementation of API functions related to schedule pass.
|
||||
* \file c_api_lang.cc
|
||||
* \file api_schedule.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/schedule.h>
|
||||
#include <tvm/schedule_pass.h>
|
||||
#include "./c_api_registry.h"
|
||||
#include <tvm/api_registry.h>
|
||||
#include "../schedule/graph.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace schedule {
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
#define REGISTER_SCHEDULE_PASS1(PassName) \
|
||||
TVM_REGISTER_API(_schedule_## PassName) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = PassName(args.at(0)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = PassName(args[0]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_SCHEDULE_PASS2(PassName) \
|
||||
TVM_REGISTER_API(_schedule_## PassName) \
|
||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
||||
*ret = PassName(args.at(0), args.at(1)); \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = PassName(args[0], args[1]); \
|
||||
}) \
|
||||
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of C API
|
||||
* \file c_api.cc
|
||||
*/
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/c_api.h>
|
||||
#include <tvm/api_registry.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <exception>
|
||||
#include "../runtime/runtime_base.h"
|
||||
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct TVMAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief result holder for retruning string */
|
||||
std::string ret_str;
|
||||
};
|
||||
|
||||
using namespace tvm;
|
||||
|
||||
/*! \brief Thread local store that can be used to hold return values. */
|
||||
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
|
||||
|
||||
using TVMAPINode = std::shared_ptr<Node>;
|
||||
|
||||
struct APIAttrGetter : public AttrVisitor {
|
||||
std::string skey;
|
||||
TVMRetValue* ret;
|
||||
bool found_node_ref{false};
|
||||
|
||||
void Visit(const char* key, double* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, int64_t* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, uint64_t* value) final {
|
||||
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
|
||||
<< "cannot return too big constant";
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, int* value) final {
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, bool* value) final {
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, Type* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, std::string* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, NodeRef* value) final {
|
||||
if (skey == key) {
|
||||
*ret = value[0];
|
||||
found_node_ref = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct APIAttrDir : public AttrVisitor {
|
||||
std::vector<std::string>* names;
|
||||
|
||||
void Visit(const char* key, double* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, int64_t* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, uint64_t* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, bool* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, int* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, Type* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, std::string* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, NodeRef* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
int TVMNodeFree(NodeHandle handle) {
|
||||
API_BEGIN();
|
||||
delete static_cast<TVMAPINode*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMNodeGetAttr(NodeHandle handle,
|
||||
const char* key,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code,
|
||||
int* ret_success) {
|
||||
API_BEGIN();
|
||||
TVMRetValue rv;
|
||||
APIAttrGetter getter;
|
||||
getter.skey = key;
|
||||
getter.ret = &rv;
|
||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||
if (getter.skey == "type_key") {
|
||||
ret_val->v_str = (*tnode)->type_key();
|
||||
*ret_type_code = kStr;
|
||||
*ret_success = 1;
|
||||
} else {
|
||||
(*tnode)->VisitAttrs(&getter);
|
||||
*ret_success = getter.found_node_ref || rv.type_code() != kNull;
|
||||
if (rv.type_code() == kStr) {
|
||||
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
|
||||
e->ret_str = rv.operator std::string();
|
||||
*ret_type_code = kStr;
|
||||
ret_val->v_str = e->ret_str.c_str();
|
||||
} else {
|
||||
rv.MoveToCHost(ret_val, ret_type_code);
|
||||
}
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMNodeListAttrNames(NodeHandle handle,
|
||||
int *out_size,
|
||||
const char*** out_array) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
ret->ret_vec_str.clear();
|
||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||
APIAttrDir dir;
|
||||
dir.names = &(ret->ret_vec_str);
|
||||
(*tnode)->VisitAttrs(&dir);
|
||||
ret->ret_vec_charp.clear();
|
||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||
}
|
||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||
API_END();
|
||||
}
|
|
@ -42,5 +42,78 @@ inline Type String2Type(std::string s) {
|
|||
return Type(code, bits, lanes);
|
||||
}
|
||||
|
||||
inline const char* TVMTypeCode2Str(int type_code) {
|
||||
switch (type_code) {
|
||||
case kInt: return "int";
|
||||
case kFloat: return "float";
|
||||
case kStr: return "str";
|
||||
case kHandle: return "Handle";
|
||||
case kNull: return "NULL";
|
||||
case kNodeHandle: return "NodeHandle";
|
||||
default: LOG(FATAL) << "unknown type_code="
|
||||
<< static_cast<int>(type_code); return "";
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
struct NodeTypeChecker {
|
||||
static inline bool Check(Node* sptr) {
|
||||
// This is the only place in the project where RTTI is used
|
||||
// It can be turned off, but will make non strict checking.
|
||||
// TODO(tqchen) possibly find alternative to turn of RTTI
|
||||
using ContainerType = typename T::ContainerType;
|
||||
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
using ContainerType = typename T::ContainerType;
|
||||
os << ContainerType::_type_key;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct NodeTypeChecker<Array<T> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<ArrayNode>()) return false;
|
||||
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
||||
for (const auto& p : n->data) {
|
||||
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "array<";
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
os << ">";
|
||||
}
|
||||
};
|
||||
|
||||
template<typename K, typename V>
|
||||
struct NodeTypeChecker<Map<K, V> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<MapNode>()) return false;
|
||||
MapNode* n = static_cast<MapNode*>(sptr);
|
||||
for (const auto& kv : n->data) {
|
||||
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
||||
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "map<";
|
||||
NodeTypeChecker<K>::PrintName(os);
|
||||
os << ',';
|
||||
NodeTypeChecker<V>::PrintName(os);
|
||||
os << '>';
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline std::string NodeTypeName() {
|
||||
std::ostringstream os;
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace tvm
|
||||
#endif // TVM_BASE_COMMON_H_
|
||||
|
|
|
@ -1,260 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of C API
|
||||
* \file c_api.cc
|
||||
*/
|
||||
#include <tvm/c_api.h>
|
||||
#include "./c_api_common.h"
|
||||
#include "./c_api_registry.h"
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct TVMAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief argument stack */
|
||||
std::vector<tvm::APIVariantValue> arg_stack;
|
||||
/*! \brief return value */
|
||||
tvm::APIVariantValue ret_value;
|
||||
// clear calling stack
|
||||
inline void Clear() {
|
||||
arg_stack.clear();
|
||||
ret_value.sptr.reset();
|
||||
}
|
||||
inline void SetReturn(TVMValue* ret_val, int* ret_type_code);
|
||||
};
|
||||
|
||||
using namespace tvm;
|
||||
|
||||
/*! \brief Thread local store that can be used to hold return values. */
|
||||
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
|
||||
|
||||
using TVMAPINode = std::shared_ptr<Node>;
|
||||
|
||||
struct APIAttrGetter : public AttrVisitor {
|
||||
std::string skey;
|
||||
APIVariantValue* ret;
|
||||
bool found_node_ref{false};
|
||||
|
||||
void Visit(const char* key, double* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, int64_t* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, uint64_t* value) final {
|
||||
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
|
||||
<< "cannot return too big constant";
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, int* value) final {
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, bool* value) final {
|
||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||
}
|
||||
void Visit(const char* key, Type* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, std::string* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, NodeRef* value) final {
|
||||
if (skey == key) {
|
||||
*ret = value[0];
|
||||
found_node_ref = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct APIAttrDir : public AttrVisitor {
|
||||
std::vector<std::string>* names;
|
||||
|
||||
void Visit(const char* key, double* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, int64_t* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, uint64_t* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, bool* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, int* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, Type* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, std::string* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
void Visit(const char* key, NodeRef* value) final {
|
||||
names->push_back(key);
|
||||
}
|
||||
};
|
||||
|
||||
int TVMListAPIFuncNames(int *out_size,
|
||||
const char*** out_array) {
|
||||
API_BEGIN();
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames();
|
||||
ret->ret_vec_charp.clear();
|
||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||
}
|
||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMGetAPIFuncHandle(const char* fname,
|
||||
APIFuncHandle* out) {
|
||||
API_BEGIN();
|
||||
const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname);
|
||||
CHECK(reg != nullptr) << "cannot find function " << fname;
|
||||
*out = (APIFuncHandle)reg;
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMGetAPIFuncInfo(APIFuncHandle handle,
|
||||
const char **real_name,
|
||||
const char **description,
|
||||
int *num_doc_args,
|
||||
const char ***arg_names,
|
||||
const char ***arg_type_infos,
|
||||
const char ***arg_descriptions,
|
||||
const char **return_type) {
|
||||
const auto *op = static_cast<const APIFuncReg *>(handle);
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
|
||||
API_BEGIN();
|
||||
*real_name = op->name.c_str();
|
||||
*description = op->description.c_str();
|
||||
*num_doc_args = static_cast<int>(op->arguments.size());
|
||||
if (return_type) *return_type = nullptr;
|
||||
ret->ret_vec_charp.clear();
|
||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
|
||||
}
|
||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
|
||||
}
|
||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
|
||||
}
|
||||
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
|
||||
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMAPIPushStack(TVMValue arg,
|
||||
int type_code) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
ret->arg_stack.resize(ret->arg_stack.size() + 1);
|
||||
APIVariantValue& v = ret->arg_stack.back();
|
||||
|
||||
v.type_code = type_code;
|
||||
switch (type_code) {
|
||||
case kInt: case kUInt: case kFloat: case kNull: {
|
||||
v.v_union = arg; break;
|
||||
}
|
||||
case kStr: {
|
||||
v.str = arg.v_str; break;
|
||||
}
|
||||
case kNodeHandle: {
|
||||
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break;
|
||||
}
|
||||
default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code);
|
||||
}
|
||||
API_END_HANDLE_ERROR(ret->Clear());
|
||||
}
|
||||
|
||||
int TVMAPIFuncCall(APIFuncHandle handle,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
const auto *op = static_cast<const APIFuncReg *>(handle);
|
||||
op->body(ret->arg_stack, &(ret->ret_value));
|
||||
ret->SetReturn(ret_val, ret_type_code);
|
||||
ret->arg_stack.clear();
|
||||
API_END_HANDLE_ERROR(ret->Clear());
|
||||
}
|
||||
|
||||
int TVMNodeFree(NodeHandle handle) {
|
||||
API_BEGIN();
|
||||
delete static_cast<TVMAPINode*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMNodeGetAttr(NodeHandle handle,
|
||||
const char* key,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code,
|
||||
int* ret_success) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
ret->ret_value.type_code = kNull;
|
||||
APIAttrGetter getter;
|
||||
getter.skey = key;
|
||||
getter.ret = &(ret->ret_value);
|
||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||
if (getter.skey == "type_key") {
|
||||
ret_val->v_str = (*tnode)->type_key();
|
||||
*ret_type_code = kStr;
|
||||
*ret_success = 1;
|
||||
} else {
|
||||
(*tnode)->VisitAttrs(&getter);
|
||||
if (ret->ret_value.type_code != kNull) {
|
||||
ret->SetReturn(ret_val, ret_type_code);
|
||||
*ret_success = 1;
|
||||
} else {
|
||||
*ret_success = getter.found_node_ref ? 1 : 0;
|
||||
*ret_type_code = kNull;
|
||||
}
|
||||
}
|
||||
API_END_HANDLE_ERROR(ret->Clear());
|
||||
}
|
||||
|
||||
int TVMNodeListAttrNames(NodeHandle handle,
|
||||
int *out_size,
|
||||
const char*** out_array) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
ret->ret_vec_str.clear();
|
||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||
APIAttrDir dir;
|
||||
dir.names = &(ret->ret_vec_str);
|
||||
(*tnode)->VisitAttrs(&dir);
|
||||
ret->ret_vec_charp.clear();
|
||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||
}
|
||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val,
|
||||
int* ret_type_code) {
|
||||
APIVariantValue& rv = ret_value;
|
||||
*ret_type_code = rv.type_code;
|
||||
if (rv.type_code == kNodeHandle) {
|
||||
if (rv.sptr.get() != nullptr) {
|
||||
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
|
||||
} else {
|
||||
ret_val->v_handle = nullptr;
|
||||
}
|
||||
} else if (rv.type_code == kFuncHandle) {
|
||||
ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func));
|
||||
} else {
|
||||
*ret_val = rv.v_union;
|
||||
}
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions related to Codegen
|
||||
* \file c_api_codegen.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/codegen.h>
|
||||
|
||||
#include "./c_api_registry.h"
|
||||
#include "../codegen/codegen_c.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
TVM_REGISTER_API(_codegen_CompileToC)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = CodeGenC().Compile(args.at(0), args.at(1));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_MakeAPI)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = MakeAPI(
|
||||
args.at(0), args.at(1), args.at(2), args.at(3));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_SplitHostDevice)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = SplitHostDevice(args.at(0));
|
||||
});
|
||||
|
||||
|
||||
// generate a dummy packed function for testing
|
||||
void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) {
|
||||
LOG(INFO) << num_args << " arguments";
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
switch (type_code[i]) {
|
||||
case kNull: LOG(INFO) << i << ":nullptr"; break;
|
||||
case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break;
|
||||
case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break;
|
||||
case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break;
|
||||
default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_API(_codegen_DummyHelloFunction)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = runtime::PackedFunc(DummyHelloFunction);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_codegen_BuildStackVM)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = BuildStackVM(args.at(0));
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -1,19 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file c_api_common.h
|
||||
* \brief Common fields of all C APIs
|
||||
*/
|
||||
#ifndef TVM_C_API_C_API_COMMON_H_
|
||||
#define TVM_C_API_C_API_COMMON_H_
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/c_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <exception>
|
||||
#include "./c_api_registry.h"
|
||||
#include "../runtime/runtime_base.h"
|
||||
|
||||
#endif // TVM_C_API_C_API_COMMON_H_
|
|
@ -1,47 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions
|
||||
* \file c_api_impl.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include "./c_api_registry.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace tvm {
|
||||
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
TVM_REGISTER_API(_format_str)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
std::ostringstream os;
|
||||
os << args.at(0).operator NodeRef();
|
||||
*ret = os.str();
|
||||
})
|
||||
.add_argument("expr", "Node", "expression to be printed");
|
||||
|
||||
TVM_REGISTER_API(_raw_ptr)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
|
||||
})
|
||||
.add_argument("src", "NodeBase", "the node base");
|
||||
|
||||
TVM_REGISTER_API(_save_json)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = SaveJSON(args.at(0));
|
||||
})
|
||||
.add_argument("src", "json_str", "the node ");
|
||||
|
||||
TVM_REGISTER_API(_load_json)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = NodeRef(LoadJSON_(args.at(0)));
|
||||
})
|
||||
.add_argument("src", "NodeBase", "the node");
|
||||
|
||||
} // namespace tvm
|
|
@ -1,273 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Implementation of API functions related to Higher DSL build.
|
||||
* \file c_api_lang.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/tensor.h>
|
||||
#include <tvm/buffer.h>
|
||||
#include <tvm/schedule.h>
|
||||
#include "./c_api_registry.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
TVM_REGISTER_API(_const)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
if (args.at(0).type_code == kInt) {
|
||||
*ret = make_const(args.at(1), args.at(0).operator int64_t());
|
||||
} else if (args.at(0).type_code == kFloat) {
|
||||
*ret = make_const(args.at(1), args.at(0).operator double());
|
||||
} else {
|
||||
LOG(FATAL) << "only accept int or float";
|
||||
}
|
||||
})
|
||||
.add_argument("src", "Number", "source number")
|
||||
.add_argument("dtype", "str", "data type");
|
||||
|
||||
|
||||
TVM_REGISTER_API(_str)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = ir::StringImm::make(args.at(0));
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_Array)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
std::vector<std::shared_ptr<Node> > data;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
CHECK(args.at(i).type_code == kNodeHandle)
|
||||
<< "need content of array to be NodeBase";
|
||||
data.push_back(args.at(i).sptr);
|
||||
}
|
||||
auto node = std::make_shared<ArrayNode>();
|
||||
node->data = std::move(data);
|
||||
ret->type_code = kNodeHandle;
|
||||
ret->sptr = node;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ArrayGetItem)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
int64_t i = args.at(1);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<ArrayNode>());
|
||||
auto* n = static_cast<const ArrayNode*>(sptr.get());
|
||||
CHECK_LT(static_cast<size_t>(i), n->data.size())
|
||||
<< "out of bound of array";
|
||||
ret->sptr = n->data[i];
|
||||
ret->type_code = kNodeHandle;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ArraySize)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<ArrayNode>());
|
||||
*ret = static_cast<int64_t>(
|
||||
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Map)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK_EQ(args.size() % 2, 0U);
|
||||
MapNode::ContainerType data;
|
||||
for (size_t i = 0; i < args.size(); i += 2) {
|
||||
CHECK(args.at(i).type_code == kNodeHandle)
|
||||
<< "need content of array to be NodeBase";
|
||||
CHECK(args.at(i + 1).type_code == kNodeHandle)
|
||||
<< "need content of array to be NodeBase";
|
||||
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
|
||||
}
|
||||
auto node = std::make_shared<MapNode>();
|
||||
node->data = std::move(data);
|
||||
ret->type_code = kNodeHandle;
|
||||
ret->sptr = node;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapSize)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
*ret = static_cast<int64_t>(n->data.size());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapGetItem)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
CHECK(args.at(1).type_code == kNodeHandle);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
auto it = n->data.find(args.at(1).sptr);
|
||||
CHECK(it != n->data.end())
|
||||
<< "cannot find the corresponding key in the Map";
|
||||
ret->sptr = (*it).second;
|
||||
ret->type_code = kNodeHandle;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapCount)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
CHECK(args.at(1).type_code == kNodeHandle);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
*ret = static_cast<int64_t>(n->data.count(args.at(1).sptr));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_MapItems)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_code == kNodeHandle);
|
||||
auto& sptr = args.at(0).sptr;
|
||||
CHECK(sptr->is_type<MapNode>());
|
||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||
auto rkvs = std::make_shared<ArrayNode>();
|
||||
for (const auto& kv : n->data) {
|
||||
rkvs->data.push_back(kv.first);
|
||||
rkvs->data.push_back(kv.second);
|
||||
}
|
||||
ret->sptr = rkvs;
|
||||
ret->type_code = kNodeHandle;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(Range)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
if (args.size() == 1) {
|
||||
*ret = Range(0, args.at(0));
|
||||
} else {
|
||||
*ret = Range(args.at(0), args.at(1));
|
||||
}
|
||||
})
|
||||
.describe("create a domain range")
|
||||
.add_argument("begin", "Expr", "beginning of the range.")
|
||||
.add_argument("end", "Expr", "extent of the range");
|
||||
|
||||
TVM_REGISTER_API(_Buffer)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = BufferNode::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
args.at(3),
|
||||
args.at(4));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Tensor)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = TensorNode::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2),
|
||||
args.at(3));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_TensorEqual)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_TensorHash)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = static_cast<int64_t>(
|
||||
std::hash<Tensor>()(args.at(0).operator Tensor()));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Placeholder)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Placeholder(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_ComputeOp)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = ComputeOpNode::make(args.at(0),
|
||||
args.at(1),
|
||||
args.at(2));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_OpGetOutput)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = args.at(0).operator Operation().output(
|
||||
args.at(1).operator int64_t());
|
||||
});
|
||||
|
||||
|
||||
TVM_REGISTER_API(_IterVar)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = IterVar(args.at(0), args.at(1), args.at(2));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_Schedule)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Schedule(args.at(0).operator Array<Operation>());
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSetScope)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
args.at(0).operator Stage()
|
||||
.set_scope(args.at(1));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSplitByFactor)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
IterVar outer, inner;
|
||||
args.at(0).operator Stage()
|
||||
.split(args.at(1), &outer, &inner, args.at(2));
|
||||
*ret = Array<IterVar>({outer, inner});
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageSplitByOuter)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
IterVar inner;
|
||||
args.at(0).operator Stage()
|
||||
.split(args.at(1), args.at(2), &inner, args.at(3));
|
||||
*ret = inner;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageFuse)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
IterVar fused;
|
||||
args.at(0).operator Stage()
|
||||
.split(args.at(1), args.at(2), &fused);
|
||||
*ret = fused;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeAt)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
args.at(0).operator Stage()
|
||||
.compute_at(args.at(1), args.at(2));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeInline)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
args.at(0).operator Stage()
|
||||
.compute_inline();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageComputeRoot)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
args.at(0).operator Stage()
|
||||
.compute_root();
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageReorder)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
args.at(0).operator Stage()
|
||||
.reorder(args.at(1));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API(_StageTile)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
IterVar x_outer, y_outer, x_inner, y_inner;
|
||||
args.at(0).operator Stage()
|
||||
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
|
||||
&x_inner, &y_inner, args.at(3), args.at(4));
|
||||
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
|
||||
});
|
||||
|
||||
} // namespace tvm
|
|
@ -1,240 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file c_api_registry.h
|
||||
* \brief Quick registry for C API.
|
||||
*/
|
||||
#ifndef TVM_C_API_C_API_REGISTRY_H_
|
||||
#define TVM_C_API_C_API_REGISTRY_H_
|
||||
|
||||
#include <tvm/base.h>
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/c_api.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../base/common.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
inline const char* TVMTypeCode2Str(int type_code) {
|
||||
switch (type_code) {
|
||||
case kInt: return "int";
|
||||
case kFloat: return "float";
|
||||
case kStr: return "str";
|
||||
case kHandle: return "Handle";
|
||||
case kNull: return "NULL";
|
||||
case kNodeHandle: return "NodeHandle";
|
||||
default: LOG(FATAL) << "unknown type_code="
|
||||
<< static_cast<int>(type_code); return "";
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct NodeTypeChecker {
|
||||
static inline bool Check(Node* sptr) {
|
||||
// This is the only place in the project where RTTI is used
|
||||
// It can be turned off, but will make non strict checking.
|
||||
// TODO(tqchen) possibly find alternative to turn of RTTI
|
||||
using ContainerType = typename T::ContainerType;
|
||||
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
using ContainerType = typename T::ContainerType;
|
||||
os << ContainerType::_type_key;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct NodeTypeChecker<Array<T> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<ArrayNode>()) return false;
|
||||
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
||||
for (const auto& p : n->data) {
|
||||
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "array<";
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
os << ">";
|
||||
}
|
||||
};
|
||||
|
||||
template<typename K, typename V>
|
||||
struct NodeTypeChecker<Map<K, V> > {
|
||||
static inline bool Check(Node* sptr) {
|
||||
if (sptr == nullptr) return false;
|
||||
if (!sptr->is_type<MapNode>()) return false;
|
||||
MapNode* n = static_cast<MapNode*>(sptr);
|
||||
for (const auto& kv : n->data) {
|
||||
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
||||
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||
os << "map<";
|
||||
NodeTypeChecker<K>::PrintName(os);
|
||||
os << ',';
|
||||
NodeTypeChecker<V>::PrintName(os);
|
||||
os << '>';
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline std::string NodeTypeName() {
|
||||
std::ostringstream os;
|
||||
NodeTypeChecker<T>::PrintName(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
/*! \brief Variant container for API calls */
|
||||
class APIVariantValue {
|
||||
public:
|
||||
/*! \brief the type id */
|
||||
int type_code{kNull};
|
||||
/*! \brief shared pointer container */
|
||||
std::shared_ptr<Node> sptr;
|
||||
/*! \brief string container */
|
||||
std::string str;
|
||||
/*! \brief the variant holder */
|
||||
TVMValue v_union;
|
||||
/*! \brief std::function */
|
||||
runtime::PackedFunc::FType func;
|
||||
// constructor
|
||||
APIVariantValue() {
|
||||
}
|
||||
// clear value
|
||||
inline void Clear() {
|
||||
}
|
||||
// assign op
|
||||
inline APIVariantValue& operator=(double value) {
|
||||
type_code = kFloat;
|
||||
v_union.v_float64 = value;
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(std::nullptr_t value) {
|
||||
type_code = kHandle;
|
||||
v_union.v_handle = value;
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(int64_t value) {
|
||||
type_code = kInt;
|
||||
v_union.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(bool value) {
|
||||
type_code = kInt;
|
||||
v_union.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(std::string value) {
|
||||
type_code = kStr;
|
||||
str = std::move(value);
|
||||
v_union.v_str = str.c_str();
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(const NodeRef& ref) {
|
||||
if (ref.node_.get() == nullptr) {
|
||||
type_code = kNull;
|
||||
} else {
|
||||
type_code = kNodeHandle;
|
||||
this->sptr = ref.node_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
|
||||
type_code = kFuncHandle;
|
||||
this->func = f.body();
|
||||
return *this;
|
||||
}
|
||||
inline APIVariantValue& operator=(const Type& value) {
|
||||
return operator=(Type2String(value));
|
||||
}
|
||||
template<typename T,
|
||||
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
|
||||
inline operator T() const {
|
||||
if (type_code == kNull) return T();
|
||||
CHECK_EQ(type_code, kNodeHandle);
|
||||
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
|
||||
<< "Did not get expected type " << NodeTypeName<T>();
|
||||
return T(sptr);
|
||||
}
|
||||
inline operator Expr() const {
|
||||
if (type_code == kNull) {
|
||||
return Expr();
|
||||
}
|
||||
if (type_code == kInt) return Expr(operator int());
|
||||
if (type_code == kFloat) {
|
||||
return Expr(static_cast<float>(operator double()));
|
||||
}
|
||||
CHECK_EQ(type_code, kNodeHandle);
|
||||
if (sptr->is_type<IterVarNode>()) {
|
||||
return IterVar(sptr)->var;
|
||||
} else {
|
||||
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
|
||||
<< "did not pass in Expr in a place need Expr";
|
||||
return Expr(sptr);
|
||||
}
|
||||
}
|
||||
inline operator double() const {
|
||||
CHECK_EQ(type_code, kFloat);
|
||||
return v_union.v_float64;
|
||||
}
|
||||
inline operator int64_t() const {
|
||||
CHECK_EQ(type_code, kInt);
|
||||
return v_union.v_int64;
|
||||
}
|
||||
inline operator uint64_t() const {
|
||||
CHECK_EQ(type_code, kInt);
|
||||
return v_union.v_int64;
|
||||
}
|
||||
inline operator int() const {
|
||||
CHECK_EQ(type_code, kInt);
|
||||
CHECK_LE(v_union.v_int64,
|
||||
std::numeric_limits<int>::max());
|
||||
return v_union.v_int64;
|
||||
}
|
||||
inline operator bool() const {
|
||||
CHECK_EQ(type_code, kInt)
|
||||
<< "expect boolean(int) but get "
|
||||
<< TVMTypeCode2Str(type_code);
|
||||
return v_union.v_int64 != 0;
|
||||
}
|
||||
inline operator std::string() const {
|
||||
CHECK_EQ(type_code, kStr)
|
||||
<< "expect Str but get "
|
||||
<< TVMTypeCode2Str(type_code);
|
||||
return str;
|
||||
}
|
||||
inline operator Type() const {
|
||||
return String2Type(operator std::string());
|
||||
}
|
||||
inline operator runtime::PackedFunc() const {
|
||||
CHECK_EQ(type_code, kFuncHandle);
|
||||
return runtime::PackedFunc(func);
|
||||
}
|
||||
};
|
||||
|
||||
// common defintiion of API function.
|
||||
using APIFunc = std::function<
|
||||
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for DataIterator factory functions.
|
||||
*/
|
||||
struct APIFuncReg
|
||||
: public dmlc::FunctionRegEntryBase<APIFuncReg,
|
||||
APIFunc> {
|
||||
};
|
||||
|
||||
#define TVM_REGISTER_API(TypeName) \
|
||||
DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
|
||||
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_C_API_C_API_REGISTRY_H_
|
|
@ -12,19 +12,22 @@ using namespace ir;
|
|||
|
||||
runtime::PackedFunc BuildStackVM(LoweredFunc func) {
|
||||
StackVM vm = codegen::CodeGenStackVM().Compile(func);
|
||||
auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) {
|
||||
LOG(INFO) << "Run stack VM";
|
||||
using runtime::TVMArgs;
|
||||
using runtime::TVMRetValue;
|
||||
|
||||
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
|
||||
StackVM::State* s = StackVM::ThreadLocalState();
|
||||
s->sp = 0;
|
||||
s->pc = 0;
|
||||
if (s->heap.size() < vm.heap_size) {
|
||||
s->heap.resize(vm.heap_size);
|
||||
}
|
||||
s->heap[0].v_handle = (void*)args; // NOLINT(*)
|
||||
s->heap[1].v_handle = (void*)type_codes; // NOLINT(*)
|
||||
s->heap[2].v_int64 = num_args;
|
||||
s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
|
||||
s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
|
||||
s->heap[2].v_int64 = args.num_args;
|
||||
vm.Run(s);
|
||||
};
|
||||
|
||||
return runtime::PackedFunc(f);
|
||||
}
|
||||
|
||||
|
@ -118,6 +121,9 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
|
|||
auto it = fun_idmap_.find(name);
|
||||
if (it != fun_idmap_.end()) return it->second;
|
||||
using runtime::PackedFunc;
|
||||
using runtime::TVMArgs;
|
||||
using runtime::TVMRetValue;
|
||||
|
||||
PackedFunc f = PackedFunc::GetGlobal(name);
|
||||
auto extern_f = [f](const TVMValue* args, int num_args) {
|
||||
CHECK_EQ(num_args % 2, 0);
|
||||
|
@ -128,7 +134,8 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
|
|||
int code = (tcode >> (8 * 3)) & 255;
|
||||
type_codes[i] = code;
|
||||
}
|
||||
f.CallPacked(args, &type_codes[0], num_args);
|
||||
TVMRetValue rv;
|
||||
f.CallPacked(TVMArgs(args, &type_codes[0], num_args), &rv);
|
||||
TVMValue r; r.v_int64 = 0;
|
||||
return r;
|
||||
};
|
||||
|
|
|
@ -136,7 +136,6 @@ class HostDeviceSplitter : public IRMutator {
|
|||
public:
|
||||
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
||||
if (op->type_key == "thread_extent") {
|
||||
LOG(INFO) << "??";
|
||||
IterVar iv(op->node.node_);
|
||||
return SplitDeviceFunc(s);
|
||||
}
|
||||
|
|
|
@ -302,7 +302,8 @@ void StackVM::Run(State* s) const {
|
|||
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
|
||||
}
|
||||
case TVM_LOAD_ARG_HANDLE: {
|
||||
STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break;
|
||||
STACK_VM_TVM_LOAD_ARG(
|
||||
tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_DATA: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
|
||||
|
@ -317,7 +318,7 @@ void StackVM::Run(State* s) const {
|
|||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_CODE: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break;
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
|
||||
}
|
||||
case TVM_ARRAY_GET_TYPE_BITS: {
|
||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
* \file c_runtime_api.cc
|
||||
* \brief Device specific implementations
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "./runtime_base.h"
|
||||
#include "./device_api.h"
|
||||
|
||||
|
@ -37,7 +39,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
|
|||
|
||||
inline void VerifyType(TVMType dtype) {
|
||||
CHECK_GE(dtype.lanes, 1U);
|
||||
if (dtype.type_code == kFloat) {
|
||||
if (dtype.code == kFloat) {
|
||||
CHECK_EQ(dtype.bits % 32U, 0U);
|
||||
} else {
|
||||
CHECK_EQ(dtype.bits % 8U, 0U);
|
||||
|
@ -65,6 +67,12 @@ inline size_t GetDataAlignment(TVMArray* arr) {
|
|||
|
||||
using namespace tvm::runtime;
|
||||
|
||||
struct TVMRuntimeEntry {
|
||||
std::string ret_str;
|
||||
};
|
||||
|
||||
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
|
||||
|
||||
int TVMDeviceInit(int dev_mask,
|
||||
const char** option_keys,
|
||||
const char** option_vals,
|
||||
|
@ -177,10 +185,31 @@ int TVMFuncFree(TVMFunctionHandle func) {
|
|||
int TVMFuncCall(TVMFunctionHandle func,
|
||||
TVMValue* args,
|
||||
int* arg_type_codes,
|
||||
int num_args) {
|
||||
int num_args,
|
||||
TVMValue* ret_val,
|
||||
int* ret_type_code) {
|
||||
API_BEGIN();
|
||||
TVMRetValue rv;
|
||||
(*static_cast<const PackedFunc*>(func)).CallPacked(
|
||||
args, arg_type_codes, num_args);
|
||||
TVMArgs(args, arg_type_codes, num_args), &rv);
|
||||
// handle return string.
|
||||
if (rv.type_code() == kStr) {
|
||||
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
|
||||
e->ret_str = rv.operator std::string();
|
||||
*ret_type_code = kStr;
|
||||
ret_val->v_str = e->ret_str.c_str();
|
||||
} else {
|
||||
rv.MoveToCHost(ret_val, ret_type_code);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMCFuncSetReturn(TVMRetValueHandle ret,
|
||||
TVMValue value,
|
||||
int type_code) {
|
||||
API_BEGIN();
|
||||
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
|
||||
*rv = TVMArgValue(value, type_code);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -191,22 +220,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
|||
API_BEGIN();
|
||||
if (fin == nullptr) {
|
||||
*out = new PackedFunc(
|
||||
[func, resource_handle](const TVMValue* args,
|
||||
const int* type_codes,
|
||||
int num_args) {
|
||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
||||
num_args, resource_handle);
|
||||
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
|
||||
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||
args.num_args, rv, resource_handle);
|
||||
});
|
||||
} else {
|
||||
// wrap it in a shared_ptr, with fin as deleter.
|
||||
// so fin will be called when the lambda went out of scope.
|
||||
std::shared_ptr<void> rpack(resource_handle, fin);
|
||||
*out = new PackedFunc(
|
||||
[func, rpack](const TVMValue* args,
|
||||
const int* type_codes,
|
||||
int num_args) {
|
||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
||||
num_args, rpack.get());
|
||||
[func, rpack](TVMArgs args, TVMRetValue* rv) {
|
||||
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||
args.num_args, rv, rpack.get());
|
||||
});
|
||||
}
|
||||
API_END();
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
* \brief The global registry of packed function.
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
@ -58,6 +59,18 @@ std::vector<std::string> PackedFunc::ListGlobalNames() {
|
|||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct TVMFuncThreadLocalEntry {
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
};
|
||||
|
||||
/*! \brief Thread local store that can be used to hold return values. */
|
||||
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
|
||||
|
||||
|
||||
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
||||
using tvm::runtime::PackedFunc;
|
||||
API_BEGIN();
|
||||
|
@ -68,6 +81,22 @@ int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
|||
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
|
||||
using tvm::runtime::PackedFunc;
|
||||
API_BEGIN();
|
||||
*out = new PackedFunc(PackedFunc::GetGlobal(name));
|
||||
const PackedFunc& f = PackedFunc::GetGlobal(name);
|
||||
*out = (TVMFunctionHandle)(&f); // NOLINT(*)
|
||||
API_END();
|
||||
}
|
||||
|
||||
int TVMFuncListGlobalNames(int *out_size,
|
||||
const char*** out_array) {
|
||||
using tvm::runtime::PackedFunc;
|
||||
API_BEGIN();
|
||||
TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
|
||||
ret->ret_vec_str = PackedFunc::ListGlobalNames();
|
||||
ret->ret_vec_charp.clear();
|
||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||
}
|
||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||
API_END();
|
||||
}
|
||||
|
|
|
@ -1,24 +1,116 @@
|
|||
#include <dmlc/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/ir.h>
|
||||
|
||||
TEST(PackedFunc, Basic) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
int x = 0;
|
||||
void* handle = &x;
|
||||
TVMArray a;
|
||||
|
||||
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) {
|
||||
CHECK(num_args == 3);
|
||||
CHECK(args[0].v_float64 == 1.0);
|
||||
CHECK(type_codes[0] == kFloat);
|
||||
CHECK(args[1].v_handle == &a);
|
||||
CHECK(type_codes[1] == kHandle);
|
||||
CHECK(args[2].v_handle == &x);
|
||||
CHECK(type_codes[2] == kHandle);
|
||||
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||
CHECK(args.num_args == 3);
|
||||
CHECK(args.values[0].v_float64 == 1.0);
|
||||
CHECK(args.type_codes[0] == kFloat);
|
||||
CHECK(args.values[1].v_handle == &a);
|
||||
CHECK(args.type_codes[1] == kArrayHandle);
|
||||
CHECK(args.values[2].v_handle == &x);
|
||||
CHECK(args.type_codes[2] == kHandle);
|
||||
*rv = Var("a");
|
||||
})(1.0, &a, handle);
|
||||
CHECK(v->name_hint == "a");
|
||||
}
|
||||
|
||||
TEST(PackedFunc, Node) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
Var x;
|
||||
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||
CHECK(args.num_args == 1);
|
||||
CHECK(args.type_codes[0] == kNodeHandle);
|
||||
Var b = args[0];
|
||||
CHECK(x.same_as(b));
|
||||
*rv = b;
|
||||
})(x);
|
||||
CHECK(t.same_as(x));
|
||||
}
|
||||
|
||||
TEST(PackedFunc, str) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||
CHECK(args.num_args == 1);
|
||||
std::string x = args[0];
|
||||
CHECK(x == "hello");
|
||||
*rv = x;
|
||||
})("hello");
|
||||
}
|
||||
|
||||
|
||||
TEST(PackedFunc, func) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = args[0].operator int() + 1;
|
||||
});
|
||||
// function as arguments
|
||||
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||
PackedFunc f = args[0];
|
||||
// TVMArgValue -> Arguments as function
|
||||
*rv = f(args[1]).operator int();
|
||||
})(addone, 1);
|
||||
CHECK_EQ(r0, 2);
|
||||
|
||||
int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||
// TVMArgValue -> TVMRetValue
|
||||
*rv = args[1];
|
||||
})(2, 100);
|
||||
CHECK_EQ(r1, 100);
|
||||
|
||||
int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||
// re-assignment
|
||||
*rv = args[0];
|
||||
// TVMRetValue -> Function argument
|
||||
*rv = addone(args[0].operator PackedFunc()(args[1], 1));
|
||||
})(addone, 100);
|
||||
CHECK_EQ(r2, 102);
|
||||
}
|
||||
|
||||
TEST(PackedFunc, Expr) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
// automatic conversion of int to expr
|
||||
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
|
||||
Expr x = args[0];
|
||||
*rv = x.as<tvm::ir::IntImm>()->value + 1;
|
||||
});
|
||||
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||
PackedFunc f = args[0];
|
||||
// TVMArgValue -> Arguments as function
|
||||
*rv = f(args[1]).operator int();
|
||||
})(addone, 1);
|
||||
CHECK_EQ(r0, 2);
|
||||
}
|
||||
|
||||
TEST(PackedFunc, Type) {
|
||||
using namespace tvm;
|
||||
using namespace tvm::runtime;
|
||||
auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||
Type x = args[0];
|
||||
*rv = x;
|
||||
});
|
||||
auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = args[0];
|
||||
});
|
||||
CHECK(get_type("int32").operator Type() == Int(32));
|
||||
CHECK(get_type("float").operator Type() == Float(32));
|
||||
CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
|
|
|
@ -2,7 +2,8 @@ import tvm
|
|||
|
||||
def test_const():
|
||||
x = tvm.const(1)
|
||||
assert x.dtype == 'int32'
|
||||
print(x.dtype)
|
||||
assert x.dtype == tvm.int32
|
||||
assert isinstance(x, tvm.expr.IntImm)
|
||||
|
||||
def test_const_saveload_json():
|
||||
|
|
|
@ -17,10 +17,22 @@ def test_get_global():
|
|||
@tvm.register_func
|
||||
def my_packed_func(*args):
|
||||
assert(tuple(args) == targs)
|
||||
return 10
|
||||
# get it out from global function table
|
||||
f = tvm.get_global_func("my_packed_func")
|
||||
assert isinstance(f, tvm.nd.Function)
|
||||
f(*targs)
|
||||
y = f(*targs)
|
||||
assert y == 10
|
||||
|
||||
|
||||
def test_return_func():
|
||||
def addy(y):
|
||||
def add(x):
|
||||
return tvm.convert(x + y)
|
||||
return add
|
||||
myf = tvm.convert(addy)
|
||||
f = myf(10)
|
||||
assert f(11).value == 21
|
||||
|
||||
|
||||
def test_convert():
|
||||
|
@ -38,3 +50,4 @@ if __name__ == "__main__":
|
|||
test_function()
|
||||
test_convert()
|
||||
test_get_global()
|
||||
test_return_func()
|
||||
|
|
|
@ -38,10 +38,10 @@ fi
|
|||
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
|
||||
make all || exit -1
|
||||
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||
python -m nose tests/python/ || exit -1
|
||||
python3 -m nose tests/python/ || exit -1
|
||||
python -m nose -v tests/python/ || exit -1
|
||||
python3 -m nose -v tests/python/ || exit -1
|
||||
else
|
||||
nosetests tests/python/ || exit -1
|
||||
nosetests3 tests/python/ || exit -1
|
||||
nosetests -v tests/python/ || exit -1
|
||||
nosetests3 -v tests/python/ || exit -1
|
||||
fi
|
||||
fi
|
||||
|
|
Загрузка…
Ссылка в новой задаче