diff --git a/.travis.yml b/.travis.yml index e7110ecb..31d6e49f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ language: cpp os: - linux - - osx + # - osx env: # code analysis diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h new file mode 100644 index 00000000..c4c2a42e --- /dev/null +++ b/include/tvm/api_registry.h @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file api_registry.h + * \brief This file defines the TVM API registry. + * + * The API registry stores type-erased functions. + * Each registered function is automatically exposed + * to front-end language(e.g. python). + * Front-end can also pass callbacks as PackedFunc, or register + * then into the same global registry in C++. + * The goal is to mix the front-end language and the TVM back-end. + * + * \code + * // register the function as MyAPIFuncName + * TVM_REGISTER_API(MyAPIFuncName) + * .set_body([](TVMArgs args, TVMRetValue* rv) { + * // my code. + * }); + * \endcode + */ +#ifndef TVM_API_REGISTRY_H_ +#define TVM_API_REGISTRY_H_ + +#include +#include +#include "./base.h" +#include "./runtime/packed_func.h" +#include "./packed_func_ext.h" + +namespace tvm { + +/*! \brief Utility to register API. */ +class APIRegistry { + public: + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + APIRegistry& set_body(PackedFunc f); // NOLINT(*) + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*) + return set_body(PackedFunc(f)); + } + /*! + * \brief Register a function with given name + * \param name The name of the function. + */ + static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*) + + private: + /*! \brief name of the function */ + std::string name_; +}; + +/*! + * \brief Get API function by name. + * + * \param name The name of the function. + * \return the corresponding API function. + * \note It is really PackedFunc::GetGlobal under the hood. + */ +inline PackedFunc GetAPIFunc(const std::string& name) { + return PackedFunc::GetGlobal(name); +} + +#define _TVM_REGISTER_VAR_DEF_ \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_ + +/*! + * \brief Register API function globally. + * \code + * TVM_REGISTER_API(MyPrint) + * .set_body([](TVMArgs args, TVMRetValue* rv) { + * // my code. + * }); + * \endcode + */ +#define TVM_REGISTER_API(OpName) \ + DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \ + ::tvm::APIRegistry::__REGISTER__(#OpName) +} // namespace tvm +#endif // TVM_API_REGISTRY_H_ diff --git a/include/tvm/c_api.h b/include/tvm/c_api.h index 30fd3ea1..e8abf552 100644 --- a/include/tvm/c_api.h +++ b/include/tvm/c_api.h @@ -2,6 +2,13 @@ * Copyright (c) 2016 by Contributors * \file c_api.h * \brief C API of TVM DSL + * + * \note The API is designed in a minimum way. + * Most of the API functions are registered and can be pulled out. + * + * The common flow is: + * - Use TVMFuncListGlobalNames to get global function name + * - Use TVMFuncCall to call these functions. */ #ifndef TVM_C_API_H_ #define TVM_C_API_H_ @@ -9,76 +16,9 @@ #include "./runtime/c_runtime_api.h" TVM_EXTERN_C { -/*! \brief handle to functions */ -typedef void* APIFuncHandle; /*! \brief handle to node */ typedef void* NodeHandle; -/*! - * \brief List all the node function name - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMListAPIFuncNames(int *out_size, - const char*** out_array); -/*! - * \brief get function handle by name - * \param name The name of function - * \param handle The returning function handle - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMGetAPIFuncHandle(const char* name, - APIFuncHandle *handle); - -/*! - * \brief Get the detailed information about function. - * \param handle The operator handle. - * \param real_name The returned name of the function. - * This name is not the alias name of the atomic symbol. - * \param description The returned description of the symbol. - * \param num_doc_args Number of arguments that contain documents. - * \param arg_names Name of the arguments of doc args - * \param arg_type_infos Type informations about the arguments. - * \param arg_descriptions Description information about the arguments. - * \param return_type Return type of the function, if any. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle, - const char **real_name, - const char **description, - int *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); - -/*! - * \brief Push an argument to the function calling stack. - * If push fails, the stack will be reset to empty - * - * \param arg The argument - * \param type_code The type_code of argument as in TVMTypeCode - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMAPIPushStack(TVMValue arg, - int type_code); - -/*! - * \brief call a function by using arguments in the stack. - * The stack will be cleanup to empty after this call, whether the call is successful. - * - * \param handle The function handle - * \param ret_val The return value. - * \param ret_type_code the type code of return value. - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle, - TVMValue* ret_val, - int* ret_type_code); - /*! * \brief free the node handle * \param handle The node handle to be freed. diff --git a/include/tvm/expr.h b/include/tvm/expr.h index e6057b29..67b5dcdd 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -12,6 +12,7 @@ #include #include #include "./base.h" +#include "./runtime/packed_func.h" namespace tvm { diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h new file mode 100644 index 00000000..edaf43a8 --- /dev/null +++ b/include/tvm/packed_func_ext.h @@ -0,0 +1,196 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file packed_func_ext.h + * \brief Extension package to PackedFunc + * This enales pass NodeRef types into/from PackedFunc. + */ +#ifndef TVM_PACKED_FUNC_EXT_H_ +#define TVM_PACKED_FUNC_EXT_H_ + +#include +#include +#include +#include + +#include "./base.h" +#include "./expr.h" + +namespace tvm { +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; + +namespace runtime { +/*! + * \brief Runtime type checker for node type. + * \tparam T the type to be checked. + */ +template +struct NodeTypeChecker { + static inline bool Check(Node* sptr) { + // This is the only place in the project where RTTI is used + // It can be turned off, but will make non strict checking. + // TODO(tqchen) possibly find alternative to turn of RTTI + using ContainerType = typename T::ContainerType; + return (dynamic_cast(sptr) != nullptr); + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + using ContainerType = typename T::ContainerType; + os << ContainerType::_type_key; + } +}; + +template +struct NodeTypeChecker > { + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; + ArrayNode* n = static_cast(sptr); + for (const auto& p : n->data) { + if (!NodeTypeChecker::Check(p.get())) return false; + } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "array<"; + NodeTypeChecker::PrintName(os); + os << ">"; + } +}; + +template +struct NodeTypeChecker > { + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; + MapNode* n = static_cast(sptr); + for (const auto& kv : n->data) { + if (!NodeTypeChecker::Check(kv.first.get())) return false; + if (!NodeTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "map<"; + NodeTypeChecker::PrintName(os); + os << ','; + NodeTypeChecker::PrintName(os); + os << '>'; + } +}; + +template +inline std::string NodeTypeName() { + std::ostringstream os; + NodeTypeChecker::PrintName(os); + return os.str(); +} + +// extensions for tvm arg value + +template +inline TVMArgValue::operator TNodeRef() const { + static_assert( + std::is_base_of::value, + "Conversion only works for NodeRef"); + if (type_code_ == kNull) return TNodeRef(); + TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); + std::shared_ptr& sptr = *ptr >(); + CHECK(NodeTypeChecker::Check(sptr.get())) + << "Expected type " << NodeTypeName() + << " but get " << sptr->type_key(); + return TNodeRef(sptr); +} + +inline TVMArgValue::operator Halide::Expr() const { + if (type_code_ == kNull) return Expr(); + if (type_code_ == kInt) { + return Expr(static_cast(value_.v_int64)); + } + if (type_code_ == kFloat) { + return Expr(static_cast(value_.v_float64)); + } + TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); + std::shared_ptr& sptr = *ptr >(); + if (sptr->is_type()) { + return IterVar(sptr)->var; + } + CHECK(NodeTypeChecker::Check(sptr.get())) + << "Expected type " << NodeTypeName() + << " but get " << sptr->type_key(); + return Expr(sptr); +} + +inline std::shared_ptr& TVMArgValue::node_sptr() { + TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); + return *ptr >(); +} + + +template +inline bool TVMArgValue::IsNodeType() const { + TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); + std::shared_ptr& sptr = + *ptr >(); + return NodeTypeChecker::Check(sptr.get()); +} + +// extensions for TVMRetValue +inline TVMRetValue& TVMRetValue::operator=( + const std::shared_ptr& other) { + SwitchToClass >(kNodeHandle, other); + return *this; +} + +inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { + SwitchToClass >(kNodeHandle, other.node_); + return *this; +} + +template +inline TVMRetValue::operator TNodeRef() const { + static_assert( + std::is_base_of::value, + "Conversion only works for NodeRef"); + if (type_code_ == kNull) return TNodeRef(); + TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); + return TNodeRef(*ptr >()); +} + +inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*) + values_[i].v_handle = &(other.node_); + type_codes_[i] = kNodeHandle; +} + +// Type related stuffs +inline Type TVMType2Type(TVMType t) { + return Type(static_cast(t.code), t.bits, t.lanes); +} + +inline TVMType Type2TVMType(Type t) { + TVMType ret; + ret.code = static_cast(t.code()); + ret.bits = static_cast(t.bits()); + ret.lanes = static_cast(t.lanes()); + return ret; +} + +inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) { + return this->operator=(Type2TVMType(t)); +} + +inline TVMRetValue::operator Halide::Type() const { + return TVMType2Type(operator TVMType()); +} + +inline TVMArgValue::operator Halide::Type() const { + return TVMType2Type(operator TVMType()); +} + +inline void TVMArgsSetter::operator()( + size_t i, const Halide::Type& t) const { + this->operator()(i, Type2TVMType(t)); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_PACKED_FUNC_EXT_H_ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f3e9eee8..5151fb53 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -36,18 +36,6 @@ TVM_EXTERN_C { /*! \brief type of array index. */ typedef uint32_t tvm_index_t; - -/*! - * \brief Union type of values - * being passed through API and function calls. - */ -typedef union { - int64_t v_int64; - double v_float64; - void* v_handle; - const char* v_str; -} TVMValue; - /*! * \brief The type code in TVMType * \note TVMType is used in two places. @@ -60,9 +48,11 @@ typedef enum { // The next few fields are extension types // that is used by TVM API calls. kNull = 4U, - kNodeHandle = 5U, - kStr = 6U, - kFuncHandle = 7U + kArrayHandle = 5U, + kTVMType = 6U, + kNodeHandle = 7U, + kStr = 8U, + kFuncHandle = 9U } TVMTypeCode; /*! @@ -77,13 +67,25 @@ typedef enum { */ typedef struct { /*! \brief type code, in TVMTypeCode */ - uint8_t type_code; + uint8_t code; /*! \brief number of bits of the type */ uint8_t bits; /*! \brief number of lanes, */ uint16_t lanes; } TVMType; +/*! + * \brief Union type of values + * being passed through API and function calls. + */ +typedef union { + int64_t v_int64; + double v_float64; + void* v_handle; + const char* v_str; + TVMType v_type; +} TVMValue; + /*! * \brief The device type */ @@ -133,11 +135,10 @@ typedef struct { * can be NULL, which indicates the default one. */ typedef void* TVMStreamHandle; -/*! - * \brief Pointer to function handle that points to - * a generated TVM function. - */ +/*! \brief Handle to packed function handle. */ typedef void* TVMFunctionHandle; +/*! \brief Handle to hold return value. */ +typedef void* TVMRetValueHandle; /*! \brief the array handle */ typedef TVMArray* TVMArrayHandle; @@ -228,20 +229,45 @@ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); TVM_DLL int TVMFuncFree(TVMFunctionHandle func); /*! - * \brief Call a function whose parameters are all packed. + * \brief Call a Packed TVM Function. * * \param func node handle of the function. - * \param args The arguments + * \param arg_values The arguments * \param type_codes The type codes of the arguments * \param num_args Number of arguments. * + * \param ret_val The return value. + * \param ret_type_code the type code of return value. + * * \return 0 when success, -1 when failure happens * \note TVM calls always exchanges with type bits=64, lanes=1 + * + * \note API calls always exchanges with type bits=64, lanes=1 + * If API call returns container handles (e.g. FunctionHandle) + * these handles should be managed by the front-end. + * The front-end need to call free function (e.g. TVMFuncFree) + * to free these handles. */ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, - TVMValue* args, + TVMValue* arg_values, int* type_codes, - int num_args); + int num_args, + TVMValue* ret_val, + int* ret_type_code); + +/*! + * \brief Set the return value of TVMPackedCFunc. + * + * This function is called by TVMPackedCFunc to set the return value. + * When this function is not called, the function returns null by default. + * + * \param ret The return value handle, pass by ret in TVMPackedCFunc + * \param value The value to be returned. + * \param type_code The type of the value to be returned. + */ +TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, + TVMValue value, + int type_code); /*! * \brief C type of packed function. @@ -249,10 +275,17 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, * \param args The arguments * \param type_codes The type codes of the arguments * \param num_args Number of arguments. + * \param ret The return value handle. * \param resource_handle The handle additional resouce handle from fron-end. + * + * \sa TVMCFuncSetReturn */ typedef void (*TVMPackedCFunc)( - TVMValue* args, int* type_codes, int num_args, void* resource_handle); + TVMValue* args, + int* type_codes, + int num_args, + TVMRetValueHandle ret, + void* resource_handle); /*! * \brief C callback to free the resource handle in C packed function. @@ -291,8 +324,20 @@ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f); * * \param name The name of the function. * \param out the result function pointer. + * + * \note The function handle of global function is managed by TVM runtime, + * So TVMFuncFree is should not be called when it get deleted. */ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + +/*! + * \brief List all the globally registered function name + * \param out_size The number of functions + * \param out_array The array of function names. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMFuncListGlobalNames(int *out_size, + const char*** out_array); } // TVM_EXTERN_C #endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index b4868f75..e3a391c8 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -1,19 +1,41 @@ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2017 by Contributors * \file packed_func.h * \brief Runtime related c++ class. */ #ifndef TVM_RUNTIME_PACKED_FUNC_H_ #define TVM_RUNTIME_PACKED_FUNC_H_ +#include #include #include #include #include +#include +#include +#include #include "./c_runtime_api.h" +namespace Halide { +// Forward declare type for extensions +// The header works fine without depending on this. +struct Type; +struct Expr; +} + namespace tvm { +// Forward declare NodeRef and Node for extensions. +// This header works fine without depend on NodeRef +// as long as it is not used. +class Node; +class NodeRef; + namespace runtime { +// forward declarations +class TVMArgs; +class TVMArgValue; +class TVMRetValue; +class TVMArgsSetter; /*! * \brief Packed function is a type-erased function. @@ -25,8 +47,25 @@ namespace runtime { */ class PackedFunc { public: - /*! \brief The internal std::function */ - using FType = std::function; + /*! + * \brief The internal std::function + * \param args The arguments to the function. + * \param rv The return value. + * + * \code + * // Example code on how to implemented FType + * void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { + * // automatically convert arguments to desired type. + * int a0 = args[0]; + * float a1 = args[1]; + * ... + * // automatically assign values to rv + * std::string my_return_value = "x"; + * *rv = my_return_value; + * } + * \endcode + */ + using FType = std::function; /*! \brief default constructor */ PackedFunc() {} /*! @@ -38,16 +77,24 @@ class PackedFunc { * \brief Call packed function by directly passing in unpacked format. * \param args Arguments to be passed. * \tparam Args arguments to be passed. + * + * \code + * // Example code on how to call packed function + * void CallPacked(PackedFunc f) { + * // call like normal functions by pass in arguments + * // return value is automatically converted back + * int rvalue = f(1, 2.0); + * } + * \endcode */ template - inline void operator()(Args&& ...args) const; + inline TVMRetValue operator()(Args&& ...args) const; /*! * \brief Call the function in packed format. * \param args The arguments - * \param type_codes The type_codes of the arguments - * \param num_args Number of arguments. + * \param rv The return value. */ - inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const; + inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; /*! \return the internal body function */ inline FType body() const; /*! @@ -74,82 +121,552 @@ class PackedFunc { FType body_; }; -// implementations -inline void PackedFunc::CallPacked( - const TVMValue* args, const int* type_codes, int num_args) const { - body_(args, type_codes, num_args); +/*! \brief Arguments into TVM functions. */ +class TVMArgs { + public: + const TVMValue* values; + const int* type_codes; + int num_args; + /*! + * \brief constructor + * \param values The argument values + * \param type_codes The argument type codes + * \param num_args number of arguments. + */ + TVMArgs(const TVMValue* values, + const int* type_codes, + int num_args) + : values(values), + type_codes(type_codes), + num_args(num_args) { } + /*! \return size of the arguments */ + inline int size() const; + /*! + * \brief Get i-th argument + * \param i the index. + * \return the ith argument. + */ + inline TVMArgValue operator[](int i) const; +}; + +/*! + * \brief Convert type code to its name + * \param type_code The type code . + * \return The name of type code. + */ +inline const char* TypeCode2Str(int type_code); + +/*! + * \brief convert a string to TVM type. + * \param s The string to be converted. + * \return The corresponding tvm type. + */ +inline TVMType String2TVMType(std::string s); + +// macro to check type code. +#define TVM_CHECK_TYPE_CODE(CODE, T) \ + CHECK_EQ(CODE, T) << " expected " \ + << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ + +/*! + * \brief Internal base class to + * handle conversion to POD values. + */ +class TVMPODValue_ { + public: + operator double() const { + TVM_CHECK_TYPE_CODE(type_code_, kFloat); + return value_.v_float64; + } + operator int64_t() const { + TVM_CHECK_TYPE_CODE(type_code_, kInt); + return value_.v_int64; + } + operator uint64_t() const { + TVM_CHECK_TYPE_CODE(type_code_, kInt); + return value_.v_int64; + } + operator int() const { + TVM_CHECK_TYPE_CODE(type_code_, kInt); + CHECK_LE(value_.v_int64, + std::numeric_limits::max()); + return static_cast(value_.v_int64); + } + operator bool() const { + TVM_CHECK_TYPE_CODE(type_code_, kInt); + return value_.v_int64 != 0; + } + operator void*() const { + if (type_code_ == kNull) return nullptr; + if (type_code_ == kArrayHandle) return value_.v_handle; + TVM_CHECK_TYPE_CODE(type_code_, kHandle); + return value_.v_handle; + } + operator TVMArray*() const { + if (type_code_ == kNull) return nullptr; + TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle); + return static_cast(value_.v_handle); + } + int type_code() const { + return type_code_; + } + + protected: + friend class TVMArgsSetter; + friend class TVMRetValue; + TVMPODValue_() : type_code_(kNull) {} + TVMPODValue_(TVMValue value, int type_code) + : value_(value), type_code_(type_code) {} + /*! + * \brief return handle as specific pointer type. + * \tparam T the data type. + * \return The pointer type. + */ + template + T* ptr() const { + return static_cast(value_.v_handle); + } + /*! \brief The value */ + TVMValue value_; + /*! \brief the type code */ + int type_code_; +}; + +/*! + * \brief A single argument value to PackedFunc. + * Containing both type_code and TVMValue + * + * Provides utilities to do type cast into other types. + */ +class TVMArgValue : public TVMPODValue_ { + public: + /*! + * \brief constructor + * \param value of the function + * \param type_code The type code. + */ + TVMArgValue(TVMValue value, int type_code) + : TVMPODValue_(value, type_code) { + } + // reuse converter from parent + using TVMPODValue_::operator double; + using TVMPODValue_::operator int64_t; + using TVMPODValue_::operator uint64_t; + using TVMPODValue_::operator int; + using TVMPODValue_::operator bool; + using TVMPODValue_::operator void*; + using TVMPODValue_::operator TVMArray*; + // conversion operator. + operator std::string() const { + TVM_CHECK_TYPE_CODE(type_code_, kStr); + return std::string(value_.v_str); + } + operator TVMType() const { + if (type_code_ == kStr) { + return String2TVMType(operator std::string()); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMType); + return value_.v_type; + } + operator PackedFunc() const { + TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); + return *ptr(); + } + const TVMValue& value() const { + return value_; + } + // NodeRef related extenstions: in tvm/packed_func_ext.h + template::value>::type> + inline operator TNodeRef() const; + template::value>::type> + inline bool IsNodeType() const; + inline operator Halide::Type() const; + inline operator Halide::Expr() const; + // get internal node ptr, if it is node + inline std::shared_ptr& node_sptr(); +}; + +/*! + * \brief Return Value container, + * Unlike TVMArgValue, which only holds reference and do not delete + * the underlying container during destruction. + * + * TVMRetValue holds value and will manage the underlying containers + * when it stores a complicated data type. + */ +class TVMRetValue : public TVMPODValue_ { + public: + /*! \brief default constructor */ + TVMRetValue() {} + /*! + * \brief move constructor from anoter return value. + * \param other The other return value. + */ + TVMRetValue(TVMRetValue&& other) + : TVMPODValue_(other.value_, other.type_code_) { + other.type_code_ = kNull; + } + /*! \brief destructor */ + ~TVMRetValue() { + this->Clear(); + } + // reuse converter from parent + using TVMPODValue_::operator double; + using TVMPODValue_::operator int64_t; + using TVMPODValue_::operator uint64_t; + using TVMPODValue_::operator int; + using TVMPODValue_::operator bool; + using TVMPODValue_::operator void*; + using TVMPODValue_::operator TVMArray*; + // Disable copy and assign from another value, but allow move. + TVMRetValue(const TVMRetValue& other) { + this->Assign(other); + } + // conversion operators + operator std::string() const { + TVM_CHECK_TYPE_CODE(type_code_, kStr); + return *ptr(); + } + operator TVMType() const { + if (type_code_ == kStr) { + return String2TVMType(operator std::string()); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMType); + return value_.v_type; + } + operator PackedFunc() const { + TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); + return *ptr(); + } + // Assign operators + TVMRetValue& operator=(TVMRetValue&& other) { + this->Clear(); + value_ = other.value_; + type_code_ = other.type_code_; + other.type_code_ = kNull; + return *this; + } + TVMRetValue& operator=(double value) { + this->SwitchToPOD(kFloat); + value_.v_float64 = value; + return *this; + } + TVMRetValue& operator=(std::nullptr_t value) { + this->SwitchToPOD(kNull); + value_.v_handle = value; + return *this; + } + TVMRetValue& operator=(void* value) { + this->SwitchToPOD(kHandle); + value_.v_handle = value; + return *this; + } + TVMRetValue& operator=(int64_t value) { + this->SwitchToPOD(kInt); + value_.v_int64 = value; + return *this; + } + TVMRetValue& operator=(int value) { + this->SwitchToPOD(kInt); + value_.v_int64 = value; + return *this; + } + TVMRetValue& operator=(TVMType t) { + this->SwitchToPOD(kTVMType); + value_.v_type = t; + return *this; + } + TVMRetValue& operator=(bool value) { + this->SwitchToPOD(kInt); + value_.v_int64 = value; + return *this; + } + TVMRetValue& operator=(std::string value) { + this->SwitchToClass(kStr, value); + return *this; + } + TVMRetValue& operator=(PackedFunc f) { + this->SwitchToClass(kFuncHandle, f); + return *this; + } + TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 + this->Assign(other); + return *this; + } + TVMRetValue& operator=(TVMArgValue other) { + this->Assign(other); + return *this; + } + /*! + * \brief Move the value back to front-end via C API. + * This marks the current container as null. + * The managed resources is moved to front-end and + * the front end should take charge in managing them. + * + * \param ret_value The return value. + * \param ret_type_code The return type code. + */ + void MoveToCHost(TVMValue* ret_value, + int* ret_type_code) { + // cannot move str; need specially handle. + CHECK(type_code_ != kStr); + *ret_value = value_; + *ret_type_code = type_code_; + type_code_ = kNull; + } + // NodeRef related extenstions: in tvm/packed_func_ext.h + inline TVMRetValue& operator=(const NodeRef& other); + inline TVMRetValue& operator=(const std::shared_ptr& other); + template::value>::type> + inline operator TNodeRef() const; + // type related + inline operator Halide::Type() const; + inline TVMRetValue& operator=(const Halide::Type& other); + + private: + template + void Assign(const T& other) { + switch (other.type_code()) { + case kStr: { + SwitchToClass(kStr, other); + break; + } + case kFuncHandle: { + SwitchToClass(kFuncHandle, other); + break; + } + case kNodeHandle: { + SwitchToClass >( + kNodeHandle, *other.template ptr >()); + break; + } + default: { + SwitchToPOD(other.type_code()); + value_ = other.value_; + break; + } + } + } + + // get the internal container. + void SwitchToPOD(int type_code) { + if (type_code_ != type_code) { + this->Clear(); + type_code_ = type_code; + } + } + template + void SwitchToClass(int type_code, T v) { + if (type_code_ != type_code) { + this->Clear(); + type_code_ = type_code; + value_.v_handle = new T(v); + } else { + *static_cast(value_.v_handle) = v; + } + } + void Clear() { + if (type_code_ == kNull) return; + switch (type_code_) { + case kStr: delete ptr(); break; + case kFuncHandle: delete ptr(); break; + case kNodeHandle: delete ptr >(); break; + } + type_code_ = kNull; + } +}; + +// implementation details +inline const char* TypeCode2Str(int type_code) { + switch (type_code) { + case kInt: return "int"; + case kFloat: return "float"; + case kStr: return "str"; + case kHandle: return "Handle"; + case kNull: return "NULL"; + case kNodeHandle: return "NodeHandle"; + case kArrayHandle: return "ArrayHandle"; + case kTVMType: return "TVMType"; + case kFuncHandle: return "FunctionHandle"; + default: LOG(FATAL) << "unknown type_code=" + << static_cast(type_code); return ""; + } +} + +inline TVMType String2TVMType(std::string s) { + TVMType t; + t.bits = 32; t.lanes = 1; + const char* scan; + if (s.substr(0, 3) == "int") { + t.code = kInt; scan = s.c_str() + 3; + } else if (s.substr(0, 4) == "uint") { + t.code = kUInt; scan = s.c_str() + 4; + } else if (s.substr(0, 5) == "float") { + t.code = kFloat; scan = s.c_str() + 5; + } else if (s == "handle") { + t.code = kHandle; + t.bits = 64; // handle uses 64 bit by default. + scan = s.c_str() + 6; + } else { + scan = s.c_str(); + LOG(FATAL) << "unknown type " << s; + } + unsigned bits = t.bits, lanes = t.lanes; + sscanf(scan, "%ux%u", &bits, &lanes); + t.bits = static_cast(bits); + t.lanes = static_cast(lanes); + return t; +} + +inline TVMArgValue TVMArgs::operator[](int i) const { + CHECK_LT(i, num_args) + << "not enough argument passed, " + << num_args << " passed" + << "but request arg" << i; + return TVMArgValue(values[i], type_codes[i]); +} + +inline int TVMArgs::size() const { + return num_args; +} + +inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { + body_(args, rv); } inline PackedFunc::FType PackedFunc::body() const { return body_; } +// internal namespace +namespace detail { template -struct for_each_dispatcher_ { - static inline void run(const std::tuple& args, F f) { +struct for_each_dispatcher { + static void run(std::tuple& args, const F& f) { // NOLINT(*) f(I, std::get(args)); - for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); + for_each_dispatcher<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); } }; template -struct for_each_dispatcher_ { - static inline void run(const std::tuple& args, F f) {} +struct for_each_dispatcher { + static void run(std::tuple& args, const F& f) {} // NOLINT(*) }; +} // namespace detail template -inline void for_each(const std::tuple& args, F f) { - for_each_dispatcher_::run(args, f); +inline void for_each(std::tuple& args, const F& f) { // NOLINT(*) + detail::for_each_dispatcher::run(args, f); } -namespace arg_setter { -template -inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*) -template<> -inline void Set(TVMValue& arg, int& t, double value) { // NOLINT(*) - arg.v_float64 = value; - t = kFloat; -} -template<> -inline void Set(TVMValue& arg, int& t, int value) { // NOLINT(*) - arg.v_int64 = value; - t = kInt; -} -template<> -inline void Set(TVMValue& arg, int& t, long value) { // NOLINT(*) - arg.v_int64 = value; - t = kInt; -} -template<> -inline void Set(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*) - arg.v_handle = value; - t = kHandle; -} -template<> -inline void Set(TVMValue& arg, int& t, void* value) { // NOLINT(*) - arg.v_handle = value; - t = kHandle; -} -} // namespace arg_setter - -struct PackedFuncArgSetter { - TVMValue* args; - int* type_codes; - template - inline void operator()(size_t i, T v) const { - arg_setter::Set(args[i], type_codes[i], v); +/* \brief argument settter to PackedFunc */ +class TVMArgsSetter { + public: + TVMArgsSetter(TVMValue* values, int* type_codes) + : values_(values), type_codes_(type_codes) {} + // setters for POD types + template::value>::type> + void operator()(size_t i, T value) const { + values_[i].v_int64 = static_cast(value); + type_codes_[i] = kInt; } + void operator()(size_t i, uint64_t value) const { + values_[i].v_int64 = static_cast(value); + CHECK_LE(value, + static_cast(std::numeric_limits::max())); + type_codes_[i] = kInt; + } + void operator()(size_t i, double value) const { + values_[i].v_float64 = value; + type_codes_[i] = kFloat; + } + void operator()(size_t i, std::nullptr_t value) const { + values_[i].v_handle = value; + type_codes_[i] = kNull; + } + void operator()(size_t i, const TVMArgValue& value) const { + values_[i] = value.value_; + type_codes_[i] = value.type_code_; + } + void operator()(size_t i, void* value) const { + values_[i].v_handle = value; + type_codes_[i] = kHandle; + } + void operator()(size_t i, TVMArray* value) const { + values_[i].v_handle = value; + type_codes_[i] = kArrayHandle; + } + void operator()(size_t i, TVMType value) const { + values_[i].v_type = value; + type_codes_[i] = kTVMType; + } + void operator()(size_t i, const char* value) const { + values_[i].v_str = value; + type_codes_[i] = kStr; + } + // setters for container type + // They must be reference(instead of const ref) + // to make sure they are alive in the tuple(instead of getting converted) + void operator()(size_t i, std::string& value) const { // NOLINT(*) + values_[i].v_str = value.c_str(); + type_codes_[i] = kStr; + } + void operator()(size_t i, PackedFunc& value) const { // NOLINT(*) + values_[i].v_handle = &value; + type_codes_[i] = kFuncHandle; + } + void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*) + if (value.type_code() == kStr) { + values_[i].v_str = value.ptr()->c_str(); + type_codes_[i] = kStr; + } else { + values_[i] = value.value_; + type_codes_[i] = value.type_code(); + } + } + // NodeRef related extenstions: in tvm/packed_func_ext.h + inline void operator()(size_t i, NodeRef& other) const; // NOLINT(*) + inline void operator()(size_t i, const Halide::Type& t) const; + + private: + /*! \brief The values fields */ + TVMValue* values_; + /*! \brief The type code fields */ + int* type_codes_; +}; + +class TVMArgsGetter { + public: + explicit TVMArgsGetter(TVMArgs args) + : args_(args) {} + + template + inline void operator()(size_t i, T& target) const { // NOLINT(*) + target = args_[i].operator T(); + } + private: + TVMArgs args_; }; template -inline void PackedFunc::operator()(Args&& ...args) const { - auto targ = std::make_tuple(std::forward(args)...); +inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { + auto targs = std::make_tuple(std::forward(args)...); const int kNumArgs = sizeof...(Args); - TVMValue tvm_args[kNumArgs]; - int tvm_arg_type_ids[kNumArgs]; - for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids}); - body_(tvm_args, tvm_arg_type_ids, kNumArgs); + TVMValue values[kNumArgs]; + int type_codes[kNumArgs]; + for_each(targs, TVMArgsSetter(values, type_codes)); + TVMRetValue rv; + body_(TVMArgs(values, type_codes, kNumArgs), &rv); + return rv; } + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/include/tvm/tvm.h b/include/tvm/tvm.h index 5b7113df..6272b9d0 100644 --- a/include/tvm/tvm.h +++ b/include/tvm/tvm.h @@ -10,5 +10,6 @@ #include "./expr.h" #include "./tensor.h" #include "./operation.h" +#include "./packed_func_ext.h" #endif // TVM_TVM_H_ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 6729a8bd..2bd6359b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -1,7 +1,7 @@ # pylint: disable=redefined-builtin, wildcard-import """C++ backend related python scripts""" from __future__ import absolute_import as _abs -from ._ctypes._api import register_node +from ._ctypes._node import register_node from . import tensor from . import expr diff --git a/python/tvm/_base.py b/python/tvm/_base.py index 33c9848f..530b883c 100644 --- a/python/tvm/_base.py +++ b/python/tvm/_base.py @@ -91,45 +91,3 @@ def c_array(ctype, values): Created ctypes array """ return (ctype * len(values))(*values) - - -def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True): - """Convert ctypes returned doc string information into parameters docstring. - - num_args : nn_uint - Number of arguments. - - arg_names : ctypes.POINTER(ctypes.c_char_p) - Argument names. - - arg_types : ctypes.POINTER(ctypes.c_char_p) - Argument type information. - - arg_descs : ctypes.POINTER(ctypes.c_char_p) - Argument description information. - - remove_dup : boolean, optional - Whether remove duplication or not. - - Returns - ------- - docstr : str - Python docstring of parameter sections. - """ - param_keys = set() - param_str = [] - for i in range(num_args.value): - key = py_str(arg_names[i]) - if key in param_keys and remove_dup: - continue - param_keys.add(key) - type_info = py_str(arg_types[i]) - ret = '%s : %s' % (key, type_info) - if len(arg_descs[i]) != 0: - ret += '\n ' + py_str(arg_descs[i]) - param_str.append(ret) - doc_str = ('Parameters\n' + - '----------\n' + - '%s\n') - doc_str = doc_str % ('\n'.join(param_str)) - return doc_str diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py deleted file mode 100644 index f6eb9ac0..00000000 --- a/python/tvm/_ctypes/_api.py +++ /dev/null @@ -1,416 +0,0 @@ -# coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines -# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements -"""Symbolic configuration API.""" -from __future__ import absolute_import as _abs - -import ctypes -import sys -from numbers import Number, Integral - -from .._base import _LIB -from .._base import c_str, py_str, string_types -from .._base import check_call, ctypes2docstring -from .. import _api_internal -from . import _runtime_api -from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer - -# type definitions -APIFuncHandle = ctypes.c_void_p -NodeHandle = ctypes.c_void_p -FunctionHandle = ctypes.c_void_p - -class APIType(object): - """TVMType used in API calls""" - INT = ctypes.c_int(TypeCode.INT) - UINT = ctypes.c_int(TypeCode.UINT) - FLOAT = ctypes.c_int(TypeCode.FLOAT) - HANDLE = ctypes.c_int(TypeCode.HANDLE) - NULL = ctypes.c_int(TypeCode.NULL) - NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE) - STR = ctypes.c_int(TypeCode.STR) - FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE) - - -NODE_TYPE = { -} - -def _return_node(x): - handle = x.v_handle - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - handle, c_str("type_key"), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle) - - -def _return_func(x): - handle = x.v_handle - if not isinstance(handle, FunctionHandle): - handle = FunctionHandle(handle) - return _runtime_api._function_cls(handle) - - -def _return_handle(x): - handle = x.v_handle - if not isinstance(handle, ctypes.c_void_p): - handle = ctypes.c_void_p(handle) - return handle - - -RET_SWITCH = { - TypeCode.NULL: lambda x: None, - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.NODE_HANDLE: _return_node, - TypeCode.FUNC_HANDLE: _return_func -} - -PACK_ARG_SWITCH = { - TypeCode.NULL: lambda x: None, - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.HANDLE: lambda x: _return_handle, -} - - -class SliceBase(object): - """base class of slice object""" - pass - -class NodeBase(object): - """Symbol is symbolic graph.""" - __slots__ = ["handle"] - # pylint: disable=no-member - def __init__(self, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : SymbolHandle - the handle to the underlying C++ Symbol - """ - self.handle = handle - - def __repr__(self): - return _api_internal._format_str(self) - - def __del__(self): - check_call(_LIB.TVMNodeFree(self.handle)) - - def __getattr__(self, name): - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - self.handle, c_str(name), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - value = RET_SWITCH[ret_type_code.value](ret_val) - if not ret_success.value: - raise AttributeError( - "'%s' object has no attribute '%s'" % (str(type(self)), name)) - return value - - def __hash__(self): - return _api_internal._raw_ptr(self) - - def __eq__(self, other): - if not isinstance(other, NodeBase): - return False - return self.__hash__() == other.__hash__() - - def __ne__(self, other): - return not self.__eq__(other) - - def __dir__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - check_call(_LIB.TVMNodeListAttrNames( - self.handle, ctypes.byref(size), ctypes.byref(plist))) - names = [] - for i in range(size.value): - names.append(py_str(plist[i])) - return names - - def __reduce__(self): - return (type(self), (None,), self.__getstate__()) - - def __getstate__(self): - handle = self.handle - if handle is not None: - return {'handle': _api_internal._save_json(self)} - else: - return {'handle': None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot - handle = state['handle'] - if handle is not None: - json_str = handle - _push_arg(json_str) - other = _api_internal._load_json(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None - - -def const(value, dtype=None): - """construct a constant""" - if dtype is None: - if isinstance(value, Integral): - dtype = 'int32' - else: - dtype = 'float32' - return _api_internal._const(value, dtype) - - -def _ctypes_free_resource(rhandle): - """callback to free resources when it it not needed.""" - pyobj = ctypes.cast(rhandle, ctypes.py_object) - ctypes.pythonapi.Py_DecRef(pyobj) - -# Global callback that is always alive -TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) -ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) - -def convert_to_tvm_func(pyfunc): - """Convert a python function to TVM function - - Parameters - ---------- - pyfunc : python function - The python function to be converted. - - Returns - ------- - tvmfunc: tvm.nd.Function - The converted tvm function. - """ - local_pyfunc = pyfunc - def cfun(args, type_codes, num_args, _): - """ ctypes function """ - num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args - pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)] - local_pyfunc(*pyargs) - handle = FunctionHandle() - f = TVMPackedCFunc(cfun) - # NOTE: We will need to use python-api to increase ref count of the f - # TVM_FREE_PYOBJ will be called after it is no longer needed. - pyobj = ctypes.py_object(f) - ctypes.pythonapi.Py_IncRef(pyobj) - check_call(_LIB.TVMFuncCreateFromCFunc( - f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle))) - return _runtime_api._function_cls(handle) - - -def convert(value): - """Convert a value to expression.""" - if isinstance(value, (NodeBase, _runtime_api.FunctionBase)): - return value - elif isinstance(value, Number): - return const(value) - elif isinstance(value, string_types): - return _api_internal._str(value) - elif isinstance(value, (list, tuple)): - value = [convert(x) for x in value] - return _api_internal._Array(*value) - elif isinstance(value, dict): - vlist = [] - for it in value.items(): - if not isinstance(it[0], NodeBase): - raise ValueError("key of map must already been a container type") - vlist.append(it[0]) - vlist.append(convert(it[1])) - return _api_internal._Map(*vlist) - elif isinstance(value, SliceBase): - return value.tensor(*value.indices) - elif callable(value): - return convert_to_tvm_func(value) - else: - raise ValueError("don't know how to handle type %s" % type(value)) - return value - - -def _push_arg(arg): - a = TVMValue() - if arg is None: - _LIB.TVMAPIPushStack(a, APIType.NULL) - elif isinstance(arg, NodeBase): - a.v_handle = arg.handle - _LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE) - elif isinstance(arg, Integral): - a.v_int64 = ctypes.c_int64(arg) - _LIB.TVMAPIPushStack(a, APIType.INT) - elif isinstance(arg, Number): - a.v_double = ctypes.c_double(arg) - _LIB.TVMAPIPushStack(a, APIType.FLOAT) - elif isinstance(arg, string_types): - a.v_str = c_str(arg) - _LIB.TVMAPIPushStack(a, APIType.STR) - else: - raise TypeError("Don't know how to handle type %s" % type(arg)) - - -def _make_function(handle, name): - """Create an atomic symbol function by handle and funciton name.""" - real_name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = ctypes.c_int() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - ret_type = ctypes.c_char_p() - - check_call(_LIB.TVMGetAPIFuncInfo( - handle, ctypes.byref(real_name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(ret_type))) - - param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) - func_name = name - desc = py_str(desc.value) - - doc_str = ('%s\n\n' + - '%s\n') - doc_str = doc_str % (desc, param_str) - arg_names = [py_str(arg_names[i]) for i in range(num_args.value)] - - def func(*args): - """TVM function""" - cargs = [] - for x in args: - if isinstance(x, (list, tuple, dict, SliceBase)): - cargs.append(convert(x)) - else: - cargs.append(x) - - for arg in cargs: - _push_arg(arg) - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - check_call(_LIB.TVMAPIFuncCall( - handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code))) - return RET_SWITCH[ret_type_code.value](ret_val) - - func.__name__ = func_name - func.__doc__ = doc_str - return func - - -def register_node(type_key=None): - """register node type - - Parameters - ---------- - type_key : str or cls - The type key of the node - """ - if isinstance(type_key, str): - def register(cls): - """internal register function""" - NODE_TYPE[type_key] = cls - return cls - return register - else: - cls = type_key - NODE_TYPE[cls.__name__] = cls - return cls - - -def register_func(func_name, f=None): - """Register global function - - Parameters - ---------- - func_name : str or function - The function name - - f : function - The function to be registered. - - Returns - ------- - fregister : function - Register function if f is not specified. - """ - if callable(func_name): - f = func_name - func_name = f.__name__ - - if not isinstance(func_name, str): - raise ValueError("expect string function name") - def register(myf): - """internal register function""" - if not isinstance(myf, _runtime_api.FunctionBase): - myf = convert_to_tvm_func(myf) - check_call(_LIB.TVMFuncRegisterGlobal( - c_str(func_name), myf.handle)) - if f: - register(f) - else: - return register - - -def get_global_func(name): - """Get a global function by name - - Parameters - ---------- - name : str - The name of the global function - - Returns - ------- - func : tvm.nd.Function - The function to be returned. - """ - handle = FunctionHandle() - check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) - return _runtime_api._function_cls(handle) - - -def _init_api_module(root_namespace): - """List and add all the functions to current module.""" - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = sys.modules["%s.api" % root_namespace] - module_internal = sys.modules["%s._api_internal" % root_namespace] - namespace_match = { - "_make_": sys.modules["%s.make" % root_namespace], - "_pass_": sys.modules["%s.ir_pass" % root_namespace], - "_codegen_": sys.modules["%s.codegen" % root_namespace], - "_schedule_": sys.modules["%s.schedule" % root_namespace] - } - - for name in op_names: - hdl = APIFuncHandle() - check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl))) - fname = name - target_module = module_internal if name.startswith('_') else module_obj - for k, v in namespace_match.items(): - if name.startswith(k): - fname = name[len(k):] - target_module = v - function = _make_function(hdl, fname) - setattr(target_module, function.__name__, function) diff --git a/python/tvm/_ctypes/_function.py b/python/tvm/_ctypes/_function.py new file mode 100644 index 00000000..3b133e55 --- /dev/null +++ b/python/tvm/_ctypes/_function.py @@ -0,0 +1,250 @@ +# coding: utf-8 +# pylint: disable=invalid-name, protected-access +"""Symbolic configuration API.""" +from __future__ import absolute_import + +import ctypes +import sys +from numbers import Number, Integral + +from .._base import _LIB, check_call +from .._base import c_str, py_str, string_types +from ._types import TVMValue, TypeCode, TVMType +from ._types import TVMPackedCFunc, TVMCFuncFinalizer +from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH +from ._node import NodeBase, SliceBase, convert_to_node +from ._ndarray import NDArrayBase + +FunctionHandle = ctypes.c_void_p +TVMRetValueHandle = ctypes.c_void_p + +def _ctypes_free_resource(rhandle): + """callback to free resources when it it not needed.""" + pyobj = ctypes.cast(rhandle, ctypes.py_object) + ctypes.pythonapi.Py_DecRef(pyobj) + +# Global callback that is always alive +TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) +ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) + +def convert_to_tvm_func(pyfunc): + """Convert a python function to TVM function + + Parameters + ---------- + pyfunc : python function + The python function to be converted. + + Returns + ------- + tvmfunc: tvm.nd.Function + The converted tvm function. + """ + local_pyfunc = pyfunc + def cfun(args, type_codes, num_args, ret, _): + """ ctypes function """ + num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args + pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)] + rv = local_pyfunc(*pyargs) + if rv is not None: + if isinstance(rv, tuple): + raise ValueError("PackedFunction can only support one reurn value") + temp_args = [] + values, tcodes, _ = _make_tvm_args((rv,), temp_args) + if not isinstance(ret, TVMRetValueHandle): + ret = TVMRetValueHandle(ret) + check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0]))) + _ = temp_args + _ = rv + + handle = FunctionHandle() + f = TVMPackedCFunc(cfun) + # NOTE: We will need to use python-api to increase ref count of the f + # TVM_FREE_PYOBJ will be called after it is no longer needed. + pyobj = ctypes.py_object(f) + ctypes.pythonapi.Py_IncRef(pyobj) + check_call(_LIB.TVMFuncCreateFromCFunc( + f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle))) + return Function(handle) + + +def _make_tvm_args(args, temp_args): + """Pack arguments into c args tvm call accept""" + num_args = len(args) + values = (TVMValue * num_args)() + type_codes = (ctypes.c_int * num_args)() + for i, arg in enumerate(args): + if arg is None: + values[i].v_handle = None + type_codes[i] = TypeCode.NULL + elif isinstance(arg, NDArrayBase): + values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) + type_codes[i] = TypeCode.ARRAY_HANDLE + elif isinstance(arg, NodeBase): + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.NODE_HANDLE + elif isinstance(arg, Integral): + values[i].v_int64 = arg + type_codes[i] = TypeCode.INT + elif isinstance(arg, Number): + values[i].v_float64 = arg + type_codes[i] = TypeCode.FLOAT + elif isinstance(arg, TVMType): + values[i].v_type = arg + type_codes[i] = TypeCode.TVM_TYPE + elif isinstance(arg, string_types): + values[i].v_str = c_str(arg) + type_codes[i] = TypeCode.STR + elif isinstance(arg, (list, tuple, dict, SliceBase)): + arg = convert_to_node(arg) + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.NODE_HANDLE + temp_args.append(arg) + elif isinstance(arg, Function): + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.FUNC_HANDLE + elif callable(arg): + arg = convert_to_tvm_func(arg) + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.FUNC_HANDLE + temp_args.append(arg) + else: + raise TypeError("Don't know how to handle type %s" % type(arg)) + return values, type_codes, num_args + + +class Function(object): + """A function object at runtime.""" + __slots__ = ["handle", "is_global"] + # pylint: disable=no-member + def __init__(self, handle, is_global=False): + """Initialize the function with handle + + Parameters + ---------- + handle : FunctionHandle + the handle to the underlying function. + + is_global : bool, optional + Whether it is global function + """ + self.handle = handle + self.is_global = is_global + + def __del__(self): + if not self.is_global: + check_call(_LIB.TVMFuncFree(self.handle)) + + def __call__(self, *args): + temp_args = [] + values, tcodes, num_args = _make_tvm_args(args, temp_args) + ret_val = TVMValue() + ret_tcode = ctypes.c_int() + check_call(_LIB.TVMFuncCall( + self.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode))) + _ = temp_args + _ = args + return RETURN_SWITCH[ret_tcode.value](ret_val) + + +def _handle_return_func(x): + """Return function""" + handle = x.v_handle + if not isinstance(handle, FunctionHandle): + handle = FunctionHandle(handle) + return Function(handle, False) + +# setup return handle for function type +RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func + +def register_func(func_name, f=None): + """Register global function + + Parameters + ---------- + func_name : str or function + The function name + + f : function + The function to be registered. + + Returns + ------- + fregister : function + Register function if f is not specified. + """ + if callable(func_name): + f = func_name + func_name = f.__name__ + + if not isinstance(func_name, str): + raise ValueError("expect string function name") + def register(myf): + """internal register function""" + if not isinstance(myf, Function): + myf = convert_to_tvm_func(myf) + check_call(_LIB.TVMFuncRegisterGlobal( + c_str(func_name), myf.handle)) + if f: + register(f) + else: + return register + + +def get_global_func(name): + """Get a global function by name + + Parameters + ---------- + name : str + The name of the global function + + Returns + ------- + func : tvm.nd.Function + The function to be returned. + """ + handle = FunctionHandle() + check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) + return Function(handle, True) + + +def list_global_func_names(): + """Get list of global functions registered. + + Returns + ------- + names : list + List of global functions names. + """ + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size), + ctypes.byref(plist))) + fnames = [] + for i in range(size.value): + fnames.append(py_str(plist[i])) + return fnames + + +def _init_api_functions(root_namespace): + """List and add all the functions to current module.""" + module_obj = sys.modules["%s.api" % root_namespace] + module_internal = sys.modules["%s._api_internal" % root_namespace] + namespace_match = { + "_make_": sys.modules["%s.make" % root_namespace], + "_pass_": sys.modules["%s.ir_pass" % root_namespace], + "_codegen_": sys.modules["%s.codegen" % root_namespace], + "_schedule_": sys.modules["%s.schedule" % root_namespace] + } + for name in list_global_func_names(): + fname = name + target_module = module_internal if name.startswith('_') else module_obj + for k, v in namespace_match.items(): + if name.startswith(k): + fname = name[len(k):] + target_module = v + f = get_global_func(name) + setattr(target_module, fname, f) diff --git a/python/tvm/_ctypes/_runtime_api.py b/python/tvm/_ctypes/_ndarray.py similarity index 78% rename from python/tvm/_ctypes/_runtime_api.py rename to python/tvm/_ctypes/_ndarray.py index dc4a64ee..b6fc4d4d 100644 --- a/python/tvm/_ctypes/_runtime_api.py +++ b/python/tvm/_ctypes/_ndarray.py @@ -1,18 +1,15 @@ # pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement # pylint: disable=attribute-defined-outside-init, no-member, missing-docstring """Symbolic configuration API.""" -from __future__ import absolute_import as _abs +from __future__ import absolute_import import ctypes -from numbers import Number, Integral import numpy as np -from .._base import _LIB -from .._base import c_array, c_str, string_types -from .._base import check_call -from ._types import TVMValue, TypeCode, TVMType +from .._base import _LIB, check_call +from .._base import c_array, c_str +from ._types import TVMType, tvm_index_t -tvm_index_t = ctypes.c_uint32 class TVMContext(ctypes.Structure): """TVM context strucure.""" @@ -39,6 +36,19 @@ class TVMContext(ctypes.Structure): return ret.value != 0 +class TVMArray(ctypes.Structure): + """TVMValue in C API""" + _fields_ = [("data", ctypes.c_void_p), + ("shape", ctypes.POINTER(tvm_index_t)), + ("strides", ctypes.POINTER(tvm_index_t)), + ("ndim", tvm_index_t), + ("dtype", TVMType), + ("ctx", TVMContext)] + + +TVMArrayHandle = ctypes.POINTER(TVMArray) + + def cpu(dev_id=0): """Construct a CPU device @@ -72,18 +82,6 @@ def opencl(dev_id=0): return TVMContext(4, dev_id) -class TVMArray(ctypes.Structure): - """TVMValue in C API""" - _fields_ = [("data", ctypes.c_void_p), - ("shape", ctypes.POINTER(tvm_index_t)), - ("strides", ctypes.POINTER(tvm_index_t)), - ("ndim", tvm_index_t), - ("dtype", TVMType), - ("ctx", TVMContext)] - -TVMArrayHandle = ctypes.POINTER(TVMArray) - - def numpyasarray(np_data): """Return a TVMArray representation of a numpy array. """ @@ -102,7 +100,6 @@ def numpyasarray(np_data): _ndarray_cls = None -_function_cls = None def empty(shape, dtype="float32", ctx=cpu(0)): @@ -275,51 +272,6 @@ class NDArrayBase(object): return target -class FunctionBase(object): - """A function object at runtim.""" - __slots__ = ["handle"] - # pylint: disable=no-member - def __init__(self, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : FunctionHandle - the handle to the underlying function. - """ - self.handle = handle - - def __del__(self): - check_call(_LIB.TVMFuncFree(self.handle)) - - def __call__(self, *args): - num_args = len(args) - tvm_args = (TVMValue * num_args)() - tvm_type_code = (ctypes.c_int * num_args)() - for i, arg in enumerate(args): - if arg is None: - tvm_args[i].v_handle = None - tvm_type_code[i] = TypeCode.NULL - elif isinstance(arg, NDArrayBase): - tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) - tvm_type_code[i] = TypeCode.HANDLE - elif isinstance(arg, Integral): - tvm_args[i].v_int64 = arg - tvm_type_code[i] = TypeCode.INT - elif isinstance(arg, Number): - tvm_args[i].v_float64 = arg - tvm_type_code[i] = TypeCode.FLOAT - elif isinstance(arg, string_types): - tvm_args[i].v_str = c_str(arg) - tvm_type_code[i] = TypeCode.STR - else: - raise TypeError("Don't know how to handle type %s" % type(arg)) - check_call(_LIB.TVMFuncCall( - self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args))) - - -def _init_runtime_module(ndarray_class, function_class): +def _init_ndarray_module(ndarray_class): global _ndarray_cls - global _function_cls _ndarray_cls = ndarray_class - _function_cls = function_class diff --git a/python/tvm/_ctypes/_node.py b/python/tvm/_ctypes/_node.py new file mode 100644 index 00000000..d91b9cac --- /dev/null +++ b/python/tvm/_ctypes/_node.py @@ -0,0 +1,184 @@ +# coding: utf-8 +# pylint: disable=invalid-name, protected-access +# pylint: disable=no-member, missing-docstring +"""Symbolic configuration API.""" +from __future__ import absolute_import + +import ctypes +from numbers import Number, Integral + +from .._base import _LIB, check_call +from .._base import c_str, py_str, string_types +from .. import _api_internal +from ._types import TVMValue, TypeCode, RETURN_SWITCH + +NodeHandle = ctypes.c_void_p + +"""Maps node type to its constructor""" +NODE_TYPE = { +} + +def _return_node(x): + """Return function""" + handle = x.v_handle + if not isinstance(handle, NodeHandle): + handle = NodeHandle(handle) + ret_val = TVMValue() + ret_type_code = ctypes.c_int() + ret_success = ctypes.c_int() + check_call(_LIB.TVMNodeGetAttr( + handle, c_str("type_key"), + ctypes.byref(ret_val), + ctypes.byref(ret_type_code), + ctypes.byref(ret_success))) + return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle) + + +RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node + + +class SliceBase(object): + """base class of slice object""" + pass + +class NodeBase(object): + """Symbol is symbolic graph.""" + __slots__ = ["handle"] + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + self.handle = handle + + def __repr__(self): + return _api_internal._format_str(self) + + def __del__(self): + check_call(_LIB.TVMNodeFree(self.handle)) + + def __getattr__(self, name): + ret_val = TVMValue() + ret_type_code = ctypes.c_int() + ret_success = ctypes.c_int() + check_call(_LIB.TVMNodeGetAttr( + self.handle, c_str(name), + ctypes.byref(ret_val), + ctypes.byref(ret_type_code), + ctypes.byref(ret_success))) + if not ret_success.value: + raise AttributeError( + "'%s' object has no attribute '%s'" % (str(type(self)), name)) + return RETURN_SWITCH[ret_type_code.value](ret_val) + + def __hash__(self): + return _api_internal._raw_ptr(self) + + def __eq__(self, other): + if not isinstance(other, NodeBase): + return False + return self.__hash__() == other.__hash__() + + def __ne__(self, other): + return not self.__eq__(other) + + def __dir__(self): + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + check_call(_LIB.TVMNodeListAttrNames( + self.handle, ctypes.byref(size), ctypes.byref(plist))) + names = [] + for i in range(size.value): + names.append(py_str(plist[i])) + return names + + def __reduce__(self): + return (type(self), (None,), self.__getstate__()) + + def __getstate__(self): + handle = self.handle + if handle is not None: + return {'handle': _api_internal._save_json(self)} + else: + return {'handle': None} + + def __setstate__(self, state): + # pylint: disable=assigning-non-slot + handle = state['handle'] + if handle is not None: + json_str = handle + other = _api_internal._load_json(json_str) + self.handle = other.handle + other.handle = None + else: + self.handle = None + + +def const(value, dtype=None): + """construct a constant""" + if dtype is None: + if isinstance(value, Integral): + dtype = 'int32' + else: + dtype = 'float32' + return _api_internal._const(value, dtype) + + +def convert_to_node(value): + """Convert a python value to corresponding node type. + + Parameters + ---------- + value : str + The value to be inspected. + + Returns + ------- + node : Node + The corresponding node value. + """ + if isinstance(value, NodeBase): + return value + elif isinstance(value, Number): + return const(value) + elif isinstance(value, string_types): + return _api_internal._str(value) + elif isinstance(value, (list, tuple)): + value = [convert_to_node(x) for x in value] + return _api_internal._Array(*value) + elif isinstance(value, dict): + vlist = [] + for it in value.items(): + if not isinstance(it[0], NodeBase): + raise ValueError("key of map must already been a container type") + vlist.append(it[0]) + vlist.append(convert_to_node(it[1])) + return _api_internal._Map(*vlist) + elif isinstance(value, SliceBase): + return value.tensor(*value.indices) + else: + raise ValueError("don't know how to convert type %s to node" % type(value)) + + +def register_node(type_key=None): + """register node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if isinstance(type_key, str): + def register(cls): + """internal register function""" + NODE_TYPE[type_key] = cls + return cls + return register + else: + cls = type_key + NODE_TYPE[cls.__name__] = cls + return cls diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py index b7d8e353..f6d0f5e7 100644 --- a/python/tvm/_ctypes/_types.py +++ b/python/tvm/_ctypes/_types.py @@ -4,13 +4,9 @@ from __future__ import absolute_import as _abs import ctypes import numpy as np +from .._base import py_str -class TVMValue(ctypes.Union): - """TVMValue in C API""" - _fields_ = [("v_int64", ctypes.c_int64), - ("v_float64", ctypes.c_double), - ("v_handle", ctypes.c_void_p), - ("v_str", ctypes.c_char_p)] +tvm_index_t = ctypes.c_uint32 class TypeCode(object): """Type code used in API calls""" @@ -19,9 +15,11 @@ class TypeCode(object): FLOAT = 2 HANDLE = 3 NULL = 4 - NODE_HANDLE = 5 - STR = 6 - FUNC_HANDLE = 7 + ARRAY_HANDLE = 5 + TVM_TYPE = 6 + NODE_HANDLE = 7 + STR = 8 + FUNC_HANDLE = 9 def _api_type(code): """create a type accepted by API""" @@ -40,13 +38,13 @@ class TVMType(ctypes.Structure): CODE2STR = { 0 : 'int', 1 : 'uint', - 2 : 'float' + 2 : 'float', + 4 : 'handle' } def __init__(self, type_str, lanes=1): super(TVMType, self).__init__() if isinstance(type_str, np.dtype): type_str = str(type_str) - if type_str.startswith("int"): self.type_code = 0 bits = int(type_str[3:]) @@ -56,6 +54,9 @@ class TVMType(ctypes.Structure): elif type_str.startswith("float"): self.type_code = 2 bits = int(type_str[5:]) + elif type_str.startswith("handle"): + self.type_code = 4 + bits = 64 else: raise ValueError("Donot know how to handle type %s" % type_str) @@ -71,15 +72,61 @@ class TVMType(ctypes.Structure): x += "x%d" % self.lanes return x + def __eq__(self, other): + return (self.bits == other.bits and + self.type_code == other.type_code and + self.lanes == other.lanes) + + def __ne__(self, other): + return not self.__eq__(other) + + +class TVMValue(ctypes.Union): + """TVMValue in C API""" + _fields_ = [("v_int64", ctypes.c_int64), + ("v_float64", ctypes.c_double), + ("v_handle", ctypes.c_void_p), + ("v_str", ctypes.c_char_p), + ("v_type", TVMType)] + TVMPackedCFunc = ctypes.CFUNCTYPE( None, ctypes.POINTER(TVMValue), ctypes.POINTER(ctypes.c_int), ctypes.c_int, + ctypes.c_void_p, ctypes.c_void_p) TVMCFuncFinalizer = ctypes.CFUNCTYPE( None, ctypes.c_void_p) + + +def _return_handle(x): + """return handle""" + handle = x.v_handle + if not isinstance(handle, ctypes.c_void_p): + handle = ctypes.c_void_p(handle) + return handle + + +RETURN_SWITCH = { + TypeCode.INT: lambda x: x.v_int64, + TypeCode.FLOAT: lambda x: x.v_float64, + TypeCode.HANDLE: _return_handle, + TypeCode.NULL: lambda x: None, + TypeCode.TVM_TYPE: lambda x: x.v_type, + TypeCode.STR: lambda x: py_str(x.v_str) +} + + +C_TO_PY_ARG_SWITCH = { + TypeCode.INT: lambda x: x.v_int64, + TypeCode.FLOAT: lambda x: x.v_float64, + TypeCode.HANDLE: _return_handle, + TypeCode.NULL: lambda x: None, + TypeCode.TVM_TYPE: lambda x: x.v_type, + TypeCode.STR: lambda x: py_str(x.v_str) +} diff --git a/python/tvm/api.py b/python/tvm/api.py index e181db9f..344ea107 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -2,16 +2,23 @@ # pylint: disable=redefined-builtin, undefined-variable, unused-import """Functions defined in TVM.""" from __future__ import absolute_import as _abs + from numbers import Integral as _Integral -from ._ctypes._api import _init_api_module, convert, register_func, get_global_func + +from ._ctypes._types import TVMType +from ._ctypes._node import register_node, NodeBase +from ._ctypes._node import convert_to_node as _convert_to_node +from ._ctypes._function import Function +from ._ctypes._function import _init_api_functions, register_func, get_global_func +from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func from . import _api_internal from . import make as _make from . import expr as _expr from . import collections as _collections -int32 = "int32" -float32 = "float32" -handle = "handle" +int32 = TVMType("int32") +float32 = TVMType("float32") +handle = TVMType("handle") def const(value, dtype=None): """construct a constant""" @@ -266,4 +273,25 @@ def Schedule(ops): return _api_internal._Schedule(ops) -_init_api_module("tvm") +def convert(value): + """Convert value to TVM node or function. + + Parameters + ---------- + value : python value + + Returns + ------- + tvm_val : Node or function + Converted value in TVM + """ + if isinstance(value, (Function, NodeBase)): + return value + + if callable(value): + return _convert_tvm_func(value) + else: + return _convert_to_node(value) + + +_init_api_functions("tvm") diff --git a/python/tvm/collections.py b/python/tvm/collections.py index 810b726c..b3bff80b 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -1,7 +1,7 @@ # pylint: disable=protected-access, no-member """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node +from ._ctypes._node import NodeBase, register_node from . import _api_internal from . import expr as _expr diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 3d75fa4c..c3b0845a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -1,6 +1,6 @@ # pylint: disable=protected-access, no-member, missing-docstring from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node +from ._ctypes._node import NodeBase, register_node from . import make as _make class ExprOp(object): diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index c3059662..8dbc22fd 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -6,11 +6,11 @@ This is a simplified runtime API for quick testing and proptyping. from __future__ import absolute_import as _abs import numpy as _np -from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase -from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync -from ._ctypes._runtime_api import _init_runtime_module -from ._ctypes._runtime_api import init_opencl - +from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase +from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync +from ._ctypes._ndarray import _init_ndarray_module +from ._ctypes._ndarray import init_opencl +from ._ctypes._function import Function class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. @@ -26,11 +26,6 @@ class NDArray(NDArrayBase): pass -class Function(FunctionBase): - """Function class that can executed a generated code.""" - pass - - def array(arr, ctx=cpu(0)): """Create an array from source arr. @@ -54,4 +49,4 @@ def array(arr, ctx=cpu(0)): return ret -_init_runtime_module(NDArray, Function) +_init_ndarray_module(NDArray) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 93767b4c..da5413e6 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -1,7 +1,7 @@ # pylint: disable=protected-access, no-member """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node +from ._ctypes._node import NodeBase, register_node from . import _api_internal from . import tensor as _tensor diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 97ef3b6b..49ce8cf9 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -1,6 +1,6 @@ # pylint: disable=protected-access, no-member, missing-docstring from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node +from ._ctypes._node import NodeBase, register_node class Stmt(NodeBase): pass diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 51767ee3..47a7ec88 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -1,7 +1,7 @@ # pylint: disable=protected-access, no-member, invalid-name """Tensor related abstractions""" from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, SliceBase, register_node, convert +from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node from . import collections as _collections from . import _api_internal from . import make as _make @@ -26,7 +26,7 @@ class Tensor(NodeBase): ndim = self.ndim if len(indices) != ndim: raise ValueError("Need to provide %d index in tensor slice" % ndim) - indices = convert(indices) + indices = convert_to_node(indices) args = [] for x in indices: if isinstance(x, _collections.IterVar): diff --git a/src/api/api_base.cc b/src/api/api_base.cc new file mode 100644 index 00000000..8f4e6f91 --- /dev/null +++ b/src/api/api_base.cc @@ -0,0 +1,37 @@ +/*! + * Copyright (c) 2017 by Contributors + * Implementation of basic API functions + * \file api_base.cc + */ +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_API(_format_str) +.set_body([](TVMArgs args, TVMRetValue *ret) { + CHECK(args[0].type_code() == kNodeHandle); + std::ostringstream os; + os << args[0].operator NodeRef(); + *ret = os.str(); + }); + +TVM_REGISTER_API(_raw_ptr) +.set_body([](TVMArgs args, TVMRetValue *ret) { + CHECK(args[0].type_code() == kNodeHandle); + *ret = reinterpret_cast( + args[0].node_sptr().get()); + }); + +TVM_REGISTER_API(_save_json) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SaveJSON(args[0]); + }); + +TVM_REGISTER_API(_load_json) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = NodeRef(LoadJSON_(args[0])); + }); + +} // namespace tvm diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc new file mode 100644 index 00000000..5f2c958c --- /dev/null +++ b/src/api/api_codegen.cc @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2016 by Contributors + * Implementation of API functions related to Codegen + * \file c_api_codegen.cc + */ +#include +#include +#include +#include +#include "../codegen/codegen_c.h" + +namespace tvm { +namespace codegen { + +TVM_REGISTER_API(_codegen_CompileToC) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CodeGenC().Compile(args[0], args[1]); + }); + +TVM_REGISTER_API(_codegen_MakeAPI) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = MakeAPI( + args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_API(_codegen_SplitHostDevice) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SplitHostDevice(args[0]); + }); + +// generate a dummy packed function for testing +void DummyHelloFunction(TVMArgs args, TVMRetValue* rv) { + LOG(INFO) << args.size() << " arguments"; + for (int i = 0; i < args.size(); ++i) { + switch (args.type_codes[i]) { + case kNull: LOG(INFO) << i << ":nullptr"; break; + case kFloat: LOG(INFO) << i << ": double=" << args.values[i].v_float64; break; + case kInt: LOG(INFO) << i << ": long=" << args.values[i].v_int64; break; + case kHandle: LOG(INFO) << i << ": handle=" << args.values[i].v_handle; break; + case kArrayHandle: LOG(INFO) << i << ": array_handle=" << args.values[i].v_handle; break; + default: LOG(FATAL) << "unhandled type " << runtime::TypeCode2Str(args.type_codes[i]); + } + } +} + +TVM_REGISTER_API(_codegen_DummyHelloFunction) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = runtime::PackedFunc(DummyHelloFunction); + }); + +TVM_REGISTER_API(_codegen_BuildStackVM) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BuildStackVM(args[0]); + }); + +} // namespace codegen +} // namespace tvm diff --git a/src/c_api/c_api_ir.cc b/src/api/api_ir.cc similarity index 51% rename from src/c_api/c_api_ir.cc rename to src/api/api_ir.cc index 41739032..af079580 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/api/api_ir.cc @@ -1,98 +1,93 @@ /*! * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build - * \file c_api_ir.cc + * \file api_ir.cc */ #include #include #include -#include "./c_api_registry.h" +#include namespace tvm { namespace ir { -using ArgStack = const std::vector; -using RetValue = APIVariantValue; - TVM_REGISTER_API(_Var) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Variable::make(args.at(1), args.at(0)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Variable::make(args[1], args[0]); }); TVM_REGISTER_API(_make_For) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = For::make(args.at(0), - args.at(1), - args.at(2), - static_cast(args.at(3).operator int()), - static_cast(args.at(4).operator int()), - args.at(5)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = For::make(args[0], + args[1], + args[2], + static_cast(args[3].operator int()), + static_cast(args[4].operator int()), + args[5]); }); TVM_REGISTER_API(_make_Realize) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Realize::make(args.at(0), - args.at(1), - args.at(2), - args.at(3), - args.at(4), - args.at(5)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Realize::make(args[0], + args[1], + args[2], + args[3], + args[4], + args[5]); }); TVM_REGISTER_API(_make_Call) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Call::make(args.at(0), - args.at(1), - args.at(2), - static_cast(args.at(3).operator int()), - args.at(4), - args.at(5)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Call::make(args[0], + args[1], + args[2], + static_cast(args[3].operator int()), + args[4], + args[5]); }); TVM_REGISTER_API(_make_Allocate) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Allocate::make(args.at(0), - args.at(1), - args.at(2), - args.at(3), - args.at(4)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Allocate::make(args[0], + args[1], + args[2], + args[3], + args[4]); }); // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API(_make_## Node) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = Node::make(args.at(0)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = Node::make(args[0]); \ }) \ #define REGISTER_MAKE2(Node) \ TVM_REGISTER_API(_make_## Node) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = Node::make(args.at(0), args.at(1)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = Node::make(args[0], args[1]); \ }) \ #define REGISTER_MAKE3(Node) \ TVM_REGISTER_API(_make_## Node) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = Node::make(args.at(0), args.at(1), args.at(2)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = Node::make(args[0], args[1], args[2]); \ }) \ -#define REGISTER_MAKE4(Node) \ +#define REGISTER_MAKE4(Node) \ TVM_REGISTER_API(_make_## Node) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ -*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \ - }) \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = Node::make(args[0], args[1], args[2], args[3]); \ + }) \ #define REGISTER_MAKE_BINARY_OP(Node) \ TVM_REGISTER_API(_make_## Node) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - Expr a = args.at(0), b = args.at(1); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + Expr a = args[0], b = args[1]; \ match_types(a, b); \ *ret = Node::make(a, b); \ - }) \ - .add_argument("lhs", "Expr", "left operand") \ - .add_argument("rhs", "Expr", "right operand") + }) REGISTER_MAKE3(Reduce); REGISTER_MAKE4(AttrStmt); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc new file mode 100644 index 00000000..e570e55c --- /dev/null +++ b/src/api/api_lang.cc @@ -0,0 +1,256 @@ +/*! + * Copyright (c) 2016 by Contributors + * Implementation of API functions related to Higher DSL build. + * \file api_lang.cc + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_API(_const) +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[0].type_code() == kInt) { + *ret = make_const(args[1], args[0].operator int64_t()); + } else if (args[0].type_code() == kFloat) { + *ret = make_const(args[1], args[0].operator double()); + } else { + LOG(FATAL) << "only accept int or float"; + } + }); + + +TVM_REGISTER_API(_str) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ir::StringImm::make(args[0]); +}); + + +TVM_REGISTER_API(_Array) +.set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector > data; + for (int i = 0; i < args.size(); ++i) { + data.push_back(args[i].node_sptr()); + } + auto node = std::make_shared(); + node->data = std::move(data); + *ret = node; + }); + +TVM_REGISTER_API(_ArrayGetItem) +.set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + CHECK_LT(static_cast(i), n->data.size()) + << "out of bound of array"; + *ret = n->data[i]; + }); + +TVM_REGISTER_API(_ArraySize) +.set_body([](TVMArgs args, TVMRetValue* ret) { + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + *ret = static_cast( + static_cast(sptr.get())->data.size()); + }); + +TVM_REGISTER_API(_Map) +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size() % 2, 0); + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].type_code() == kNodeHandle) + << "need content of array to be NodeBase"; + CHECK(args[i + 1].type_code() == kNodeHandle) + << "need content of array to be NodeBase"; + data.emplace(std::make_pair(args[i].node_sptr(), + args[i + 1].node_sptr())); + } + auto node = std::make_shared(); + node->data = std::move(data); + *ret = node; + }); + +TVM_REGISTER_API(_MapSize) +.set_body([](TVMArgs args, TVMRetValue* ret) { + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + *ret = static_cast(n->data.size()); + }); + +TVM_REGISTER_API(_MapGetItem) +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[1].type_code() == kNodeHandle); + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + auto it = n->data.find(args[1].node_sptr()); + CHECK(it != n->data.end()) + << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + }); + +TVM_REGISTER_API(_MapCount) +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[1].type_code() == kNodeHandle); + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + *ret = static_cast( + n->data.count(args[1].node_sptr())); + }); + +TVM_REGISTER_API(_MapItems) +.set_body([](TVMArgs args, TVMRetValue* ret) { + auto& sptr = args[0].node_sptr(); + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + auto rkvs = std::make_shared(); + for (const auto& kv : n->data) { + rkvs->data.push_back(kv.first); + rkvs->data.push_back(kv.second); + } + *ret = rkvs; + }); + +TVM_REGISTER_API(Range) +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = Range(0, args[0]); + } else { + *ret = Range(args[0], args[1]); + } + }); + +TVM_REGISTER_API(_Buffer) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = BufferNode::make(args[0], + args[1], + args[2], + args[3], + args[4]); + }); + +TVM_REGISTER_API(_Tensor) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorNode::make(args[0], + args[1], + args[2], + args[3]); + }); + +TVM_REGISTER_API(_TensorEqual) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Tensor() == args[1].operator Tensor(); + }); + +TVM_REGISTER_API(_TensorHash) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = static_cast( + std::hash()(args[0].operator Tensor())); + }); + +TVM_REGISTER_API(_Placeholder) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Placeholder(args[0], + args[1], + args[2]); + }); + +TVM_REGISTER_API(_ComputeOp) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ComputeOpNode::make(args[0], + args[1], + args[2]); + }); + +TVM_REGISTER_API(_OpGetOutput) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Operation().output( + args[1].operator int64_t()); + }); + + +TVM_REGISTER_API(_IterVar) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = IterVar(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API(_Schedule) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Schedule(args[0].operator Array()); + }); + +TVM_REGISTER_API(_StageSetScope) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .set_scope(args[1]); + }); + +TVM_REGISTER_API(_StageSplitByFactor) +.set_body([](TVMArgs args, TVMRetValue* ret) { + IterVar outer, inner; + args[0].operator Stage() + .split(args[1], &outer, &inner, args[2]); + *ret = Array({outer, inner}); + }); + +TVM_REGISTER_API(_StageSplitByOuter) +.set_body([](TVMArgs args, TVMRetValue* ret) { + IterVar inner; + args[0].operator Stage() + .split(args[1], args[2], &inner, args[3]); + *ret = inner; + }); + +TVM_REGISTER_API(_StageFuse) +.set_body([](TVMArgs args, TVMRetValue* ret) { + IterVar fused; + args[0].operator Stage() + .split(args[1], args[2], &fused); + *ret = fused; + }); + +TVM_REGISTER_API(_StageComputeAt) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .compute_at(args[1], args[2]); + }); + +TVM_REGISTER_API(_StageComputeInline) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .compute_inline(); + }); + +TVM_REGISTER_API(_StageComputeRoot) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .compute_root(); + }); + +TVM_REGISTER_API(_StageReorder) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .reorder(args[1]); + }); + +TVM_REGISTER_API(_StageTile) + .set_body([](TVMArgs args, TVMRetValue* ret) { + IterVar x_outer, y_outer, x_inner, y_inner; + args[0].operator Stage() + .tile(args[1], args[2], &x_outer, &y_outer, + &x_inner, &y_inner, args[3], args[4]); + *ret = Array({x_outer, y_outer, x_inner, y_inner}); + }); + +} // namespace tvm diff --git a/src/c_api/c_api_pass.cc b/src/api/api_pass.cc similarity index 51% rename from src/c_api/c_api_pass.cc rename to src/api/api_pass.cc index d7d41c4f..0a88e3b1 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/api/api_pass.cc @@ -1,53 +1,51 @@ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2017 by Contributors * Exposre of pass functions. - * \file c_api_pass.cc + * \file api_pass.cc */ #include #include #include -#include "./c_api_registry.h" +#include namespace tvm { namespace ir { -using ArgStack = const std::vector; -using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) -.set_body([](const ArgStack& args, RetValue *ret) { - if (NodeTypeChecker::Check(args.at(0).sptr.get())) { - *ret = Simplify(args.at(0).operator Stmt()); +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + *ret = Simplify(args[0].operator Stmt()); } else { - *ret = Simplify(args.at(0).operator Expr()); + *ret = Simplify(args[0].operator Expr()); } }); TVM_REGISTER_API(_pass_Equal) -.set_body([](const ArgStack& args, RetValue *ret) { - if (NodeTypeChecker::Check(args.at(0).sptr.get())) { - *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { - *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); + *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); } }); // make from two arguments #define REGISTER_PASS1(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = PassName(args.at(0)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = PassName(args[0]); \ }) \ #define REGISTER_PASS2(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = PassName(args.at(0), args.at(1)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = PassName(args[0], args[1]); \ }) \ #define REGISTER_PASS4(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = PassName(args[0], args[1], args[2], args[3]); \ }) \ REGISTER_PASS1(ConvertSSA); diff --git a/src/api/api_registry.cc b/src/api/api_registry.cc new file mode 100644 index 00000000..d4ab9028 --- /dev/null +++ b/src/api/api_registry.cc @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file api_registry.cc + */ +#include +#include +#include +#include + +namespace tvm { + +struct APIManager { + std::unordered_map > fmap; + + static APIManager* Global() { + static APIManager inst; + return &inst; + } +}; + +APIRegistry& APIRegistry::__REGISTER__(const std::string& name) { // NOLINT(*) + APIManager* m = APIManager::Global(); + CHECK(!m->fmap.count(name)) + << "API function " << name << " has already been registered"; + std::unique_ptr p(new APIRegistry()); + p->name_ = name; + m->fmap[name] = std::move(p); + return *(m->fmap[name]); +} + +APIRegistry& APIRegistry::set_body(PackedFunc f) { // NOLINT(*) + PackedFunc::RegisterGlobal(name_, f); + return *this; +} +} // namespace tvm diff --git a/src/c_api/c_api_schedule.cc b/src/api/api_schedule.cc similarity index 65% rename from src/c_api/c_api_schedule.cc rename to src/api/api_schedule.cc index 6ee41b2d..a84642f9 100644 --- a/src/c_api/c_api_schedule.cc +++ b/src/api/api_schedule.cc @@ -1,30 +1,28 @@ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2017 by Contributors * Implementation of API functions related to schedule pass. - * \file c_api_lang.cc + * \file api_schedule.cc */ #include #include #include #include -#include "./c_api_registry.h" +#include #include "../schedule/graph.h" namespace tvm { namespace schedule { -using ArgStack = const std::vector; -using RetValue = APIVariantValue; #define REGISTER_SCHEDULE_PASS1(PassName) \ TVM_REGISTER_API(_schedule_## PassName) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = PassName(args.at(0)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = PassName(args[0]); \ }) \ #define REGISTER_SCHEDULE_PASS2(PassName) \ TVM_REGISTER_API(_schedule_## PassName) \ - .set_body([](const ArgStack& args, RetValue *ret) { \ - *ret = PassName(args.at(0), args.at(1)); \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + *ret = PassName(args[0], args[1]); \ }) \ diff --git a/src/api/c_api.cc b/src/api/c_api.cc new file mode 100644 index 00000000..c4290c57 --- /dev/null +++ b/src/api/c_api.cc @@ -0,0 +1,153 @@ +/*! + * Copyright (c) 2016 by Contributors + * Implementation of C API + * \file c_api.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../runtime/runtime_base.h" + + +/*! \brief entry to to easily hold returning information */ +struct TVMAPIThreadLocalEntry { + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; + /*! \brief result holder for retruning string */ + std::string ret_str; +}; + +using namespace tvm; + +/*! \brief Thread local store that can be used to hold return values. */ +typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; + +using TVMAPINode = std::shared_ptr; + +struct APIAttrGetter : public AttrVisitor { + std::string skey; + TVMRetValue* ret; + bool found_node_ref{false}; + + void Visit(const char* key, double* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, int64_t* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, uint64_t* value) final { + CHECK_LE(value[0], static_cast(std::numeric_limits::max())) + << "cannot return too big constant"; + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, int* value) final { + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, bool* value) final { + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, Type* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, std::string* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, NodeRef* value) final { + if (skey == key) { + *ret = value[0]; + found_node_ref = true; + } + } +}; + +struct APIAttrDir : public AttrVisitor { + std::vector* names; + + void Visit(const char* key, double* value) final { + names->push_back(key); + } + void Visit(const char* key, int64_t* value) final { + names->push_back(key); + } + void Visit(const char* key, uint64_t* value) final { + names->push_back(key); + } + void Visit(const char* key, bool* value) final { + names->push_back(key); + } + void Visit(const char* key, int* value) final { + names->push_back(key); + } + void Visit(const char* key, Type* value) final { + names->push_back(key); + } + void Visit(const char* key, std::string* value) final { + names->push_back(key); + } + void Visit(const char* key, NodeRef* value) final { + names->push_back(key); + } +}; + + +int TVMNodeFree(NodeHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} + +int TVMNodeGetAttr(NodeHandle handle, + const char* key, + TVMValue* ret_val, + int* ret_type_code, + int* ret_success) { + API_BEGIN(); + TVMRetValue rv; + APIAttrGetter getter; + getter.skey = key; + getter.ret = &rv; + TVMAPINode* tnode = static_cast(handle); + if (getter.skey == "type_key") { + ret_val->v_str = (*tnode)->type_key(); + *ret_type_code = kStr; + *ret_success = 1; + } else { + (*tnode)->VisitAttrs(&getter); + *ret_success = getter.found_node_ref || rv.type_code() != kNull; + if (rv.type_code() == kStr) { + TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); + e->ret_str = rv.operator std::string(); + *ret_type_code = kStr; + ret_val->v_str = e->ret_str.c_str(); + } else { + rv.MoveToCHost(ret_val, ret_type_code); + } + } + API_END(); +} + +int TVMNodeListAttrNames(NodeHandle handle, + int *out_size, + const char*** out_array) { + TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); + API_BEGIN(); + ret->ret_vec_str.clear(); + TVMAPINode* tnode = static_cast(handle); + APIAttrDir dir; + dir.names = &(ret->ret_vec_str); + (*tnode)->VisitAttrs(&dir); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_array = dmlc::BeginPtr(ret->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_str.size()); + API_END(); +} diff --git a/src/base/common.h b/src/base/common.h index 432ec74d..4b1c799e 100644 --- a/src/base/common.h +++ b/src/base/common.h @@ -42,5 +42,78 @@ inline Type String2Type(std::string s) { return Type(code, bits, lanes); } +inline const char* TVMTypeCode2Str(int type_code) { + switch (type_code) { + case kInt: return "int"; + case kFloat: return "float"; + case kStr: return "str"; + case kHandle: return "Handle"; + case kNull: return "NULL"; + case kNodeHandle: return "NodeHandle"; + default: LOG(FATAL) << "unknown type_code=" + << static_cast(type_code); return ""; + } +} +template +struct NodeTypeChecker { + static inline bool Check(Node* sptr) { + // This is the only place in the project where RTTI is used + // It can be turned off, but will make non strict checking. + // TODO(tqchen) possibly find alternative to turn of RTTI + using ContainerType = typename T::ContainerType; + return (dynamic_cast(sptr) != nullptr); + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + using ContainerType = typename T::ContainerType; + os << ContainerType::_type_key; + } +}; + +template +struct NodeTypeChecker > { + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; + ArrayNode* n = static_cast(sptr); + for (const auto& p : n->data) { + if (!NodeTypeChecker::Check(p.get())) return false; + } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "array<"; + NodeTypeChecker::PrintName(os); + os << ">"; + } +}; + +template +struct NodeTypeChecker > { + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; + MapNode* n = static_cast(sptr); + for (const auto& kv : n->data) { + if (!NodeTypeChecker::Check(kv.first.get())) return false; + if (!NodeTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "map<"; + NodeTypeChecker::PrintName(os); + os << ','; + NodeTypeChecker::PrintName(os); + os << '>'; + } +}; + +template +inline std::string NodeTypeName() { + std::ostringstream os; + NodeTypeChecker::PrintName(os); + return os.str(); +} + } // namespace tvm #endif // TVM_BASE_COMMON_H_ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc deleted file mode 100644 index 7b360594..00000000 --- a/src/c_api/c_api.cc +++ /dev/null @@ -1,260 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * Implementation of C API - * \file c_api.cc - */ -#include -#include "./c_api_common.h" -#include "./c_api_registry.h" - -/*! \brief entry to to easily hold returning information */ -struct TVMAPIThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; - /*! \brief argument stack */ - std::vector arg_stack; - /*! \brief return value */ - tvm::APIVariantValue ret_value; - // clear calling stack - inline void Clear() { - arg_stack.clear(); - ret_value.sptr.reset(); - } - inline void SetReturn(TVMValue* ret_val, int* ret_type_code); -}; - -using namespace tvm; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; - -using TVMAPINode = std::shared_ptr; - -struct APIAttrGetter : public AttrVisitor { - std::string skey; - APIVariantValue* ret; - bool found_node_ref{false}; - - void Visit(const char* key, double* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, int64_t* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, uint64_t* value) final { - CHECK_LE(value[0], static_cast(std::numeric_limits::max())) - << "cannot return too big constant"; - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, int* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, bool* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, Type* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, std::string* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, NodeRef* value) final { - if (skey == key) { - *ret = value[0]; - found_node_ref = true; - } - } -}; - -struct APIAttrDir : public AttrVisitor { - std::vector* names; - - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, Type* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, NodeRef* value) final { - names->push_back(key); - } -}; - -int TVMListAPIFuncNames(int *out_size, - const char*** out_array) { - API_BEGIN(); - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str = dmlc::Registry::ListAllNames(); - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); - API_END(); -} - -int TVMGetAPIFuncHandle(const char* fname, - APIFuncHandle* out) { - API_BEGIN(); - const APIFuncReg* reg = dmlc::Registry::Find(fname); - CHECK(reg != nullptr) << "cannot find function " << fname; - *out = (APIFuncHandle)reg; - API_END(); -} - -int TVMGetAPIFuncInfo(APIFuncHandle handle, - const char **real_name, - const char **description, - int *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - const auto *op = static_cast(handle); - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - - API_BEGIN(); - *real_name = op->name.c_str(); - *description = op->description.c_str(); - *num_doc_args = static_cast(op->arguments.size()); - if (return_type) *return_type = nullptr; - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < op->arguments.size(); ++i) { - ret->ret_vec_charp.push_back(op->arguments[i].name.c_str()); - } - for (size_t i = 0; i < op->arguments.size(); ++i) { - ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str()); - } - for (size_t i = 0; i < op->arguments.size(); ++i) { - ret->ret_vec_charp.push_back(op->arguments[i].description.c_str()); - } - *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); - *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size(); - *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2); - API_END(); -} - -int TVMAPIPushStack(TVMValue arg, - int type_code) { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - API_BEGIN(); - ret->arg_stack.resize(ret->arg_stack.size() + 1); - APIVariantValue& v = ret->arg_stack.back(); - - v.type_code = type_code; - switch (type_code) { - case kInt: case kUInt: case kFloat: case kNull: { - v.v_union = arg; break; - } - case kStr: { - v.str = arg.v_str; break; - } - case kNodeHandle: { - v.sptr = *static_cast(arg.v_handle); break; - } - default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code); - } - API_END_HANDLE_ERROR(ret->Clear()); -} - -int TVMAPIFuncCall(APIFuncHandle handle, - TVMValue* ret_val, - int* ret_type_code) { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - API_BEGIN(); - const auto *op = static_cast(handle); - op->body(ret->arg_stack, &(ret->ret_value)); - ret->SetReturn(ret_val, ret_type_code); - ret->arg_stack.clear(); - API_END_HANDLE_ERROR(ret->Clear()); -} - -int TVMNodeFree(NodeHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* ret_val, - int* ret_type_code, - int* ret_success) { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - API_BEGIN(); - ret->ret_value.type_code = kNull; - APIAttrGetter getter; - getter.skey = key; - getter.ret = &(ret->ret_value); - TVMAPINode* tnode = static_cast(handle); - if (getter.skey == "type_key") { - ret_val->v_str = (*tnode)->type_key(); - *ret_type_code = kStr; - *ret_success = 1; - } else { - (*tnode)->VisitAttrs(&getter); - if (ret->ret_value.type_code != kNull) { - ret->SetReturn(ret_val, ret_type_code); - *ret_success = 1; - } else { - *ret_success = getter.found_node_ref ? 1 : 0; - *ret_type_code = kNull; - } - } - API_END_HANDLE_ERROR(ret->Clear()); -} - -int TVMNodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array) { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - API_BEGIN(); - ret->ret_vec_str.clear(); - TVMAPINode* tnode = static_cast(handle); - APIAttrDir dir; - dir.names = &(ret->ret_vec_str); - (*tnode)->VisitAttrs(&dir); - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); - API_END(); -} - - -inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val, - int* ret_type_code) { - APIVariantValue& rv = ret_value; - *ret_type_code = rv.type_code; - if (rv.type_code == kNodeHandle) { - if (rv.sptr.get() != nullptr) { - ret_val->v_handle = new TVMAPINode(std::move(rv.sptr)); - } else { - ret_val->v_handle = nullptr; - } - } else if (rv.type_code == kFuncHandle) { - ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func)); - } else { - *ret_val = rv.v_union; - } -} diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc deleted file mode 100644 index a198365f..00000000 --- a/src/c_api/c_api_codegen.cc +++ /dev/null @@ -1,61 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * Implementation of API functions related to Codegen - * \file c_api_codegen.cc - */ -#include -#include -#include - -#include "./c_api_registry.h" -#include "../codegen/codegen_c.h" - -namespace tvm { -namespace codegen { - -using ArgStack = const std::vector; -using RetValue = APIVariantValue; - -TVM_REGISTER_API(_codegen_CompileToC) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = CodeGenC().Compile(args.at(0), args.at(1)); - }); - -TVM_REGISTER_API(_codegen_MakeAPI) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = MakeAPI( - args.at(0), args.at(1), args.at(2), args.at(3)); - }); - -TVM_REGISTER_API(_codegen_SplitHostDevice) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = SplitHostDevice(args.at(0)); - }); - - -// generate a dummy packed function for testing -void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) { - LOG(INFO) << num_args << " arguments"; - for (int i = 0; i < num_args; ++i) { - switch (type_code[i]) { - case kNull: LOG(INFO) << i << ":nullptr"; break; - case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break; - case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break; - case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break; - default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]); - } - } -} - -TVM_REGISTER_API(_codegen_DummyHelloFunction) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = runtime::PackedFunc(DummyHelloFunction); - }); - -TVM_REGISTER_API(_codegen_BuildStackVM) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = BuildStackVM(args.at(0)); - }); - -} // namespace codegen -} // namespace tvm diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h deleted file mode 100644 index 55b3fe0b..00000000 --- a/src/c_api/c_api_common.h +++ /dev/null @@ -1,19 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file c_api_common.h - * \brief Common fields of all C APIs - */ -#ifndef TVM_C_API_C_API_COMMON_H_ -#define TVM_C_API_C_API_COMMON_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "./c_api_registry.h" -#include "../runtime/runtime_base.h" - -#endif // TVM_C_API_C_API_COMMON_H_ diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc deleted file mode 100644 index f550804a..00000000 --- a/src/c_api/c_api_function.cc +++ /dev/null @@ -1,47 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * Implementation of API functions - * \file c_api_impl.cc - */ -#include -#include -#include "./c_api_registry.h" - -namespace dmlc { -DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg); -} // namespace dmlc - -namespace tvm { - -using ArgStack = const std::vector; -using RetValue = APIVariantValue; - -TVM_REGISTER_API(_format_str) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - std::ostringstream os; - os << args.at(0).operator NodeRef(); - *ret = os.str(); - }) -.add_argument("expr", "Node", "expression to be printed"); - -TVM_REGISTER_API(_raw_ptr) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - *ret = reinterpret_cast(args.at(0).sptr.get()); - }) -.add_argument("src", "NodeBase", "the node base"); - -TVM_REGISTER_API(_save_json) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = SaveJSON(args.at(0)); - }) -.add_argument("src", "json_str", "the node "); - -TVM_REGISTER_API(_load_json) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = NodeRef(LoadJSON_(args.at(0))); - }) -.add_argument("src", "NodeBase", "the node"); - -} // namespace tvm diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc deleted file mode 100644 index 2c65d81b..00000000 --- a/src/c_api/c_api_lang.cc +++ /dev/null @@ -1,273 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * Implementation of API functions related to Higher DSL build. - * \file c_api_lang.cc - */ -#include -#include -#include -#include -#include -#include "./c_api_registry.h" - -namespace tvm { - -using ArgStack = const std::vector; -using RetValue = APIVariantValue; - -TVM_REGISTER_API(_const) -.set_body([](const ArgStack& args, RetValue *ret) { - if (args.at(0).type_code == kInt) { - *ret = make_const(args.at(1), args.at(0).operator int64_t()); - } else if (args.at(0).type_code == kFloat) { - *ret = make_const(args.at(1), args.at(0).operator double()); - } else { - LOG(FATAL) << "only accept int or float"; - } - }) -.add_argument("src", "Number", "source number") -.add_argument("dtype", "str", "data type"); - - -TVM_REGISTER_API(_str) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = ir::StringImm::make(args.at(0)); -}); - - -TVM_REGISTER_API(_Array) -.set_body([](const ArgStack& args, RetValue *ret) { - std::vector > data; - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args.at(i).type_code == kNodeHandle) - << "need content of array to be NodeBase"; - data.push_back(args.at(i).sptr); - } - auto node = std::make_shared(); - node->data = std::move(data); - ret->type_code = kNodeHandle; - ret->sptr = node; - }); - -TVM_REGISTER_API(_ArrayGetItem) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - int64_t i = args.at(1); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - CHECK_LT(static_cast(i), n->data.size()) - << "out of bound of array"; - ret->sptr = n->data[i]; - ret->type_code = kNodeHandle; - }); - -TVM_REGISTER_API(_ArraySize) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - *ret = static_cast( - static_cast(sptr.get())->data.size()); - }); - -TVM_REGISTER_API(_Map) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK_EQ(args.size() % 2, 0U); - MapNode::ContainerType data; - for (size_t i = 0; i < args.size(); i += 2) { - CHECK(args.at(i).type_code == kNodeHandle) - << "need content of array to be NodeBase"; - CHECK(args.at(i + 1).type_code == kNodeHandle) - << "need content of array to be NodeBase"; - data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr)); - } - auto node = std::make_shared(); - node->data = std::move(data); - ret->type_code = kNodeHandle; - ret->sptr = node; - }); - -TVM_REGISTER_API(_MapSize) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - *ret = static_cast(n->data.size()); - }); - -TVM_REGISTER_API(_MapGetItem) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - CHECK(args.at(1).type_code == kNodeHandle); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args.at(1).sptr); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - ret->sptr = (*it).second; - ret->type_code = kNodeHandle; - }); - -TVM_REGISTER_API(_MapCount) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - CHECK(args.at(1).type_code == kNodeHandle); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - *ret = static_cast(n->data.count(args.at(1).sptr)); - }); - -TVM_REGISTER_API(_MapItems) -.set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_code == kNodeHandle); - auto& sptr = args.at(0).sptr; - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - auto rkvs = std::make_shared(); - for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); - } - ret->sptr = rkvs; - ret->type_code = kNodeHandle; - }); - -TVM_REGISTER_API(Range) -.set_body([](const ArgStack& args, RetValue *ret) { - if (args.size() == 1) { - *ret = Range(0, args.at(0)); - } else { - *ret = Range(args.at(0), args.at(1)); - } - }) -.describe("create a domain range") -.add_argument("begin", "Expr", "beginning of the range.") -.add_argument("end", "Expr", "extent of the range"); - -TVM_REGISTER_API(_Buffer) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = BufferNode::make(args.at(0), - args.at(1), - args.at(2), - args.at(3), - args.at(4)); - }); - -TVM_REGISTER_API(_Tensor) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = TensorNode::make(args.at(0), - args.at(1), - args.at(2), - args.at(3)); - }); - -TVM_REGISTER_API(_TensorEqual) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = args.at(0).operator Tensor() == args.at(1).operator Tensor(); - }); - -TVM_REGISTER_API(_TensorHash) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = static_cast( - std::hash()(args.at(0).operator Tensor())); - }); - -TVM_REGISTER_API(_Placeholder) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Placeholder(args.at(0), - args.at(1), - args.at(2)); - }); - -TVM_REGISTER_API(_ComputeOp) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = ComputeOpNode::make(args.at(0), - args.at(1), - args.at(2)); - }); - -TVM_REGISTER_API(_OpGetOutput) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = args.at(0).operator Operation().output( - args.at(1).operator int64_t()); - }); - - -TVM_REGISTER_API(_IterVar) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = IterVar(args.at(0), args.at(1), args.at(2)); - }); - -TVM_REGISTER_API(_Schedule) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Schedule(args.at(0).operator Array()); - }); - -TVM_REGISTER_API(_StageSetScope) -.set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Stage() - .set_scope(args.at(1)); - }); - -TVM_REGISTER_API(_StageSplitByFactor) -.set_body([](const ArgStack& args, RetValue *ret) { - IterVar outer, inner; - args.at(0).operator Stage() - .split(args.at(1), &outer, &inner, args.at(2)); - *ret = Array({outer, inner}); - }); - -TVM_REGISTER_API(_StageSplitByOuter) -.set_body([](const ArgStack& args, RetValue *ret) { - IterVar inner; - args.at(0).operator Stage() - .split(args.at(1), args.at(2), &inner, args.at(3)); - *ret = inner; - }); - -TVM_REGISTER_API(_StageFuse) -.set_body([](const ArgStack& args, RetValue *ret) { - IterVar fused; - args.at(0).operator Stage() - .split(args.at(1), args.at(2), &fused); - *ret = fused; - }); - -TVM_REGISTER_API(_StageComputeAt) -.set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Stage() - .compute_at(args.at(1), args.at(2)); - }); - -TVM_REGISTER_API(_StageComputeInline) -.set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Stage() - .compute_inline(); - }); - -TVM_REGISTER_API(_StageComputeRoot) -.set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Stage() - .compute_root(); - }); - -TVM_REGISTER_API(_StageReorder) -.set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Stage() - .reorder(args.at(1)); - }); - -TVM_REGISTER_API(_StageTile) - .set_body([](const ArgStack& args, RetValue *ret) { - IterVar x_outer, y_outer, x_inner, y_inner; - args.at(0).operator Stage() - .tile(args.at(1), args.at(2), &x_outer, &y_outer, - &x_inner, &y_inner, args.at(3), args.at(4)); - *ret = Array({x_outer, y_outer, x_inner, y_inner}); - }); - -} // namespace tvm diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h deleted file mode 100644 index 8852b937..00000000 --- a/src/c_api/c_api_registry.h +++ /dev/null @@ -1,240 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file c_api_registry.h - * \brief Quick registry for C API. - */ -#ifndef TVM_C_API_C_API_REGISTRY_H_ -#define TVM_C_API_C_API_REGISTRY_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "../base/common.h" - -namespace tvm { - -inline const char* TVMTypeCode2Str(int type_code) { - switch (type_code) { - case kInt: return "int"; - case kFloat: return "float"; - case kStr: return "str"; - case kHandle: return "Handle"; - case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; - } -} - -template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI - using ContainerType = typename T::ContainerType; - return (dynamic_cast(sptr) != nullptr); - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - using ContainerType = typename T::ContainerType; - os << ContainerType::_type_key; - } -}; - -template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); - for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) return false; - } - return true; - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; - } -}; - -template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); - for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); - os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; - } -}; - -template -inline std::string NodeTypeName() { - std::ostringstream os; - NodeTypeChecker::PrintName(os); - return os.str(); -} - -/*! \brief Variant container for API calls */ -class APIVariantValue { - public: - /*! \brief the type id */ - int type_code{kNull}; - /*! \brief shared pointer container */ - std::shared_ptr sptr; - /*! \brief string container */ - std::string str; - /*! \brief the variant holder */ - TVMValue v_union; - /*! \brief std::function */ - runtime::PackedFunc::FType func; - // constructor - APIVariantValue() { - } - // clear value - inline void Clear() { - } - // assign op - inline APIVariantValue& operator=(double value) { - type_code = kFloat; - v_union.v_float64 = value; - return *this; - } - inline APIVariantValue& operator=(std::nullptr_t value) { - type_code = kHandle; - v_union.v_handle = value; - return *this; - } - inline APIVariantValue& operator=(int64_t value) { - type_code = kInt; - v_union.v_int64 = value; - return *this; - } - inline APIVariantValue& operator=(bool value) { - type_code = kInt; - v_union.v_int64 = value; - return *this; - } - inline APIVariantValue& operator=(std::string value) { - type_code = kStr; - str = std::move(value); - v_union.v_str = str.c_str(); - return *this; - } - inline APIVariantValue& operator=(const NodeRef& ref) { - if (ref.node_.get() == nullptr) { - type_code = kNull; - } else { - type_code = kNodeHandle; - this->sptr = ref.node_; - } - return *this; - } - inline APIVariantValue& operator=(const runtime::PackedFunc& f) { - type_code = kFuncHandle; - this->func = f.body(); - return *this; - } - inline APIVariantValue& operator=(const Type& value) { - return operator=(Type2String(value)); - } - template::value>::type> - inline operator T() const { - if (type_code == kNull) return T(); - CHECK_EQ(type_code, kNodeHandle); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Did not get expected type " << NodeTypeName(); - return T(sptr); - } - inline operator Expr() const { - if (type_code == kNull) { - return Expr(); - } - if (type_code == kInt) return Expr(operator int()); - if (type_code == kFloat) { - return Expr(static_cast(operator double())); - } - CHECK_EQ(type_code, kNodeHandle); - if (sptr->is_type()) { - return IterVar(sptr)->var; - } else { - CHECK(NodeTypeChecker::Check(sptr.get())) - << "did not pass in Expr in a place need Expr"; - return Expr(sptr); - } - } - inline operator double() const { - CHECK_EQ(type_code, kFloat); - return v_union.v_float64; - } - inline operator int64_t() const { - CHECK_EQ(type_code, kInt); - return v_union.v_int64; - } - inline operator uint64_t() const { - CHECK_EQ(type_code, kInt); - return v_union.v_int64; - } - inline operator int() const { - CHECK_EQ(type_code, kInt); - CHECK_LE(v_union.v_int64, - std::numeric_limits::max()); - return v_union.v_int64; - } - inline operator bool() const { - CHECK_EQ(type_code, kInt) - << "expect boolean(int) but get " - << TVMTypeCode2Str(type_code); - return v_union.v_int64 != 0; - } - inline operator std::string() const { - CHECK_EQ(type_code, kStr) - << "expect Str but get " - << TVMTypeCode2Str(type_code); - return str; - } - inline operator Type() const { - return String2Type(operator std::string()); - } - inline operator runtime::PackedFunc() const { - CHECK_EQ(type_code, kFuncHandle); - return runtime::PackedFunc(func); - } -}; - -// common defintiion of API function. -using APIFunc = std::function< - void(const std::vector &args, APIVariantValue* ret)>; - -/*! - * \brief Registry entry for DataIterator factory functions. - */ -struct APIFuncReg - : public dmlc::FunctionRegEntryBase { -}; - -#define TVM_REGISTER_API(TypeName) \ - DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \ - -} // namespace tvm - -#endif // TVM_C_API_C_API_REGISTRY_H_ diff --git a/src/codegen/codegen_stack_vm.cc b/src/codegen/codegen_stack_vm.cc index d9ca3d38..7748504f 100644 --- a/src/codegen/codegen_stack_vm.cc +++ b/src/codegen/codegen_stack_vm.cc @@ -12,19 +12,22 @@ using namespace ir; runtime::PackedFunc BuildStackVM(LoweredFunc func) { StackVM vm = codegen::CodeGenStackVM().Compile(func); - auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) { - LOG(INFO) << "Run stack VM"; + using runtime::TVMArgs; + using runtime::TVMRetValue; + + auto f = [vm](TVMArgs args, TVMRetValue* rv) { StackVM::State* s = StackVM::ThreadLocalState(); s->sp = 0; s->pc = 0; if (s->heap.size() < vm.heap_size) { s->heap.resize(vm.heap_size); } - s->heap[0].v_handle = (void*)args; // NOLINT(*) - s->heap[1].v_handle = (void*)type_codes; // NOLINT(*) - s->heap[2].v_int64 = num_args; + s->heap[0].v_handle = (void*)args.values; // NOLINT(*) + s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) + s->heap[2].v_int64 = args.num_args; vm.Run(s); }; + return runtime::PackedFunc(f); } @@ -118,6 +121,9 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) { auto it = fun_idmap_.find(name); if (it != fun_idmap_.end()) return it->second; using runtime::PackedFunc; + using runtime::TVMArgs; + using runtime::TVMRetValue; + PackedFunc f = PackedFunc::GetGlobal(name); auto extern_f = [f](const TVMValue* args, int num_args) { CHECK_EQ(num_args % 2, 0); @@ -128,7 +134,8 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) { int code = (tcode >> (8 * 3)) & 255; type_codes[i] = code; } - f.CallPacked(args, &type_codes[0], num_args); + TVMRetValue rv; + f.CallPacked(TVMArgs(args, &type_codes[0], num_args), &rv); TVMValue r; r.v_int64 = 0; return r; }; diff --git a/src/codegen/split_host_device.cc b/src/codegen/split_host_device.cc index 1560fda4..383e307a 100644 --- a/src/codegen/split_host_device.cc +++ b/src/codegen/split_host_device.cc @@ -136,7 +136,6 @@ class HostDeviceSplitter : public IRMutator { public: Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->type_key == "thread_extent") { - LOG(INFO) << "??"; IterVar iv(op->node.node_); return SplitDeviceFunc(s); } diff --git a/src/jit/stack_vm.cc b/src/jit/stack_vm.cc index 7c92b1ee..80c4bcbf 100644 --- a/src/jit/stack_vm.cc +++ b/src/jit/stack_vm.cc @@ -302,7 +302,8 @@ void StackVM::Run(State* s) const { STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break; } case TVM_LOAD_ARG_HANDLE: { - STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break; + STACK_VM_TVM_LOAD_ARG( + tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break; } case TVM_ARRAY_GET_DATA: { STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break; @@ -317,7 +318,7 @@ void StackVM::Run(State* s) const { STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break; } case TVM_ARRAY_GET_TYPE_CODE: { - STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break; + STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break; } case TVM_ARRAY_GET_TYPE_BITS: { STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break; diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 8bf4ba4b..17fb97de 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -3,9 +3,11 @@ * \file c_runtime_api.cc * \brief Device specific implementations */ +#include #include #include #include +#include #include "./runtime_base.h" #include "./device_api.h" @@ -37,7 +39,7 @@ inline void TVMArrayFree_(TVMArray* arr) { inline void VerifyType(TVMType dtype) { CHECK_GE(dtype.lanes, 1U); - if (dtype.type_code == kFloat) { + if (dtype.code == kFloat) { CHECK_EQ(dtype.bits % 32U, 0U); } else { CHECK_EQ(dtype.bits % 8U, 0U); @@ -65,6 +67,12 @@ inline size_t GetDataAlignment(TVMArray* arr) { using namespace tvm::runtime; +struct TVMRuntimeEntry { + std::string ret_str; +}; + +typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; + int TVMDeviceInit(int dev_mask, const char** option_keys, const char** option_vals, @@ -177,10 +185,31 @@ int TVMFuncFree(TVMFunctionHandle func) { int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, - int num_args) { + int num_args, + TVMValue* ret_val, + int* ret_type_code) { API_BEGIN(); + TVMRetValue rv; (*static_cast(func)).CallPacked( - args, arg_type_codes, num_args); + TVMArgs(args, arg_type_codes, num_args), &rv); + // handle return string. + if (rv.type_code() == kStr) { + TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); + e->ret_str = rv.operator std::string(); + *ret_type_code = kStr; + ret_val->v_str = e->ret_str.c_str(); + } else { + rv.MoveToCHost(ret_val, ret_type_code); + } + API_END(); +} + +int TVMCFuncSetReturn(TVMRetValueHandle ret, + TVMValue value, + int type_code) { + API_BEGIN(); + TVMRetValue* rv = static_cast(ret); + *rv = TVMArgValue(value, type_code); API_END(); } @@ -191,22 +220,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, API_BEGIN(); if (fin == nullptr) { *out = new PackedFunc( - [func, resource_handle](const TVMValue* args, - const int* type_codes, - int num_args) { - func((TVMValue*)args, (int*)type_codes, // NOLINT(*) - num_args, resource_handle); + [func, resource_handle](TVMArgs args, TVMRetValue* rv) { + func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, resource_handle); }); } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); *out = new PackedFunc( - [func, rpack](const TVMValue* args, - const int* type_codes, - int num_args) { - func((TVMValue*)args, (int*)type_codes, // NOLINT(*) - num_args, rpack.get()); + [func, rpack](TVMArgs args, TVMRetValue* rv) { + func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, rpack.get()); }); } API_END(); diff --git a/src/runtime/packed_func_registry.cc b/src/runtime/packed_func_registry.cc index acc09d8a..8a9c6094 100644 --- a/src/runtime/packed_func_registry.cc +++ b/src/runtime/packed_func_registry.cc @@ -4,6 +4,7 @@ * \brief The global registry of packed function. */ #include +#include #include #include #include @@ -58,6 +59,18 @@ std::vector PackedFunc::ListGlobalNames() { } // namespace runtime } // namespace tvm +/*! \brief entry to to easily hold returning information */ +struct TVMFuncThreadLocalEntry { + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; +}; + +/*! \brief Thread local store that can be used to hold return values. */ +typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; + + int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) { using tvm::runtime::PackedFunc; API_BEGIN(); @@ -68,6 +81,22 @@ int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) { int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { using tvm::runtime::PackedFunc; API_BEGIN(); - *out = new PackedFunc(PackedFunc::GetGlobal(name)); + const PackedFunc& f = PackedFunc::GetGlobal(name); + *out = (TVMFunctionHandle)(&f); // NOLINT(*) + API_END(); +} + +int TVMFuncListGlobalNames(int *out_size, + const char*** out_array) { + using tvm::runtime::PackedFunc; + API_BEGIN(); + TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get(); + ret->ret_vec_str = PackedFunc::ListGlobalNames(); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_array = dmlc::BeginPtr(ret->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_str.size()); API_END(); } diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index f2a87e48..c801b4c1 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -1,24 +1,116 @@ #include #include #include +#include +#include TEST(PackedFunc, Basic) { + using namespace tvm; using namespace tvm::runtime; int x = 0; void* handle = &x; TVMArray a; - PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) { - CHECK(num_args == 3); - CHECK(args[0].v_float64 == 1.0); - CHECK(type_codes[0] == kFloat); - CHECK(args[1].v_handle == &a); - CHECK(type_codes[1] == kHandle); - CHECK(args[2].v_handle == &x); - CHECK(type_codes[2] == kHandle); + Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 3); + CHECK(args.values[0].v_float64 == 1.0); + CHECK(args.type_codes[0] == kFloat); + CHECK(args.values[1].v_handle == &a); + CHECK(args.type_codes[1] == kArrayHandle); + CHECK(args.values[2].v_handle == &x); + CHECK(args.type_codes[2] == kHandle); + *rv = Var("a"); })(1.0, &a, handle); + CHECK(v->name_hint == "a"); } +TEST(PackedFunc, Node) { + using namespace tvm; + using namespace tvm::runtime; + Var x; + Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 1); + CHECK(args.type_codes[0] == kNodeHandle); + Var b = args[0]; + CHECK(x.same_as(b)); + *rv = b; + })(x); + CHECK(t.same_as(x)); +} + +TEST(PackedFunc, str) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 1); + std::string x = args[0]; + CHECK(x == "hello"); + *rv = x; + })("hello"); +} + + +TEST(PackedFunc, func) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { + *rv = args[0].operator int() + 1; + }); + // function as arguments + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + CHECK_EQ(r0, 2); + + int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); + CHECK_EQ(r1, 100); + + int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); + CHECK_EQ(r2, 102); +} + +TEST(PackedFunc, Expr) { + using namespace tvm; + using namespace tvm::runtime; + // automatic conversion of int to expr + PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { + Expr x = args[0]; + *rv = x.as()->value + 1; + }); + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + CHECK_EQ(r0, 2); +} + +TEST(PackedFunc, Type) { + using namespace tvm; + using namespace tvm::runtime; + auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Type x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + *rv = args[0]; + }); + CHECK(get_type("int32").operator Type() == Int(32)); + CHECK(get_type("float").operator Type() == Float(32)); + CHECK(get_type2("float32x2").operator Type() == Float(32, 2)); +} + + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/test_lang_basic.py b/tests/python/test_lang_basic.py index b1eb2990..b12a94a2 100644 --- a/tests/python/test_lang_basic.py +++ b/tests/python/test_lang_basic.py @@ -2,7 +2,8 @@ import tvm def test_const(): x = tvm.const(1) - assert x.dtype == 'int32' + print(x.dtype) + assert x.dtype == tvm.int32 assert isinstance(x, tvm.expr.IntImm) def test_const_saveload_json(): diff --git a/tests/python/test_runtime_packed_func.py b/tests/python/test_runtime_packed_func.py index ed123e9b..3332e9a3 100644 --- a/tests/python/test_runtime_packed_func.py +++ b/tests/python/test_runtime_packed_func.py @@ -17,10 +17,22 @@ def test_get_global(): @tvm.register_func def my_packed_func(*args): assert(tuple(args) == targs) + return 10 # get it out from global function table f = tvm.get_global_func("my_packed_func") assert isinstance(f, tvm.nd.Function) - f(*targs) + y = f(*targs) + assert y == 10 + + +def test_return_func(): + def addy(y): + def add(x): + return tvm.convert(x + y) + return add + myf = tvm.convert(addy) + f = myf(10) + assert f(11).value == 21 def test_convert(): @@ -38,3 +50,4 @@ if __name__ == "__main__": test_function() test_convert() test_get_global() + test_return_func() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index f1a070b5..5c332a55 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -38,10 +38,10 @@ fi if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then make all || exit -1 if [ ${TRAVIS_OS_NAME} == "osx" ]; then - python -m nose tests/python/ || exit -1 - python3 -m nose tests/python/ || exit -1 + python -m nose -v tests/python/ || exit -1 + python3 -m nose -v tests/python/ || exit -1 else - nosetests tests/python/ || exit -1 - nosetests3 tests/python/ || exit -1 + nosetests -v tests/python/ || exit -1 + nosetests3 -v tests/python/ || exit -1 fi fi