[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)
This commit is contained in:
Родитель
4242b9cff5
Коммит
ff06917c59
|
@ -4,7 +4,7 @@ language: cpp
|
||||||
|
|
||||||
os:
|
os:
|
||||||
- linux
|
- linux
|
||||||
- osx
|
# - osx
|
||||||
|
|
||||||
env:
|
env:
|
||||||
# code analysis
|
# code analysis
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2016 by Contributors
|
||||||
|
* \file api_registry.h
|
||||||
|
* \brief This file defines the TVM API registry.
|
||||||
|
*
|
||||||
|
* The API registry stores type-erased functions.
|
||||||
|
* Each registered function is automatically exposed
|
||||||
|
* to front-end language(e.g. python).
|
||||||
|
* Front-end can also pass callbacks as PackedFunc, or register
|
||||||
|
* then into the same global registry in C++.
|
||||||
|
* The goal is to mix the front-end language and the TVM back-end.
|
||||||
|
*
|
||||||
|
* \code
|
||||||
|
* // register the function as MyAPIFuncName
|
||||||
|
* TVM_REGISTER_API(MyAPIFuncName)
|
||||||
|
* .set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
* // my code.
|
||||||
|
* });
|
||||||
|
* \endcode
|
||||||
|
*/
|
||||||
|
#ifndef TVM_API_REGISTRY_H_
|
||||||
|
#define TVM_API_REGISTRY_H_
|
||||||
|
|
||||||
|
#include <dmlc/base.h>
|
||||||
|
#include <string>
|
||||||
|
#include "./base.h"
|
||||||
|
#include "./runtime/packed_func.h"
|
||||||
|
#include "./packed_func_ext.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
|
||||||
|
/*! \brief Utility to register API. */
|
||||||
|
class APIRegistry {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief set the body of the function to be f
|
||||||
|
* \param f The body of the function.
|
||||||
|
*/
|
||||||
|
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
|
||||||
|
/*!
|
||||||
|
* \brief set the body of the function to be f
|
||||||
|
* \param f The body of the function.
|
||||||
|
*/
|
||||||
|
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
|
||||||
|
return set_body(PackedFunc(f));
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Register a function with given name
|
||||||
|
* \param name The name of the function.
|
||||||
|
*/
|
||||||
|
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief name of the function */
|
||||||
|
std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get API function by name.
|
||||||
|
*
|
||||||
|
* \param name The name of the function.
|
||||||
|
* \return the corresponding API function.
|
||||||
|
* \note It is really PackedFunc::GetGlobal under the hood.
|
||||||
|
*/
|
||||||
|
inline PackedFunc GetAPIFunc(const std::string& name) {
|
||||||
|
return PackedFunc::GetGlobal(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define _TVM_REGISTER_VAR_DEF_ \
|
||||||
|
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Register API function globally.
|
||||||
|
* \code
|
||||||
|
* TVM_REGISTER_API(MyPrint)
|
||||||
|
* .set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
* // my code.
|
||||||
|
* });
|
||||||
|
* \endcode
|
||||||
|
*/
|
||||||
|
#define TVM_REGISTER_API(OpName) \
|
||||||
|
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
|
||||||
|
::tvm::APIRegistry::__REGISTER__(#OpName)
|
||||||
|
} // namespace tvm
|
||||||
|
#endif // TVM_API_REGISTRY_H_
|
|
@ -2,6 +2,13 @@
|
||||||
* Copyright (c) 2016 by Contributors
|
* Copyright (c) 2016 by Contributors
|
||||||
* \file c_api.h
|
* \file c_api.h
|
||||||
* \brief C API of TVM DSL
|
* \brief C API of TVM DSL
|
||||||
|
*
|
||||||
|
* \note The API is designed in a minimum way.
|
||||||
|
* Most of the API functions are registered and can be pulled out.
|
||||||
|
*
|
||||||
|
* The common flow is:
|
||||||
|
* - Use TVMFuncListGlobalNames to get global function name
|
||||||
|
* - Use TVMFuncCall to call these functions.
|
||||||
*/
|
*/
|
||||||
#ifndef TVM_C_API_H_
|
#ifndef TVM_C_API_H_
|
||||||
#define TVM_C_API_H_
|
#define TVM_C_API_H_
|
||||||
|
@ -9,76 +16,9 @@
|
||||||
#include "./runtime/c_runtime_api.h"
|
#include "./runtime/c_runtime_api.h"
|
||||||
|
|
||||||
TVM_EXTERN_C {
|
TVM_EXTERN_C {
|
||||||
/*! \brief handle to functions */
|
|
||||||
typedef void* APIFuncHandle;
|
|
||||||
/*! \brief handle to node */
|
/*! \brief handle to node */
|
||||||
typedef void* NodeHandle;
|
typedef void* NodeHandle;
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief List all the node function name
|
|
||||||
* \param out_size The number of functions
|
|
||||||
* \param out_array The array of function names.
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
TVM_DLL int TVMListAPIFuncNames(int *out_size,
|
|
||||||
const char*** out_array);
|
|
||||||
/*!
|
|
||||||
* \brief get function handle by name
|
|
||||||
* \param name The name of function
|
|
||||||
* \param handle The returning function handle
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
|
|
||||||
APIFuncHandle *handle);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Get the detailed information about function.
|
|
||||||
* \param handle The operator handle.
|
|
||||||
* \param real_name The returned name of the function.
|
|
||||||
* This name is not the alias name of the atomic symbol.
|
|
||||||
* \param description The returned description of the symbol.
|
|
||||||
* \param num_doc_args Number of arguments that contain documents.
|
|
||||||
* \param arg_names Name of the arguments of doc args
|
|
||||||
* \param arg_type_infos Type informations about the arguments.
|
|
||||||
* \param arg_descriptions Description information about the arguments.
|
|
||||||
* \param return_type Return type of the function, if any.
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
|
|
||||||
const char **real_name,
|
|
||||||
const char **description,
|
|
||||||
int *num_doc_args,
|
|
||||||
const char ***arg_names,
|
|
||||||
const char ***arg_type_infos,
|
|
||||||
const char ***arg_descriptions,
|
|
||||||
const char **return_type);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Push an argument to the function calling stack.
|
|
||||||
* If push fails, the stack will be reset to empty
|
|
||||||
*
|
|
||||||
* \param arg The argument
|
|
||||||
* \param type_code The type_code of argument as in TVMTypeCode
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
* \note API calls always exchanges with type bits=64, lanes=1
|
|
||||||
*/
|
|
||||||
TVM_DLL int TVMAPIPushStack(TVMValue arg,
|
|
||||||
int type_code);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief call a function by using arguments in the stack.
|
|
||||||
* The stack will be cleanup to empty after this call, whether the call is successful.
|
|
||||||
*
|
|
||||||
* \param handle The function handle
|
|
||||||
* \param ret_val The return value.
|
|
||||||
* \param ret_type_code the type code of return value.
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
* \note API calls always exchanges with type bits=64, lanes=1
|
|
||||||
*/
|
|
||||||
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
|
|
||||||
TVMValue* ret_val,
|
|
||||||
int* ret_type_code);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief free the node handle
|
* \brief free the node handle
|
||||||
* \param handle The node handle to be freed.
|
* \param handle The node handle to be freed.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
|
#include "./runtime/packed_func.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2016 by Contributors
|
||||||
|
* \file packed_func_ext.h
|
||||||
|
* \brief Extension package to PackedFunc
|
||||||
|
* This enales pass NodeRef types into/from PackedFunc.
|
||||||
|
*/
|
||||||
|
#ifndef TVM_PACKED_FUNC_EXT_H_
|
||||||
|
#define TVM_PACKED_FUNC_EXT_H_
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "./base.h"
|
||||||
|
#include "./expr.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
using runtime::TVMArgs;
|
||||||
|
using runtime::TVMRetValue;
|
||||||
|
using runtime::PackedFunc;
|
||||||
|
|
||||||
|
namespace runtime {
|
||||||
|
/*!
|
||||||
|
* \brief Runtime type checker for node type.
|
||||||
|
* \tparam T the type to be checked.
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
struct NodeTypeChecker {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
// This is the only place in the project where RTTI is used
|
||||||
|
// It can be turned off, but will make non strict checking.
|
||||||
|
// TODO(tqchen) possibly find alternative to turn of RTTI
|
||||||
|
using ContainerType = typename T::ContainerType;
|
||||||
|
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
using ContainerType = typename T::ContainerType;
|
||||||
|
os << ContainerType::_type_key;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct NodeTypeChecker<Array<T> > {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
if (sptr == nullptr) return false;
|
||||||
|
if (!sptr->is_type<ArrayNode>()) return false;
|
||||||
|
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
||||||
|
for (const auto& p : n->data) {
|
||||||
|
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
os << "array<";
|
||||||
|
NodeTypeChecker<T>::PrintName(os);
|
||||||
|
os << ">";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename K, typename V>
|
||||||
|
struct NodeTypeChecker<Map<K, V> > {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
if (sptr == nullptr) return false;
|
||||||
|
if (!sptr->is_type<MapNode>()) return false;
|
||||||
|
MapNode* n = static_cast<MapNode*>(sptr);
|
||||||
|
for (const auto& kv : n->data) {
|
||||||
|
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
||||||
|
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
os << "map<";
|
||||||
|
NodeTypeChecker<K>::PrintName(os);
|
||||||
|
os << ',';
|
||||||
|
NodeTypeChecker<V>::PrintName(os);
|
||||||
|
os << '>';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline std::string NodeTypeName() {
|
||||||
|
std::ostringstream os;
|
||||||
|
NodeTypeChecker<T>::PrintName(os);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// extensions for tvm arg value
|
||||||
|
|
||||||
|
template<typename TNodeRef, typename>
|
||||||
|
inline TVMArgValue::operator TNodeRef() const {
|
||||||
|
static_assert(
|
||||||
|
std::is_base_of<NodeRef, TNodeRef>::value,
|
||||||
|
"Conversion only works for NodeRef");
|
||||||
|
if (type_code_ == kNull) return TNodeRef();
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||||
|
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
|
||||||
|
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
|
||||||
|
<< "Expected type " << NodeTypeName<TNodeRef>()
|
||||||
|
<< " but get " << sptr->type_key();
|
||||||
|
return TNodeRef(sptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMArgValue::operator Halide::Expr() const {
|
||||||
|
if (type_code_ == kNull) return Expr();
|
||||||
|
if (type_code_ == kInt) {
|
||||||
|
return Expr(static_cast<int>(value_.v_int64));
|
||||||
|
}
|
||||||
|
if (type_code_ == kFloat) {
|
||||||
|
return Expr(static_cast<float>(value_.v_float64));
|
||||||
|
}
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||||
|
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
|
||||||
|
if (sptr->is_type<IterVarNode>()) {
|
||||||
|
return IterVar(sptr)->var;
|
||||||
|
}
|
||||||
|
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
|
||||||
|
<< "Expected type " << NodeTypeName<Expr>()
|
||||||
|
<< " but get " << sptr->type_key();
|
||||||
|
return Expr(sptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||||
|
return *ptr<std::shared_ptr<Node> >();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename TNodeRef, typename>
|
||||||
|
inline bool TVMArgValue::IsNodeType() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||||
|
std::shared_ptr<Node>& sptr =
|
||||||
|
*ptr<std::shared_ptr<Node> >();
|
||||||
|
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// extensions for TVMRetValue
|
||||||
|
inline TVMRetValue& TVMRetValue::operator=(
|
||||||
|
const std::shared_ptr<Node>& other) {
|
||||||
|
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
|
||||||
|
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TNodeRef, typename>
|
||||||
|
inline TVMRetValue::operator TNodeRef() const {
|
||||||
|
static_assert(
|
||||||
|
std::is_base_of<NodeRef, TNodeRef>::value,
|
||||||
|
"Conversion only works for NodeRef");
|
||||||
|
if (type_code_ == kNull) return TNodeRef();
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||||
|
return TNodeRef(*ptr<std::shared_ptr<Node> >());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
|
||||||
|
values_[i].v_handle = &(other.node_);
|
||||||
|
type_codes_[i] = kNodeHandle;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type related stuffs
|
||||||
|
inline Type TVMType2Type(TVMType t) {
|
||||||
|
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMType Type2TVMType(Type t) {
|
||||||
|
TVMType ret;
|
||||||
|
ret.code = static_cast<uint8_t>(t.code());
|
||||||
|
ret.bits = static_cast<uint8_t>(t.bits());
|
||||||
|
ret.lanes = static_cast<uint16_t>(t.lanes());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
|
||||||
|
return this->operator=(Type2TVMType(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMRetValue::operator Halide::Type() const {
|
||||||
|
return TVMType2Type(operator TVMType());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMArgValue::operator Halide::Type() const {
|
||||||
|
return TVMType2Type(operator TVMType());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TVMArgsSetter::operator()(
|
||||||
|
size_t i, const Halide::Type& t) const {
|
||||||
|
this->operator()(i, Type2TVMType(t));
|
||||||
|
}
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace tvm
|
||||||
|
#endif // TVM_PACKED_FUNC_EXT_H_
|
|
@ -36,18 +36,6 @@
|
||||||
TVM_EXTERN_C {
|
TVM_EXTERN_C {
|
||||||
/*! \brief type of array index. */
|
/*! \brief type of array index. */
|
||||||
typedef uint32_t tvm_index_t;
|
typedef uint32_t tvm_index_t;
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Union type of values
|
|
||||||
* being passed through API and function calls.
|
|
||||||
*/
|
|
||||||
typedef union {
|
|
||||||
int64_t v_int64;
|
|
||||||
double v_float64;
|
|
||||||
void* v_handle;
|
|
||||||
const char* v_str;
|
|
||||||
} TVMValue;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief The type code in TVMType
|
* \brief The type code in TVMType
|
||||||
* \note TVMType is used in two places.
|
* \note TVMType is used in two places.
|
||||||
|
@ -60,9 +48,11 @@ typedef enum {
|
||||||
// The next few fields are extension types
|
// The next few fields are extension types
|
||||||
// that is used by TVM API calls.
|
// that is used by TVM API calls.
|
||||||
kNull = 4U,
|
kNull = 4U,
|
||||||
kNodeHandle = 5U,
|
kArrayHandle = 5U,
|
||||||
kStr = 6U,
|
kTVMType = 6U,
|
||||||
kFuncHandle = 7U
|
kNodeHandle = 7U,
|
||||||
|
kStr = 8U,
|
||||||
|
kFuncHandle = 9U
|
||||||
} TVMTypeCode;
|
} TVMTypeCode;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -77,13 +67,25 @@ typedef enum {
|
||||||
*/
|
*/
|
||||||
typedef struct {
|
typedef struct {
|
||||||
/*! \brief type code, in TVMTypeCode */
|
/*! \brief type code, in TVMTypeCode */
|
||||||
uint8_t type_code;
|
uint8_t code;
|
||||||
/*! \brief number of bits of the type */
|
/*! \brief number of bits of the type */
|
||||||
uint8_t bits;
|
uint8_t bits;
|
||||||
/*! \brief number of lanes, */
|
/*! \brief number of lanes, */
|
||||||
uint16_t lanes;
|
uint16_t lanes;
|
||||||
} TVMType;
|
} TVMType;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Union type of values
|
||||||
|
* being passed through API and function calls.
|
||||||
|
*/
|
||||||
|
typedef union {
|
||||||
|
int64_t v_int64;
|
||||||
|
double v_float64;
|
||||||
|
void* v_handle;
|
||||||
|
const char* v_str;
|
||||||
|
TVMType v_type;
|
||||||
|
} TVMValue;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief The device type
|
* \brief The device type
|
||||||
*/
|
*/
|
||||||
|
@ -133,11 +135,10 @@ typedef struct {
|
||||||
* can be NULL, which indicates the default one.
|
* can be NULL, which indicates the default one.
|
||||||
*/
|
*/
|
||||||
typedef void* TVMStreamHandle;
|
typedef void* TVMStreamHandle;
|
||||||
/*!
|
/*! \brief Handle to packed function handle. */
|
||||||
* \brief Pointer to function handle that points to
|
|
||||||
* a generated TVM function.
|
|
||||||
*/
|
|
||||||
typedef void* TVMFunctionHandle;
|
typedef void* TVMFunctionHandle;
|
||||||
|
/*! \brief Handle to hold return value. */
|
||||||
|
typedef void* TVMRetValueHandle;
|
||||||
/*! \brief the array handle */
|
/*! \brief the array handle */
|
||||||
typedef TVMArray* TVMArrayHandle;
|
typedef TVMArray* TVMArrayHandle;
|
||||||
|
|
||||||
|
@ -228,20 +229,45 @@ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
|
||||||
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
|
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Call a function whose parameters are all packed.
|
* \brief Call a Packed TVM Function.
|
||||||
*
|
*
|
||||||
* \param func node handle of the function.
|
* \param func node handle of the function.
|
||||||
* \param args The arguments
|
* \param arg_values The arguments
|
||||||
* \param type_codes The type codes of the arguments
|
* \param type_codes The type codes of the arguments
|
||||||
* \param num_args Number of arguments.
|
* \param num_args Number of arguments.
|
||||||
*
|
*
|
||||||
|
* \param ret_val The return value.
|
||||||
|
* \param ret_type_code the type code of return value.
|
||||||
|
*
|
||||||
* \return 0 when success, -1 when failure happens
|
* \return 0 when success, -1 when failure happens
|
||||||
* \note TVM calls always exchanges with type bits=64, lanes=1
|
* \note TVM calls always exchanges with type bits=64, lanes=1
|
||||||
|
*
|
||||||
|
* \note API calls always exchanges with type bits=64, lanes=1
|
||||||
|
* If API call returns container handles (e.g. FunctionHandle)
|
||||||
|
* these handles should be managed by the front-end.
|
||||||
|
* The front-end need to call free function (e.g. TVMFuncFree)
|
||||||
|
* to free these handles.
|
||||||
*/
|
*/
|
||||||
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
||||||
TVMValue* args,
|
TVMValue* arg_values,
|
||||||
int* type_codes,
|
int* type_codes,
|
||||||
int num_args);
|
int num_args,
|
||||||
|
TVMValue* ret_val,
|
||||||
|
int* ret_type_code);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Set the return value of TVMPackedCFunc.
|
||||||
|
*
|
||||||
|
* This function is called by TVMPackedCFunc to set the return value.
|
||||||
|
* When this function is not called, the function returns null by default.
|
||||||
|
*
|
||||||
|
* \param ret The return value handle, pass by ret in TVMPackedCFunc
|
||||||
|
* \param value The value to be returned.
|
||||||
|
* \param type_code The type of the value to be returned.
|
||||||
|
*/
|
||||||
|
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
|
||||||
|
TVMValue value,
|
||||||
|
int type_code);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief C type of packed function.
|
* \brief C type of packed function.
|
||||||
|
@ -249,10 +275,17 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
|
||||||
* \param args The arguments
|
* \param args The arguments
|
||||||
* \param type_codes The type codes of the arguments
|
* \param type_codes The type codes of the arguments
|
||||||
* \param num_args Number of arguments.
|
* \param num_args Number of arguments.
|
||||||
|
* \param ret The return value handle.
|
||||||
* \param resource_handle The handle additional resouce handle from fron-end.
|
* \param resource_handle The handle additional resouce handle from fron-end.
|
||||||
|
*
|
||||||
|
* \sa TVMCFuncSetReturn
|
||||||
*/
|
*/
|
||||||
typedef void (*TVMPackedCFunc)(
|
typedef void (*TVMPackedCFunc)(
|
||||||
TVMValue* args, int* type_codes, int num_args, void* resource_handle);
|
TVMValue* args,
|
||||||
|
int* type_codes,
|
||||||
|
int num_args,
|
||||||
|
TVMRetValueHandle ret,
|
||||||
|
void* resource_handle);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief C callback to free the resource handle in C packed function.
|
* \brief C callback to free the resource handle in C packed function.
|
||||||
|
@ -291,8 +324,20 @@ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
|
||||||
*
|
*
|
||||||
* \param name The name of the function.
|
* \param name The name of the function.
|
||||||
* \param out the result function pointer.
|
* \param out the result function pointer.
|
||||||
|
*
|
||||||
|
* \note The function handle of global function is managed by TVM runtime,
|
||||||
|
* So TVMFuncFree is should not be called when it get deleted.
|
||||||
*/
|
*/
|
||||||
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
|
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief List all the globally registered function name
|
||||||
|
* \param out_size The number of functions
|
||||||
|
* \param out_array The array of function names.
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
|
||||||
|
const char*** out_array);
|
||||||
} // TVM_EXTERN_C
|
} // TVM_EXTERN_C
|
||||||
|
|
||||||
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
|
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
|
||||||
|
|
|
@ -1,19 +1,41 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2016 by Contributors
|
* Copyright (c) 2017 by Contributors
|
||||||
* \file packed_func.h
|
* \file packed_func.h
|
||||||
* \brief Runtime related c++ class.
|
* \brief Runtime related c++ class.
|
||||||
*/
|
*/
|
||||||
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
|
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
|
||||||
#define TVM_RUNTIME_PACKED_FUNC_H_
|
#define TVM_RUNTIME_PACKED_FUNC_H_
|
||||||
|
|
||||||
|
#include <dmlc/logging.h>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
#include "./c_runtime_api.h"
|
#include "./c_runtime_api.h"
|
||||||
|
|
||||||
|
namespace Halide {
|
||||||
|
// Forward declare type for extensions
|
||||||
|
// The header works fine without depending on this.
|
||||||
|
struct Type;
|
||||||
|
struct Expr;
|
||||||
|
}
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
// Forward declare NodeRef and Node for extensions.
|
||||||
|
// This header works fine without depend on NodeRef
|
||||||
|
// as long as it is not used.
|
||||||
|
class Node;
|
||||||
|
class NodeRef;
|
||||||
|
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
// forward declarations
|
||||||
|
class TVMArgs;
|
||||||
|
class TVMArgValue;
|
||||||
|
class TVMRetValue;
|
||||||
|
class TVMArgsSetter;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Packed function is a type-erased function.
|
* \brief Packed function is a type-erased function.
|
||||||
|
@ -25,8 +47,25 @@ namespace runtime {
|
||||||
*/
|
*/
|
||||||
class PackedFunc {
|
class PackedFunc {
|
||||||
public:
|
public:
|
||||||
/*! \brief The internal std::function */
|
/*!
|
||||||
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
|
* \brief The internal std::function
|
||||||
|
* \param args The arguments to the function.
|
||||||
|
* \param rv The return value.
|
||||||
|
*
|
||||||
|
* \code
|
||||||
|
* // Example code on how to implemented FType
|
||||||
|
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
|
||||||
|
* // automatically convert arguments to desired type.
|
||||||
|
* int a0 = args[0];
|
||||||
|
* float a1 = args[1];
|
||||||
|
* ...
|
||||||
|
* // automatically assign values to rv
|
||||||
|
* std::string my_return_value = "x";
|
||||||
|
* *rv = my_return_value;
|
||||||
|
* }
|
||||||
|
* \endcode
|
||||||
|
*/
|
||||||
|
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
|
||||||
/*! \brief default constructor */
|
/*! \brief default constructor */
|
||||||
PackedFunc() {}
|
PackedFunc() {}
|
||||||
/*!
|
/*!
|
||||||
|
@ -38,16 +77,24 @@ class PackedFunc {
|
||||||
* \brief Call packed function by directly passing in unpacked format.
|
* \brief Call packed function by directly passing in unpacked format.
|
||||||
* \param args Arguments to be passed.
|
* \param args Arguments to be passed.
|
||||||
* \tparam Args arguments to be passed.
|
* \tparam Args arguments to be passed.
|
||||||
|
*
|
||||||
|
* \code
|
||||||
|
* // Example code on how to call packed function
|
||||||
|
* void CallPacked(PackedFunc f) {
|
||||||
|
* // call like normal functions by pass in arguments
|
||||||
|
* // return value is automatically converted back
|
||||||
|
* int rvalue = f(1, 2.0);
|
||||||
|
* }
|
||||||
|
* \endcode
|
||||||
*/
|
*/
|
||||||
template<typename... Args>
|
template<typename... Args>
|
||||||
inline void operator()(Args&& ...args) const;
|
inline TVMRetValue operator()(Args&& ...args) const;
|
||||||
/*!
|
/*!
|
||||||
* \brief Call the function in packed format.
|
* \brief Call the function in packed format.
|
||||||
* \param args The arguments
|
* \param args The arguments
|
||||||
* \param type_codes The type_codes of the arguments
|
* \param rv The return value.
|
||||||
* \param num_args Number of arguments.
|
|
||||||
*/
|
*/
|
||||||
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
|
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
|
||||||
/*! \return the internal body function */
|
/*! \return the internal body function */
|
||||||
inline FType body() const;
|
inline FType body() const;
|
||||||
/*!
|
/*!
|
||||||
|
@ -74,82 +121,552 @@ class PackedFunc {
|
||||||
FType body_;
|
FType body_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// implementations
|
/*! \brief Arguments into TVM functions. */
|
||||||
inline void PackedFunc::CallPacked(
|
class TVMArgs {
|
||||||
const TVMValue* args, const int* type_codes, int num_args) const {
|
public:
|
||||||
body_(args, type_codes, num_args);
|
const TVMValue* values;
|
||||||
|
const int* type_codes;
|
||||||
|
int num_args;
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param values The argument values
|
||||||
|
* \param type_codes The argument type codes
|
||||||
|
* \param num_args number of arguments.
|
||||||
|
*/
|
||||||
|
TVMArgs(const TVMValue* values,
|
||||||
|
const int* type_codes,
|
||||||
|
int num_args)
|
||||||
|
: values(values),
|
||||||
|
type_codes(type_codes),
|
||||||
|
num_args(num_args) { }
|
||||||
|
/*! \return size of the arguments */
|
||||||
|
inline int size() const;
|
||||||
|
/*!
|
||||||
|
* \brief Get i-th argument
|
||||||
|
* \param i the index.
|
||||||
|
* \return the ith argument.
|
||||||
|
*/
|
||||||
|
inline TVMArgValue operator[](int i) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Convert type code to its name
|
||||||
|
* \param type_code The type code .
|
||||||
|
* \return The name of type code.
|
||||||
|
*/
|
||||||
|
inline const char* TypeCode2Str(int type_code);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief convert a string to TVM type.
|
||||||
|
* \param s The string to be converted.
|
||||||
|
* \return The corresponding tvm type.
|
||||||
|
*/
|
||||||
|
inline TVMType String2TVMType(std::string s);
|
||||||
|
|
||||||
|
// macro to check type code.
|
||||||
|
#define TVM_CHECK_TYPE_CODE(CODE, T) \
|
||||||
|
CHECK_EQ(CODE, T) << " expected " \
|
||||||
|
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Internal base class to
|
||||||
|
* handle conversion to POD values.
|
||||||
|
*/
|
||||||
|
class TVMPODValue_ {
|
||||||
|
public:
|
||||||
|
operator double() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kFloat);
|
||||||
|
return value_.v_float64;
|
||||||
|
}
|
||||||
|
operator int64_t() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||||
|
return value_.v_int64;
|
||||||
|
}
|
||||||
|
operator uint64_t() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||||
|
return value_.v_int64;
|
||||||
|
}
|
||||||
|
operator int() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||||
|
CHECK_LE(value_.v_int64,
|
||||||
|
std::numeric_limits<int>::max());
|
||||||
|
return static_cast<int>(value_.v_int64);
|
||||||
|
}
|
||||||
|
operator bool() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kInt);
|
||||||
|
return value_.v_int64 != 0;
|
||||||
|
}
|
||||||
|
operator void*() const {
|
||||||
|
if (type_code_ == kNull) return nullptr;
|
||||||
|
if (type_code_ == kArrayHandle) return value_.v_handle;
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kHandle);
|
||||||
|
return value_.v_handle;
|
||||||
|
}
|
||||||
|
operator TVMArray*() const {
|
||||||
|
if (type_code_ == kNull) return nullptr;
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle);
|
||||||
|
return static_cast<TVMArray*>(value_.v_handle);
|
||||||
|
}
|
||||||
|
int type_code() const {
|
||||||
|
return type_code_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend class TVMArgsSetter;
|
||||||
|
friend class TVMRetValue;
|
||||||
|
TVMPODValue_() : type_code_(kNull) {}
|
||||||
|
TVMPODValue_(TVMValue value, int type_code)
|
||||||
|
: value_(value), type_code_(type_code) {}
|
||||||
|
/*!
|
||||||
|
* \brief return handle as specific pointer type.
|
||||||
|
* \tparam T the data type.
|
||||||
|
* \return The pointer type.
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
T* ptr() const {
|
||||||
|
return static_cast<T*>(value_.v_handle);
|
||||||
|
}
|
||||||
|
/*! \brief The value */
|
||||||
|
TVMValue value_;
|
||||||
|
/*! \brief the type code */
|
||||||
|
int type_code_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A single argument value to PackedFunc.
|
||||||
|
* Containing both type_code and TVMValue
|
||||||
|
*
|
||||||
|
* Provides utilities to do type cast into other types.
|
||||||
|
*/
|
||||||
|
class TVMArgValue : public TVMPODValue_ {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param value of the function
|
||||||
|
* \param type_code The type code.
|
||||||
|
*/
|
||||||
|
TVMArgValue(TVMValue value, int type_code)
|
||||||
|
: TVMPODValue_(value, type_code) {
|
||||||
|
}
|
||||||
|
// reuse converter from parent
|
||||||
|
using TVMPODValue_::operator double;
|
||||||
|
using TVMPODValue_::operator int64_t;
|
||||||
|
using TVMPODValue_::operator uint64_t;
|
||||||
|
using TVMPODValue_::operator int;
|
||||||
|
using TVMPODValue_::operator bool;
|
||||||
|
using TVMPODValue_::operator void*;
|
||||||
|
using TVMPODValue_::operator TVMArray*;
|
||||||
|
// conversion operator.
|
||||||
|
operator std::string() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kStr);
|
||||||
|
return std::string(value_.v_str);
|
||||||
|
}
|
||||||
|
operator TVMType() const {
|
||||||
|
if (type_code_ == kStr) {
|
||||||
|
return String2TVMType(operator std::string());
|
||||||
|
}
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
|
||||||
|
return value_.v_type;
|
||||||
|
}
|
||||||
|
operator PackedFunc() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
|
||||||
|
return *ptr<PackedFunc>();
|
||||||
|
}
|
||||||
|
const TVMValue& value() const {
|
||||||
|
return value_;
|
||||||
|
}
|
||||||
|
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||||
|
template<typename TNodeRef,
|
||||||
|
typename = typename std::enable_if<
|
||||||
|
std::is_class<TNodeRef>::value>::type>
|
||||||
|
inline operator TNodeRef() const;
|
||||||
|
template<typename TNodeRef,
|
||||||
|
typename = typename std::enable_if<
|
||||||
|
std::is_class<TNodeRef>::value>::type>
|
||||||
|
inline bool IsNodeType() const;
|
||||||
|
inline operator Halide::Type() const;
|
||||||
|
inline operator Halide::Expr() const;
|
||||||
|
// get internal node ptr, if it is node
|
||||||
|
inline std::shared_ptr<Node>& node_sptr();
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Return Value container,
|
||||||
|
* Unlike TVMArgValue, which only holds reference and do not delete
|
||||||
|
* the underlying container during destruction.
|
||||||
|
*
|
||||||
|
* TVMRetValue holds value and will manage the underlying containers
|
||||||
|
* when it stores a complicated data type.
|
||||||
|
*/
|
||||||
|
class TVMRetValue : public TVMPODValue_ {
|
||||||
|
public:
|
||||||
|
/*! \brief default constructor */
|
||||||
|
TVMRetValue() {}
|
||||||
|
/*!
|
||||||
|
* \brief move constructor from anoter return value.
|
||||||
|
* \param other The other return value.
|
||||||
|
*/
|
||||||
|
TVMRetValue(TVMRetValue&& other)
|
||||||
|
: TVMPODValue_(other.value_, other.type_code_) {
|
||||||
|
other.type_code_ = kNull;
|
||||||
|
}
|
||||||
|
/*! \brief destructor */
|
||||||
|
~TVMRetValue() {
|
||||||
|
this->Clear();
|
||||||
|
}
|
||||||
|
// reuse converter from parent
|
||||||
|
using TVMPODValue_::operator double;
|
||||||
|
using TVMPODValue_::operator int64_t;
|
||||||
|
using TVMPODValue_::operator uint64_t;
|
||||||
|
using TVMPODValue_::operator int;
|
||||||
|
using TVMPODValue_::operator bool;
|
||||||
|
using TVMPODValue_::operator void*;
|
||||||
|
using TVMPODValue_::operator TVMArray*;
|
||||||
|
// Disable copy and assign from another value, but allow move.
|
||||||
|
TVMRetValue(const TVMRetValue& other) {
|
||||||
|
this->Assign(other);
|
||||||
|
}
|
||||||
|
// conversion operators
|
||||||
|
operator std::string() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kStr);
|
||||||
|
return *ptr<std::string>();
|
||||||
|
}
|
||||||
|
operator TVMType() const {
|
||||||
|
if (type_code_ == kStr) {
|
||||||
|
return String2TVMType(operator std::string());
|
||||||
|
}
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
|
||||||
|
return value_.v_type;
|
||||||
|
}
|
||||||
|
operator PackedFunc() const {
|
||||||
|
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
|
||||||
|
return *ptr<PackedFunc>();
|
||||||
|
}
|
||||||
|
// Assign operators
|
||||||
|
TVMRetValue& operator=(TVMRetValue&& other) {
|
||||||
|
this->Clear();
|
||||||
|
value_ = other.value_;
|
||||||
|
type_code_ = other.type_code_;
|
||||||
|
other.type_code_ = kNull;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(double value) {
|
||||||
|
this->SwitchToPOD(kFloat);
|
||||||
|
value_.v_float64 = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(std::nullptr_t value) {
|
||||||
|
this->SwitchToPOD(kNull);
|
||||||
|
value_.v_handle = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(void* value) {
|
||||||
|
this->SwitchToPOD(kHandle);
|
||||||
|
value_.v_handle = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(int64_t value) {
|
||||||
|
this->SwitchToPOD(kInt);
|
||||||
|
value_.v_int64 = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(int value) {
|
||||||
|
this->SwitchToPOD(kInt);
|
||||||
|
value_.v_int64 = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(TVMType t) {
|
||||||
|
this->SwitchToPOD(kTVMType);
|
||||||
|
value_.v_type = t;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(bool value) {
|
||||||
|
this->SwitchToPOD(kInt);
|
||||||
|
value_.v_int64 = value;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(std::string value) {
|
||||||
|
this->SwitchToClass(kStr, value);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(PackedFunc f) {
|
||||||
|
this->SwitchToClass(kFuncHandle, f);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
|
||||||
|
this->Assign(other);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
TVMRetValue& operator=(TVMArgValue other) {
|
||||||
|
this->Assign(other);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Move the value back to front-end via C API.
|
||||||
|
* This marks the current container as null.
|
||||||
|
* The managed resources is moved to front-end and
|
||||||
|
* the front end should take charge in managing them.
|
||||||
|
*
|
||||||
|
* \param ret_value The return value.
|
||||||
|
* \param ret_type_code The return type code.
|
||||||
|
*/
|
||||||
|
void MoveToCHost(TVMValue* ret_value,
|
||||||
|
int* ret_type_code) {
|
||||||
|
// cannot move str; need specially handle.
|
||||||
|
CHECK(type_code_ != kStr);
|
||||||
|
*ret_value = value_;
|
||||||
|
*ret_type_code = type_code_;
|
||||||
|
type_code_ = kNull;
|
||||||
|
}
|
||||||
|
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||||
|
inline TVMRetValue& operator=(const NodeRef& other);
|
||||||
|
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
|
||||||
|
template<typename TNodeRef,
|
||||||
|
typename = typename std::enable_if<
|
||||||
|
std::is_class<TNodeRef>::value>::type>
|
||||||
|
inline operator TNodeRef() const;
|
||||||
|
// type related
|
||||||
|
inline operator Halide::Type() const;
|
||||||
|
inline TVMRetValue& operator=(const Halide::Type& other);
|
||||||
|
|
||||||
|
private:
|
||||||
|
template<typename T>
|
||||||
|
void Assign(const T& other) {
|
||||||
|
switch (other.type_code()) {
|
||||||
|
case kStr: {
|
||||||
|
SwitchToClass<std::string>(kStr, other);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kFuncHandle: {
|
||||||
|
SwitchToClass<PackedFunc>(kFuncHandle, other);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kNodeHandle: {
|
||||||
|
SwitchToClass<std::shared_ptr<Node> >(
|
||||||
|
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
SwitchToPOD(other.type_code());
|
||||||
|
value_ = other.value_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the internal container.
|
||||||
|
void SwitchToPOD(int type_code) {
|
||||||
|
if (type_code_ != type_code) {
|
||||||
|
this->Clear();
|
||||||
|
type_code_ = type_code;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template<typename T>
|
||||||
|
void SwitchToClass(int type_code, T v) {
|
||||||
|
if (type_code_ != type_code) {
|
||||||
|
this->Clear();
|
||||||
|
type_code_ = type_code;
|
||||||
|
value_.v_handle = new T(v);
|
||||||
|
} else {
|
||||||
|
*static_cast<T*>(value_.v_handle) = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Clear() {
|
||||||
|
if (type_code_ == kNull) return;
|
||||||
|
switch (type_code_) {
|
||||||
|
case kStr: delete ptr<std::string>(); break;
|
||||||
|
case kFuncHandle: delete ptr<PackedFunc>(); break;
|
||||||
|
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
|
||||||
|
}
|
||||||
|
type_code_ = kNull;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// implementation details
|
||||||
|
inline const char* TypeCode2Str(int type_code) {
|
||||||
|
switch (type_code) {
|
||||||
|
case kInt: return "int";
|
||||||
|
case kFloat: return "float";
|
||||||
|
case kStr: return "str";
|
||||||
|
case kHandle: return "Handle";
|
||||||
|
case kNull: return "NULL";
|
||||||
|
case kNodeHandle: return "NodeHandle";
|
||||||
|
case kArrayHandle: return "ArrayHandle";
|
||||||
|
case kTVMType: return "TVMType";
|
||||||
|
case kFuncHandle: return "FunctionHandle";
|
||||||
|
default: LOG(FATAL) << "unknown type_code="
|
||||||
|
<< static_cast<int>(type_code); return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMType String2TVMType(std::string s) {
|
||||||
|
TVMType t;
|
||||||
|
t.bits = 32; t.lanes = 1;
|
||||||
|
const char* scan;
|
||||||
|
if (s.substr(0, 3) == "int") {
|
||||||
|
t.code = kInt; scan = s.c_str() + 3;
|
||||||
|
} else if (s.substr(0, 4) == "uint") {
|
||||||
|
t.code = kUInt; scan = s.c_str() + 4;
|
||||||
|
} else if (s.substr(0, 5) == "float") {
|
||||||
|
t.code = kFloat; scan = s.c_str() + 5;
|
||||||
|
} else if (s == "handle") {
|
||||||
|
t.code = kHandle;
|
||||||
|
t.bits = 64; // handle uses 64 bit by default.
|
||||||
|
scan = s.c_str() + 6;
|
||||||
|
} else {
|
||||||
|
scan = s.c_str();
|
||||||
|
LOG(FATAL) << "unknown type " << s;
|
||||||
|
}
|
||||||
|
unsigned bits = t.bits, lanes = t.lanes;
|
||||||
|
sscanf(scan, "%ux%u", &bits, &lanes);
|
||||||
|
t.bits = static_cast<uint8_t>(bits);
|
||||||
|
t.lanes = static_cast<uint16_t>(lanes);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TVMArgValue TVMArgs::operator[](int i) const {
|
||||||
|
CHECK_LT(i, num_args)
|
||||||
|
<< "not enough argument passed, "
|
||||||
|
<< num_args << " passed"
|
||||||
|
<< "but request arg" << i;
|
||||||
|
return TVMArgValue(values[i], type_codes[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int TVMArgs::size() const {
|
||||||
|
return num_args;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
|
||||||
|
body_(args, rv);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline PackedFunc::FType PackedFunc::body() const {
|
inline PackedFunc::FType PackedFunc::body() const {
|
||||||
return body_;
|
return body_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internal namespace
|
||||||
|
namespace detail {
|
||||||
template<bool stop, std::size_t I, typename F, typename ...Args>
|
template<bool stop, std::size_t I, typename F, typename ...Args>
|
||||||
struct for_each_dispatcher_ {
|
struct for_each_dispatcher {
|
||||||
static inline void run(const std::tuple<Args...>& args, F f) {
|
static void run(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
|
||||||
f(I, std::get<I>(args));
|
f(I, std::get<I>(args));
|
||||||
for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
|
for_each_dispatcher<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<std::size_t I, typename F, typename ...Args>
|
template<std::size_t I, typename F, typename ...Args>
|
||||||
struct for_each_dispatcher_<true, I, F, Args...> {
|
struct for_each_dispatcher<true, I, F, Args...> {
|
||||||
static inline void run(const std::tuple<Args...>& args, F f) {}
|
static void run(std::tuple<Args...>& args, const F& f) {} // NOLINT(*)
|
||||||
};
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
template<typename F, typename ...Args>
|
template<typename F, typename ...Args>
|
||||||
inline void for_each(const std::tuple<Args...>& args, F f) {
|
inline void for_each(std::tuple<Args...>& args, const F& f) { // NOLINT(*)
|
||||||
for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
|
detail::for_each_dispatcher<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace arg_setter {
|
/* \brief argument settter to PackedFunc */
|
||||||
template<typename T>
|
class TVMArgsSetter {
|
||||||
inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*)
|
public:
|
||||||
template<>
|
TVMArgsSetter(TVMValue* values, int* type_codes)
|
||||||
inline void Set<double>(TVMValue& arg, int& t, double value) { // NOLINT(*)
|
: values_(values), type_codes_(type_codes) {}
|
||||||
arg.v_float64 = value;
|
// setters for POD types
|
||||||
t = kFloat;
|
template<typename T,
|
||||||
}
|
typename = typename std::enable_if<std::is_integral<T>::value>::type>
|
||||||
template<>
|
void operator()(size_t i, T value) const {
|
||||||
inline void Set<int>(TVMValue& arg, int& t, int value) { // NOLINT(*)
|
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||||
arg.v_int64 = value;
|
type_codes_[i] = kInt;
|
||||||
t = kInt;
|
|
||||||
}
|
|
||||||
template<>
|
|
||||||
inline void Set<long>(TVMValue& arg, int& t, long value) { // NOLINT(*)
|
|
||||||
arg.v_int64 = value;
|
|
||||||
t = kInt;
|
|
||||||
}
|
|
||||||
template<>
|
|
||||||
inline void Set<TVMArray*>(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*)
|
|
||||||
arg.v_handle = value;
|
|
||||||
t = kHandle;
|
|
||||||
}
|
|
||||||
template<>
|
|
||||||
inline void Set<void*>(TVMValue& arg, int& t, void* value) { // NOLINT(*)
|
|
||||||
arg.v_handle = value;
|
|
||||||
t = kHandle;
|
|
||||||
}
|
|
||||||
} // namespace arg_setter
|
|
||||||
|
|
||||||
struct PackedFuncArgSetter {
|
|
||||||
TVMValue* args;
|
|
||||||
int* type_codes;
|
|
||||||
template<typename T>
|
|
||||||
inline void operator()(size_t i, T v) const {
|
|
||||||
arg_setter::Set(args[i], type_codes[i], v);
|
|
||||||
}
|
}
|
||||||
|
void operator()(size_t i, uint64_t value) const {
|
||||||
|
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||||
|
CHECK_LE(value,
|
||||||
|
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
|
||||||
|
type_codes_[i] = kInt;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, double value) const {
|
||||||
|
values_[i].v_float64 = value;
|
||||||
|
type_codes_[i] = kFloat;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, std::nullptr_t value) const {
|
||||||
|
values_[i].v_handle = value;
|
||||||
|
type_codes_[i] = kNull;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, const TVMArgValue& value) const {
|
||||||
|
values_[i] = value.value_;
|
||||||
|
type_codes_[i] = value.type_code_;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, void* value) const {
|
||||||
|
values_[i].v_handle = value;
|
||||||
|
type_codes_[i] = kHandle;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, TVMArray* value) const {
|
||||||
|
values_[i].v_handle = value;
|
||||||
|
type_codes_[i] = kArrayHandle;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, TVMType value) const {
|
||||||
|
values_[i].v_type = value;
|
||||||
|
type_codes_[i] = kTVMType;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, const char* value) const {
|
||||||
|
values_[i].v_str = value;
|
||||||
|
type_codes_[i] = kStr;
|
||||||
|
}
|
||||||
|
// setters for container type
|
||||||
|
// They must be reference(instead of const ref)
|
||||||
|
// to make sure they are alive in the tuple(instead of getting converted)
|
||||||
|
void operator()(size_t i, std::string& value) const { // NOLINT(*)
|
||||||
|
values_[i].v_str = value.c_str();
|
||||||
|
type_codes_[i] = kStr;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*)
|
||||||
|
values_[i].v_handle = &value;
|
||||||
|
type_codes_[i] = kFuncHandle;
|
||||||
|
}
|
||||||
|
void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*)
|
||||||
|
if (value.type_code() == kStr) {
|
||||||
|
values_[i].v_str = value.ptr<std::string>()->c_str();
|
||||||
|
type_codes_[i] = kStr;
|
||||||
|
} else {
|
||||||
|
values_[i] = value.value_;
|
||||||
|
type_codes_[i] = value.type_code();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// NodeRef related extenstions: in tvm/packed_func_ext.h
|
||||||
|
inline void operator()(size_t i, NodeRef& other) const; // NOLINT(*)
|
||||||
|
inline void operator()(size_t i, const Halide::Type& t) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief The values fields */
|
||||||
|
TVMValue* values_;
|
||||||
|
/*! \brief The type code fields */
|
||||||
|
int* type_codes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TVMArgsGetter {
|
||||||
|
public:
|
||||||
|
explicit TVMArgsGetter(TVMArgs args)
|
||||||
|
: args_(args) {}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline void operator()(size_t i, T& target) const { // NOLINT(*)
|
||||||
|
target = args_[i].operator T();
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
TVMArgs args_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename... Args>
|
template<typename... Args>
|
||||||
inline void PackedFunc::operator()(Args&& ...args) const {
|
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
|
||||||
auto targ = std::make_tuple(std::forward<Args>(args)...);
|
auto targs = std::make_tuple(std::forward<Args>(args)...);
|
||||||
const int kNumArgs = sizeof...(Args);
|
const int kNumArgs = sizeof...(Args);
|
||||||
TVMValue tvm_args[kNumArgs];
|
TVMValue values[kNumArgs];
|
||||||
int tvm_arg_type_ids[kNumArgs];
|
int type_codes[kNumArgs];
|
||||||
for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids});
|
for_each(targs, TVMArgsSetter(values, type_codes));
|
||||||
body_(tvm_args, tvm_arg_type_ids, kNumArgs);
|
TVMRetValue rv;
|
||||||
|
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
|
||||||
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
#endif // TVM_RUNTIME_PACKED_FUNC_H_
|
#endif // TVM_RUNTIME_PACKED_FUNC_H_
|
||||||
|
|
|
@ -10,5 +10,6 @@
|
||||||
#include "./expr.h"
|
#include "./expr.h"
|
||||||
#include "./tensor.h"
|
#include "./tensor.h"
|
||||||
#include "./operation.h"
|
#include "./operation.h"
|
||||||
|
#include "./packed_func_ext.h"
|
||||||
|
|
||||||
#endif // TVM_TVM_H_
|
#endif // TVM_TVM_H_
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# pylint: disable=redefined-builtin, wildcard-import
|
# pylint: disable=redefined-builtin, wildcard-import
|
||||||
"""C++ backend related python scripts"""
|
"""C++ backend related python scripts"""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import register_node
|
from ._ctypes._node import register_node
|
||||||
|
|
||||||
from . import tensor
|
from . import tensor
|
||||||
from . import expr
|
from . import expr
|
||||||
|
|
|
@ -91,45 +91,3 @@ def c_array(ctype, values):
|
||||||
Created ctypes array
|
Created ctypes array
|
||||||
"""
|
"""
|
||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
|
|
||||||
"""Convert ctypes returned doc string information into parameters docstring.
|
|
||||||
|
|
||||||
num_args : nn_uint
|
|
||||||
Number of arguments.
|
|
||||||
|
|
||||||
arg_names : ctypes.POINTER(ctypes.c_char_p)
|
|
||||||
Argument names.
|
|
||||||
|
|
||||||
arg_types : ctypes.POINTER(ctypes.c_char_p)
|
|
||||||
Argument type information.
|
|
||||||
|
|
||||||
arg_descs : ctypes.POINTER(ctypes.c_char_p)
|
|
||||||
Argument description information.
|
|
||||||
|
|
||||||
remove_dup : boolean, optional
|
|
||||||
Whether remove duplication or not.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
docstr : str
|
|
||||||
Python docstring of parameter sections.
|
|
||||||
"""
|
|
||||||
param_keys = set()
|
|
||||||
param_str = []
|
|
||||||
for i in range(num_args.value):
|
|
||||||
key = py_str(arg_names[i])
|
|
||||||
if key in param_keys and remove_dup:
|
|
||||||
continue
|
|
||||||
param_keys.add(key)
|
|
||||||
type_info = py_str(arg_types[i])
|
|
||||||
ret = '%s : %s' % (key, type_info)
|
|
||||||
if len(arg_descs[i]) != 0:
|
|
||||||
ret += '\n ' + py_str(arg_descs[i])
|
|
||||||
param_str.append(ret)
|
|
||||||
doc_str = ('Parameters\n' +
|
|
||||||
'----------\n' +
|
|
||||||
'%s\n')
|
|
||||||
doc_str = doc_str % ('\n'.join(param_str))
|
|
||||||
return doc_str
|
|
||||||
|
|
|
@ -1,416 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
|
|
||||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements
|
|
||||||
"""Symbolic configuration API."""
|
|
||||||
from __future__ import absolute_import as _abs
|
|
||||||
|
|
||||||
import ctypes
|
|
||||||
import sys
|
|
||||||
from numbers import Number, Integral
|
|
||||||
|
|
||||||
from .._base import _LIB
|
|
||||||
from .._base import c_str, py_str, string_types
|
|
||||||
from .._base import check_call, ctypes2docstring
|
|
||||||
from .. import _api_internal
|
|
||||||
from . import _runtime_api
|
|
||||||
from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer
|
|
||||||
|
|
||||||
# type definitions
|
|
||||||
APIFuncHandle = ctypes.c_void_p
|
|
||||||
NodeHandle = ctypes.c_void_p
|
|
||||||
FunctionHandle = ctypes.c_void_p
|
|
||||||
|
|
||||||
class APIType(object):
|
|
||||||
"""TVMType used in API calls"""
|
|
||||||
INT = ctypes.c_int(TypeCode.INT)
|
|
||||||
UINT = ctypes.c_int(TypeCode.UINT)
|
|
||||||
FLOAT = ctypes.c_int(TypeCode.FLOAT)
|
|
||||||
HANDLE = ctypes.c_int(TypeCode.HANDLE)
|
|
||||||
NULL = ctypes.c_int(TypeCode.NULL)
|
|
||||||
NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE)
|
|
||||||
STR = ctypes.c_int(TypeCode.STR)
|
|
||||||
FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_TYPE = {
|
|
||||||
}
|
|
||||||
|
|
||||||
def _return_node(x):
|
|
||||||
handle = x.v_handle
|
|
||||||
if not isinstance(handle, NodeHandle):
|
|
||||||
handle = NodeHandle(handle)
|
|
||||||
ret_val = TVMValue()
|
|
||||||
ret_type_code = ctypes.c_int()
|
|
||||||
ret_success = ctypes.c_int()
|
|
||||||
check_call(_LIB.TVMNodeGetAttr(
|
|
||||||
handle, c_str("type_key"),
|
|
||||||
ctypes.byref(ret_val),
|
|
||||||
ctypes.byref(ret_type_code),
|
|
||||||
ctypes.byref(ret_success)))
|
|
||||||
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
|
|
||||||
|
|
||||||
|
|
||||||
def _return_func(x):
|
|
||||||
handle = x.v_handle
|
|
||||||
if not isinstance(handle, FunctionHandle):
|
|
||||||
handle = FunctionHandle(handle)
|
|
||||||
return _runtime_api._function_cls(handle)
|
|
||||||
|
|
||||||
|
|
||||||
def _return_handle(x):
|
|
||||||
handle = x.v_handle
|
|
||||||
if not isinstance(handle, ctypes.c_void_p):
|
|
||||||
handle = ctypes.c_void_p(handle)
|
|
||||||
return handle
|
|
||||||
|
|
||||||
|
|
||||||
RET_SWITCH = {
|
|
||||||
TypeCode.NULL: lambda x: None,
|
|
||||||
TypeCode.INT: lambda x: x.v_int64,
|
|
||||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
|
||||||
TypeCode.STR: lambda x: py_str(x.v_str),
|
|
||||||
TypeCode.NODE_HANDLE: _return_node,
|
|
||||||
TypeCode.FUNC_HANDLE: _return_func
|
|
||||||
}
|
|
||||||
|
|
||||||
PACK_ARG_SWITCH = {
|
|
||||||
TypeCode.NULL: lambda x: None,
|
|
||||||
TypeCode.INT: lambda x: x.v_int64,
|
|
||||||
TypeCode.FLOAT: lambda x: x.v_float64,
|
|
||||||
TypeCode.STR: lambda x: py_str(x.v_str),
|
|
||||||
TypeCode.HANDLE: lambda x: _return_handle,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SliceBase(object):
|
|
||||||
"""base class of slice object"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class NodeBase(object):
|
|
||||||
"""Symbol is symbolic graph."""
|
|
||||||
__slots__ = ["handle"]
|
|
||||||
# pylint: disable=no-member
|
|
||||||
def __init__(self, handle):
|
|
||||||
"""Initialize the function with handle
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
handle : SymbolHandle
|
|
||||||
the handle to the underlying C++ Symbol
|
|
||||||
"""
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return _api_internal._format_str(self)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
check_call(_LIB.TVMNodeFree(self.handle))
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
ret_val = TVMValue()
|
|
||||||
ret_type_code = ctypes.c_int()
|
|
||||||
ret_success = ctypes.c_int()
|
|
||||||
check_call(_LIB.TVMNodeGetAttr(
|
|
||||||
self.handle, c_str(name),
|
|
||||||
ctypes.byref(ret_val),
|
|
||||||
ctypes.byref(ret_type_code),
|
|
||||||
ctypes.byref(ret_success)))
|
|
||||||
value = RET_SWITCH[ret_type_code.value](ret_val)
|
|
||||||
if not ret_success.value:
|
|
||||||
raise AttributeError(
|
|
||||||
"'%s' object has no attribute '%s'" % (str(type(self)), name))
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return _api_internal._raw_ptr(self)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, NodeBase):
|
|
||||||
return False
|
|
||||||
return self.__hash__() == other.__hash__()
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def __dir__(self):
|
|
||||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
|
||||||
size = ctypes.c_uint()
|
|
||||||
check_call(_LIB.TVMNodeListAttrNames(
|
|
||||||
self.handle, ctypes.byref(size), ctypes.byref(plist)))
|
|
||||||
names = []
|
|
||||||
for i in range(size.value):
|
|
||||||
names.append(py_str(plist[i]))
|
|
||||||
return names
|
|
||||||
|
|
||||||
def __reduce__(self):
|
|
||||||
return (type(self), (None,), self.__getstate__())
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
handle = self.handle
|
|
||||||
if handle is not None:
|
|
||||||
return {'handle': _api_internal._save_json(self)}
|
|
||||||
else:
|
|
||||||
return {'handle': None}
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
# pylint: disable=assigning-non-slot
|
|
||||||
handle = state['handle']
|
|
||||||
if handle is not None:
|
|
||||||
json_str = handle
|
|
||||||
_push_arg(json_str)
|
|
||||||
other = _api_internal._load_json(json_str)
|
|
||||||
self.handle = other.handle
|
|
||||||
other.handle = None
|
|
||||||
else:
|
|
||||||
self.handle = None
|
|
||||||
|
|
||||||
|
|
||||||
def const(value, dtype=None):
|
|
||||||
"""construct a constant"""
|
|
||||||
if dtype is None:
|
|
||||||
if isinstance(value, Integral):
|
|
||||||
dtype = 'int32'
|
|
||||||
else:
|
|
||||||
dtype = 'float32'
|
|
||||||
return _api_internal._const(value, dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def _ctypes_free_resource(rhandle):
|
|
||||||
"""callback to free resources when it it not needed."""
|
|
||||||
pyobj = ctypes.cast(rhandle, ctypes.py_object)
|
|
||||||
ctypes.pythonapi.Py_DecRef(pyobj)
|
|
||||||
|
|
||||||
# Global callback that is always alive
|
|
||||||
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
|
|
||||||
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
|
|
||||||
|
|
||||||
def convert_to_tvm_func(pyfunc):
|
|
||||||
"""Convert a python function to TVM function
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pyfunc : python function
|
|
||||||
The python function to be converted.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
tvmfunc: tvm.nd.Function
|
|
||||||
The converted tvm function.
|
|
||||||
"""
|
|
||||||
local_pyfunc = pyfunc
|
|
||||||
def cfun(args, type_codes, num_args, _):
|
|
||||||
""" ctypes function """
|
|
||||||
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
|
||||||
pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
|
||||||
local_pyfunc(*pyargs)
|
|
||||||
handle = FunctionHandle()
|
|
||||||
f = TVMPackedCFunc(cfun)
|
|
||||||
# NOTE: We will need to use python-api to increase ref count of the f
|
|
||||||
# TVM_FREE_PYOBJ will be called after it is no longer needed.
|
|
||||||
pyobj = ctypes.py_object(f)
|
|
||||||
ctypes.pythonapi.Py_IncRef(pyobj)
|
|
||||||
check_call(_LIB.TVMFuncCreateFromCFunc(
|
|
||||||
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
|
|
||||||
return _runtime_api._function_cls(handle)
|
|
||||||
|
|
||||||
|
|
||||||
def convert(value):
|
|
||||||
"""Convert a value to expression."""
|
|
||||||
if isinstance(value, (NodeBase, _runtime_api.FunctionBase)):
|
|
||||||
return value
|
|
||||||
elif isinstance(value, Number):
|
|
||||||
return const(value)
|
|
||||||
elif isinstance(value, string_types):
|
|
||||||
return _api_internal._str(value)
|
|
||||||
elif isinstance(value, (list, tuple)):
|
|
||||||
value = [convert(x) for x in value]
|
|
||||||
return _api_internal._Array(*value)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
vlist = []
|
|
||||||
for it in value.items():
|
|
||||||
if not isinstance(it[0], NodeBase):
|
|
||||||
raise ValueError("key of map must already been a container type")
|
|
||||||
vlist.append(it[0])
|
|
||||||
vlist.append(convert(it[1]))
|
|
||||||
return _api_internal._Map(*vlist)
|
|
||||||
elif isinstance(value, SliceBase):
|
|
||||||
return value.tensor(*value.indices)
|
|
||||||
elif callable(value):
|
|
||||||
return convert_to_tvm_func(value)
|
|
||||||
else:
|
|
||||||
raise ValueError("don't know how to handle type %s" % type(value))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _push_arg(arg):
|
|
||||||
a = TVMValue()
|
|
||||||
if arg is None:
|
|
||||||
_LIB.TVMAPIPushStack(a, APIType.NULL)
|
|
||||||
elif isinstance(arg, NodeBase):
|
|
||||||
a.v_handle = arg.handle
|
|
||||||
_LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE)
|
|
||||||
elif isinstance(arg, Integral):
|
|
||||||
a.v_int64 = ctypes.c_int64(arg)
|
|
||||||
_LIB.TVMAPIPushStack(a, APIType.INT)
|
|
||||||
elif isinstance(arg, Number):
|
|
||||||
a.v_double = ctypes.c_double(arg)
|
|
||||||
_LIB.TVMAPIPushStack(a, APIType.FLOAT)
|
|
||||||
elif isinstance(arg, string_types):
|
|
||||||
a.v_str = c_str(arg)
|
|
||||||
_LIB.TVMAPIPushStack(a, APIType.STR)
|
|
||||||
else:
|
|
||||||
raise TypeError("Don't know how to handle type %s" % type(arg))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_function(handle, name):
|
|
||||||
"""Create an atomic symbol function by handle and funciton name."""
|
|
||||||
real_name = ctypes.c_char_p()
|
|
||||||
desc = ctypes.c_char_p()
|
|
||||||
num_args = ctypes.c_int()
|
|
||||||
arg_names = ctypes.POINTER(ctypes.c_char_p)()
|
|
||||||
arg_types = ctypes.POINTER(ctypes.c_char_p)()
|
|
||||||
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
|
|
||||||
ret_type = ctypes.c_char_p()
|
|
||||||
|
|
||||||
check_call(_LIB.TVMGetAPIFuncInfo(
|
|
||||||
handle, ctypes.byref(real_name), ctypes.byref(desc),
|
|
||||||
ctypes.byref(num_args),
|
|
||||||
ctypes.byref(arg_names),
|
|
||||||
ctypes.byref(arg_types),
|
|
||||||
ctypes.byref(arg_descs),
|
|
||||||
ctypes.byref(ret_type)))
|
|
||||||
|
|
||||||
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
|
|
||||||
func_name = name
|
|
||||||
desc = py_str(desc.value)
|
|
||||||
|
|
||||||
doc_str = ('%s\n\n' +
|
|
||||||
'%s\n')
|
|
||||||
doc_str = doc_str % (desc, param_str)
|
|
||||||
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
|
|
||||||
|
|
||||||
def func(*args):
|
|
||||||
"""TVM function"""
|
|
||||||
cargs = []
|
|
||||||
for x in args:
|
|
||||||
if isinstance(x, (list, tuple, dict, SliceBase)):
|
|
||||||
cargs.append(convert(x))
|
|
||||||
else:
|
|
||||||
cargs.append(x)
|
|
||||||
|
|
||||||
for arg in cargs:
|
|
||||||
_push_arg(arg)
|
|
||||||
ret_val = TVMValue()
|
|
||||||
ret_type_code = ctypes.c_int()
|
|
||||||
check_call(_LIB.TVMAPIFuncCall(
|
|
||||||
handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code)))
|
|
||||||
return RET_SWITCH[ret_type_code.value](ret_val)
|
|
||||||
|
|
||||||
func.__name__ = func_name
|
|
||||||
func.__doc__ = doc_str
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def register_node(type_key=None):
|
|
||||||
"""register node type
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
type_key : str or cls
|
|
||||||
The type key of the node
|
|
||||||
"""
|
|
||||||
if isinstance(type_key, str):
|
|
||||||
def register(cls):
|
|
||||||
"""internal register function"""
|
|
||||||
NODE_TYPE[type_key] = cls
|
|
||||||
return cls
|
|
||||||
return register
|
|
||||||
else:
|
|
||||||
cls = type_key
|
|
||||||
NODE_TYPE[cls.__name__] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
def register_func(func_name, f=None):
|
|
||||||
"""Register global function
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
func_name : str or function
|
|
||||||
The function name
|
|
||||||
|
|
||||||
f : function
|
|
||||||
The function to be registered.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
fregister : function
|
|
||||||
Register function if f is not specified.
|
|
||||||
"""
|
|
||||||
if callable(func_name):
|
|
||||||
f = func_name
|
|
||||||
func_name = f.__name__
|
|
||||||
|
|
||||||
if not isinstance(func_name, str):
|
|
||||||
raise ValueError("expect string function name")
|
|
||||||
def register(myf):
|
|
||||||
"""internal register function"""
|
|
||||||
if not isinstance(myf, _runtime_api.FunctionBase):
|
|
||||||
myf = convert_to_tvm_func(myf)
|
|
||||||
check_call(_LIB.TVMFuncRegisterGlobal(
|
|
||||||
c_str(func_name), myf.handle))
|
|
||||||
if f:
|
|
||||||
register(f)
|
|
||||||
else:
|
|
||||||
return register
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_func(name):
|
|
||||||
"""Get a global function by name
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The name of the global function
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
func : tvm.nd.Function
|
|
||||||
The function to be returned.
|
|
||||||
"""
|
|
||||||
handle = FunctionHandle()
|
|
||||||
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
|
||||||
return _runtime_api._function_cls(handle)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_api_module(root_namespace):
|
|
||||||
"""List and add all the functions to current module."""
|
|
||||||
plist = ctypes.POINTER(ctypes.c_char_p)()
|
|
||||||
size = ctypes.c_uint()
|
|
||||||
|
|
||||||
check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size),
|
|
||||||
ctypes.byref(plist)))
|
|
||||||
op_names = []
|
|
||||||
for i in range(size.value):
|
|
||||||
op_names.append(py_str(plist[i]))
|
|
||||||
|
|
||||||
module_obj = sys.modules["%s.api" % root_namespace]
|
|
||||||
module_internal = sys.modules["%s._api_internal" % root_namespace]
|
|
||||||
namespace_match = {
|
|
||||||
"_make_": sys.modules["%s.make" % root_namespace],
|
|
||||||
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
|
|
||||||
"_codegen_": sys.modules["%s.codegen" % root_namespace],
|
|
||||||
"_schedule_": sys.modules["%s.schedule" % root_namespace]
|
|
||||||
}
|
|
||||||
|
|
||||||
for name in op_names:
|
|
||||||
hdl = APIFuncHandle()
|
|
||||||
check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl)))
|
|
||||||
fname = name
|
|
||||||
target_module = module_internal if name.startswith('_') else module_obj
|
|
||||||
for k, v in namespace_match.items():
|
|
||||||
if name.startswith(k):
|
|
||||||
fname = name[len(k):]
|
|
||||||
target_module = v
|
|
||||||
function = _make_function(hdl, fname)
|
|
||||||
setattr(target_module, function.__name__, function)
|
|
|
@ -0,0 +1,250 @@
|
||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable=invalid-name, protected-access
|
||||||
|
"""Symbolic configuration API."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
from numbers import Number, Integral
|
||||||
|
|
||||||
|
from .._base import _LIB, check_call
|
||||||
|
from .._base import c_str, py_str, string_types
|
||||||
|
from ._types import TVMValue, TypeCode, TVMType
|
||||||
|
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
|
||||||
|
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
|
||||||
|
from ._node import NodeBase, SliceBase, convert_to_node
|
||||||
|
from ._ndarray import NDArrayBase
|
||||||
|
|
||||||
|
FunctionHandle = ctypes.c_void_p
|
||||||
|
TVMRetValueHandle = ctypes.c_void_p
|
||||||
|
|
||||||
|
def _ctypes_free_resource(rhandle):
|
||||||
|
"""callback to free resources when it it not needed."""
|
||||||
|
pyobj = ctypes.cast(rhandle, ctypes.py_object)
|
||||||
|
ctypes.pythonapi.Py_DecRef(pyobj)
|
||||||
|
|
||||||
|
# Global callback that is always alive
|
||||||
|
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
|
||||||
|
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
|
||||||
|
|
||||||
|
def convert_to_tvm_func(pyfunc):
|
||||||
|
"""Convert a python function to TVM function
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pyfunc : python function
|
||||||
|
The python function to be converted.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tvmfunc: tvm.nd.Function
|
||||||
|
The converted tvm function.
|
||||||
|
"""
|
||||||
|
local_pyfunc = pyfunc
|
||||||
|
def cfun(args, type_codes, num_args, ret, _):
|
||||||
|
""" ctypes function """
|
||||||
|
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
|
||||||
|
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
|
||||||
|
rv = local_pyfunc(*pyargs)
|
||||||
|
if rv is not None:
|
||||||
|
if isinstance(rv, tuple):
|
||||||
|
raise ValueError("PackedFunction can only support one reurn value")
|
||||||
|
temp_args = []
|
||||||
|
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
|
||||||
|
if not isinstance(ret, TVMRetValueHandle):
|
||||||
|
ret = TVMRetValueHandle(ret)
|
||||||
|
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
|
||||||
|
_ = temp_args
|
||||||
|
_ = rv
|
||||||
|
|
||||||
|
handle = FunctionHandle()
|
||||||
|
f = TVMPackedCFunc(cfun)
|
||||||
|
# NOTE: We will need to use python-api to increase ref count of the f
|
||||||
|
# TVM_FREE_PYOBJ will be called after it is no longer needed.
|
||||||
|
pyobj = ctypes.py_object(f)
|
||||||
|
ctypes.pythonapi.Py_IncRef(pyobj)
|
||||||
|
check_call(_LIB.TVMFuncCreateFromCFunc(
|
||||||
|
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
|
||||||
|
return Function(handle)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tvm_args(args, temp_args):
|
||||||
|
"""Pack arguments into c args tvm call accept"""
|
||||||
|
num_args = len(args)
|
||||||
|
values = (TVMValue * num_args)()
|
||||||
|
type_codes = (ctypes.c_int * num_args)()
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if arg is None:
|
||||||
|
values[i].v_handle = None
|
||||||
|
type_codes[i] = TypeCode.NULL
|
||||||
|
elif isinstance(arg, NDArrayBase):
|
||||||
|
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
|
||||||
|
type_codes[i] = TypeCode.ARRAY_HANDLE
|
||||||
|
elif isinstance(arg, NodeBase):
|
||||||
|
values[i].v_handle = arg.handle
|
||||||
|
type_codes[i] = TypeCode.NODE_HANDLE
|
||||||
|
elif isinstance(arg, Integral):
|
||||||
|
values[i].v_int64 = arg
|
||||||
|
type_codes[i] = TypeCode.INT
|
||||||
|
elif isinstance(arg, Number):
|
||||||
|
values[i].v_float64 = arg
|
||||||
|
type_codes[i] = TypeCode.FLOAT
|
||||||
|
elif isinstance(arg, TVMType):
|
||||||
|
values[i].v_type = arg
|
||||||
|
type_codes[i] = TypeCode.TVM_TYPE
|
||||||
|
elif isinstance(arg, string_types):
|
||||||
|
values[i].v_str = c_str(arg)
|
||||||
|
type_codes[i] = TypeCode.STR
|
||||||
|
elif isinstance(arg, (list, tuple, dict, SliceBase)):
|
||||||
|
arg = convert_to_node(arg)
|
||||||
|
values[i].v_handle = arg.handle
|
||||||
|
type_codes[i] = TypeCode.NODE_HANDLE
|
||||||
|
temp_args.append(arg)
|
||||||
|
elif isinstance(arg, Function):
|
||||||
|
values[i].v_handle = arg.handle
|
||||||
|
type_codes[i] = TypeCode.FUNC_HANDLE
|
||||||
|
elif callable(arg):
|
||||||
|
arg = convert_to_tvm_func(arg)
|
||||||
|
values[i].v_handle = arg.handle
|
||||||
|
type_codes[i] = TypeCode.FUNC_HANDLE
|
||||||
|
temp_args.append(arg)
|
||||||
|
else:
|
||||||
|
raise TypeError("Don't know how to handle type %s" % type(arg))
|
||||||
|
return values, type_codes, num_args
|
||||||
|
|
||||||
|
|
||||||
|
class Function(object):
|
||||||
|
"""A function object at runtime."""
|
||||||
|
__slots__ = ["handle", "is_global"]
|
||||||
|
# pylint: disable=no-member
|
||||||
|
def __init__(self, handle, is_global=False):
|
||||||
|
"""Initialize the function with handle
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
handle : FunctionHandle
|
||||||
|
the handle to the underlying function.
|
||||||
|
|
||||||
|
is_global : bool, optional
|
||||||
|
Whether it is global function
|
||||||
|
"""
|
||||||
|
self.handle = handle
|
||||||
|
self.is_global = is_global
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if not self.is_global:
|
||||||
|
check_call(_LIB.TVMFuncFree(self.handle))
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
temp_args = []
|
||||||
|
values, tcodes, num_args = _make_tvm_args(args, temp_args)
|
||||||
|
ret_val = TVMValue()
|
||||||
|
ret_tcode = ctypes.c_int()
|
||||||
|
check_call(_LIB.TVMFuncCall(
|
||||||
|
self.handle, values, tcodes, ctypes.c_int(num_args),
|
||||||
|
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
|
||||||
|
_ = temp_args
|
||||||
|
_ = args
|
||||||
|
return RETURN_SWITCH[ret_tcode.value](ret_val)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_return_func(x):
|
||||||
|
"""Return function"""
|
||||||
|
handle = x.v_handle
|
||||||
|
if not isinstance(handle, FunctionHandle):
|
||||||
|
handle = FunctionHandle(handle)
|
||||||
|
return Function(handle, False)
|
||||||
|
|
||||||
|
# setup return handle for function type
|
||||||
|
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
|
||||||
|
|
||||||
|
def register_func(func_name, f=None):
|
||||||
|
"""Register global function
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func_name : str or function
|
||||||
|
The function name
|
||||||
|
|
||||||
|
f : function
|
||||||
|
The function to be registered.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
fregister : function
|
||||||
|
Register function if f is not specified.
|
||||||
|
"""
|
||||||
|
if callable(func_name):
|
||||||
|
f = func_name
|
||||||
|
func_name = f.__name__
|
||||||
|
|
||||||
|
if not isinstance(func_name, str):
|
||||||
|
raise ValueError("expect string function name")
|
||||||
|
def register(myf):
|
||||||
|
"""internal register function"""
|
||||||
|
if not isinstance(myf, Function):
|
||||||
|
myf = convert_to_tvm_func(myf)
|
||||||
|
check_call(_LIB.TVMFuncRegisterGlobal(
|
||||||
|
c_str(func_name), myf.handle))
|
||||||
|
if f:
|
||||||
|
register(f)
|
||||||
|
else:
|
||||||
|
return register
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_func(name):
|
||||||
|
"""Get a global function by name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the global function
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
func : tvm.nd.Function
|
||||||
|
The function to be returned.
|
||||||
|
"""
|
||||||
|
handle = FunctionHandle()
|
||||||
|
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
|
||||||
|
return Function(handle, True)
|
||||||
|
|
||||||
|
|
||||||
|
def list_global_func_names():
|
||||||
|
"""Get list of global functions registered.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
names : list
|
||||||
|
List of global functions names.
|
||||||
|
"""
|
||||||
|
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
|
size = ctypes.c_uint()
|
||||||
|
|
||||||
|
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
|
||||||
|
ctypes.byref(plist)))
|
||||||
|
fnames = []
|
||||||
|
for i in range(size.value):
|
||||||
|
fnames.append(py_str(plist[i]))
|
||||||
|
return fnames
|
||||||
|
|
||||||
|
|
||||||
|
def _init_api_functions(root_namespace):
|
||||||
|
"""List and add all the functions to current module."""
|
||||||
|
module_obj = sys.modules["%s.api" % root_namespace]
|
||||||
|
module_internal = sys.modules["%s._api_internal" % root_namespace]
|
||||||
|
namespace_match = {
|
||||||
|
"_make_": sys.modules["%s.make" % root_namespace],
|
||||||
|
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
|
||||||
|
"_codegen_": sys.modules["%s.codegen" % root_namespace],
|
||||||
|
"_schedule_": sys.modules["%s.schedule" % root_namespace]
|
||||||
|
}
|
||||||
|
for name in list_global_func_names():
|
||||||
|
fname = name
|
||||||
|
target_module = module_internal if name.startswith('_') else module_obj
|
||||||
|
for k, v in namespace_match.items():
|
||||||
|
if name.startswith(k):
|
||||||
|
fname = name[len(k):]
|
||||||
|
target_module = v
|
||||||
|
f = get_global_func(name)
|
||||||
|
setattr(target_module, fname, f)
|
|
@ -1,18 +1,15 @@
|
||||||
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
|
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
|
||||||
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
|
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
|
||||||
"""Symbolic configuration API."""
|
"""Symbolic configuration API."""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
from numbers import Number, Integral
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .._base import _LIB
|
from .._base import _LIB, check_call
|
||||||
from .._base import c_array, c_str, string_types
|
from .._base import c_array, c_str
|
||||||
from .._base import check_call
|
from ._types import TVMType, tvm_index_t
|
||||||
from ._types import TVMValue, TypeCode, TVMType
|
|
||||||
|
|
||||||
tvm_index_t = ctypes.c_uint32
|
|
||||||
|
|
||||||
class TVMContext(ctypes.Structure):
|
class TVMContext(ctypes.Structure):
|
||||||
"""TVM context strucure."""
|
"""TVM context strucure."""
|
||||||
|
@ -39,6 +36,19 @@ class TVMContext(ctypes.Structure):
|
||||||
return ret.value != 0
|
return ret.value != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TVMArray(ctypes.Structure):
|
||||||
|
"""TVMValue in C API"""
|
||||||
|
_fields_ = [("data", ctypes.c_void_p),
|
||||||
|
("shape", ctypes.POINTER(tvm_index_t)),
|
||||||
|
("strides", ctypes.POINTER(tvm_index_t)),
|
||||||
|
("ndim", tvm_index_t),
|
||||||
|
("dtype", TVMType),
|
||||||
|
("ctx", TVMContext)]
|
||||||
|
|
||||||
|
|
||||||
|
TVMArrayHandle = ctypes.POINTER(TVMArray)
|
||||||
|
|
||||||
|
|
||||||
def cpu(dev_id=0):
|
def cpu(dev_id=0):
|
||||||
"""Construct a CPU device
|
"""Construct a CPU device
|
||||||
|
|
||||||
|
@ -72,18 +82,6 @@ def opencl(dev_id=0):
|
||||||
return TVMContext(4, dev_id)
|
return TVMContext(4, dev_id)
|
||||||
|
|
||||||
|
|
||||||
class TVMArray(ctypes.Structure):
|
|
||||||
"""TVMValue in C API"""
|
|
||||||
_fields_ = [("data", ctypes.c_void_p),
|
|
||||||
("shape", ctypes.POINTER(tvm_index_t)),
|
|
||||||
("strides", ctypes.POINTER(tvm_index_t)),
|
|
||||||
("ndim", tvm_index_t),
|
|
||||||
("dtype", TVMType),
|
|
||||||
("ctx", TVMContext)]
|
|
||||||
|
|
||||||
TVMArrayHandle = ctypes.POINTER(TVMArray)
|
|
||||||
|
|
||||||
|
|
||||||
def numpyasarray(np_data):
|
def numpyasarray(np_data):
|
||||||
"""Return a TVMArray representation of a numpy array.
|
"""Return a TVMArray representation of a numpy array.
|
||||||
"""
|
"""
|
||||||
|
@ -102,7 +100,6 @@ def numpyasarray(np_data):
|
||||||
|
|
||||||
|
|
||||||
_ndarray_cls = None
|
_ndarray_cls = None
|
||||||
_function_cls = None
|
|
||||||
|
|
||||||
|
|
||||||
def empty(shape, dtype="float32", ctx=cpu(0)):
|
def empty(shape, dtype="float32", ctx=cpu(0)):
|
||||||
|
@ -275,51 +272,6 @@ class NDArrayBase(object):
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
class FunctionBase(object):
|
def _init_ndarray_module(ndarray_class):
|
||||||
"""A function object at runtim."""
|
|
||||||
__slots__ = ["handle"]
|
|
||||||
# pylint: disable=no-member
|
|
||||||
def __init__(self, handle):
|
|
||||||
"""Initialize the function with handle
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
handle : FunctionHandle
|
|
||||||
the handle to the underlying function.
|
|
||||||
"""
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
check_call(_LIB.TVMFuncFree(self.handle))
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
|
||||||
num_args = len(args)
|
|
||||||
tvm_args = (TVMValue * num_args)()
|
|
||||||
tvm_type_code = (ctypes.c_int * num_args)()
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if arg is None:
|
|
||||||
tvm_args[i].v_handle = None
|
|
||||||
tvm_type_code[i] = TypeCode.NULL
|
|
||||||
elif isinstance(arg, NDArrayBase):
|
|
||||||
tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
|
|
||||||
tvm_type_code[i] = TypeCode.HANDLE
|
|
||||||
elif isinstance(arg, Integral):
|
|
||||||
tvm_args[i].v_int64 = arg
|
|
||||||
tvm_type_code[i] = TypeCode.INT
|
|
||||||
elif isinstance(arg, Number):
|
|
||||||
tvm_args[i].v_float64 = arg
|
|
||||||
tvm_type_code[i] = TypeCode.FLOAT
|
|
||||||
elif isinstance(arg, string_types):
|
|
||||||
tvm_args[i].v_str = c_str(arg)
|
|
||||||
tvm_type_code[i] = TypeCode.STR
|
|
||||||
else:
|
|
||||||
raise TypeError("Don't know how to handle type %s" % type(arg))
|
|
||||||
check_call(_LIB.TVMFuncCall(
|
|
||||||
self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args)))
|
|
||||||
|
|
||||||
|
|
||||||
def _init_runtime_module(ndarray_class, function_class):
|
|
||||||
global _ndarray_cls
|
global _ndarray_cls
|
||||||
global _function_cls
|
|
||||||
_ndarray_cls = ndarray_class
|
_ndarray_cls = ndarray_class
|
||||||
_function_cls = function_class
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable=invalid-name, protected-access
|
||||||
|
# pylint: disable=no-member, missing-docstring
|
||||||
|
"""Symbolic configuration API."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
from numbers import Number, Integral
|
||||||
|
|
||||||
|
from .._base import _LIB, check_call
|
||||||
|
from .._base import c_str, py_str, string_types
|
||||||
|
from .. import _api_internal
|
||||||
|
from ._types import TVMValue, TypeCode, RETURN_SWITCH
|
||||||
|
|
||||||
|
NodeHandle = ctypes.c_void_p
|
||||||
|
|
||||||
|
"""Maps node type to its constructor"""
|
||||||
|
NODE_TYPE = {
|
||||||
|
}
|
||||||
|
|
||||||
|
def _return_node(x):
|
||||||
|
"""Return function"""
|
||||||
|
handle = x.v_handle
|
||||||
|
if not isinstance(handle, NodeHandle):
|
||||||
|
handle = NodeHandle(handle)
|
||||||
|
ret_val = TVMValue()
|
||||||
|
ret_type_code = ctypes.c_int()
|
||||||
|
ret_success = ctypes.c_int()
|
||||||
|
check_call(_LIB.TVMNodeGetAttr(
|
||||||
|
handle, c_str("type_key"),
|
||||||
|
ctypes.byref(ret_val),
|
||||||
|
ctypes.byref(ret_type_code),
|
||||||
|
ctypes.byref(ret_success)))
|
||||||
|
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
|
||||||
|
|
||||||
|
|
||||||
|
class SliceBase(object):
|
||||||
|
"""base class of slice object"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NodeBase(object):
|
||||||
|
"""Symbol is symbolic graph."""
|
||||||
|
__slots__ = ["handle"]
|
||||||
|
# pylint: disable=no-member
|
||||||
|
def __init__(self, handle):
|
||||||
|
"""Initialize the function with handle
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
handle : SymbolHandle
|
||||||
|
the handle to the underlying C++ Symbol
|
||||||
|
"""
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return _api_internal._format_str(self)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
check_call(_LIB.TVMNodeFree(self.handle))
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
ret_val = TVMValue()
|
||||||
|
ret_type_code = ctypes.c_int()
|
||||||
|
ret_success = ctypes.c_int()
|
||||||
|
check_call(_LIB.TVMNodeGetAttr(
|
||||||
|
self.handle, c_str(name),
|
||||||
|
ctypes.byref(ret_val),
|
||||||
|
ctypes.byref(ret_type_code),
|
||||||
|
ctypes.byref(ret_success)))
|
||||||
|
if not ret_success.value:
|
||||||
|
raise AttributeError(
|
||||||
|
"'%s' object has no attribute '%s'" % (str(type(self)), name))
|
||||||
|
return RETURN_SWITCH[ret_type_code.value](ret_val)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return _api_internal._raw_ptr(self)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, NodeBase):
|
||||||
|
return False
|
||||||
|
return self.__hash__() == other.__hash__()
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __dir__(self):
|
||||||
|
plist = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
|
size = ctypes.c_uint()
|
||||||
|
check_call(_LIB.TVMNodeListAttrNames(
|
||||||
|
self.handle, ctypes.byref(size), ctypes.byref(plist)))
|
||||||
|
names = []
|
||||||
|
for i in range(size.value):
|
||||||
|
names.append(py_str(plist[i]))
|
||||||
|
return names
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (type(self), (None,), self.__getstate__())
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
handle = self.handle
|
||||||
|
if handle is not None:
|
||||||
|
return {'handle': _api_internal._save_json(self)}
|
||||||
|
else:
|
||||||
|
return {'handle': None}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# pylint: disable=assigning-non-slot
|
||||||
|
handle = state['handle']
|
||||||
|
if handle is not None:
|
||||||
|
json_str = handle
|
||||||
|
other = _api_internal._load_json(json_str)
|
||||||
|
self.handle = other.handle
|
||||||
|
other.handle = None
|
||||||
|
else:
|
||||||
|
self.handle = None
|
||||||
|
|
||||||
|
|
||||||
|
def const(value, dtype=None):
|
||||||
|
"""construct a constant"""
|
||||||
|
if dtype is None:
|
||||||
|
if isinstance(value, Integral):
|
||||||
|
dtype = 'int32'
|
||||||
|
else:
|
||||||
|
dtype = 'float32'
|
||||||
|
return _api_internal._const(value, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_node(value):
|
||||||
|
"""Convert a python value to corresponding node type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
value : str
|
||||||
|
The value to be inspected.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
node : Node
|
||||||
|
The corresponding node value.
|
||||||
|
"""
|
||||||
|
if isinstance(value, NodeBase):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, Number):
|
||||||
|
return const(value)
|
||||||
|
elif isinstance(value, string_types):
|
||||||
|
return _api_internal._str(value)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
value = [convert_to_node(x) for x in value]
|
||||||
|
return _api_internal._Array(*value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
vlist = []
|
||||||
|
for it in value.items():
|
||||||
|
if not isinstance(it[0], NodeBase):
|
||||||
|
raise ValueError("key of map must already been a container type")
|
||||||
|
vlist.append(it[0])
|
||||||
|
vlist.append(convert_to_node(it[1]))
|
||||||
|
return _api_internal._Map(*vlist)
|
||||||
|
elif isinstance(value, SliceBase):
|
||||||
|
return value.tensor(*value.indices)
|
||||||
|
else:
|
||||||
|
raise ValueError("don't know how to convert type %s to node" % type(value))
|
||||||
|
|
||||||
|
|
||||||
|
def register_node(type_key=None):
|
||||||
|
"""register node type
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
type_key : str or cls
|
||||||
|
The type key of the node
|
||||||
|
"""
|
||||||
|
if isinstance(type_key, str):
|
||||||
|
def register(cls):
|
||||||
|
"""internal register function"""
|
||||||
|
NODE_TYPE[type_key] = cls
|
||||||
|
return cls
|
||||||
|
return register
|
||||||
|
else:
|
||||||
|
cls = type_key
|
||||||
|
NODE_TYPE[cls.__name__] = cls
|
||||||
|
return cls
|
|
@ -4,13 +4,9 @@ from __future__ import absolute_import as _abs
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from .._base import py_str
|
||||||
|
|
||||||
class TVMValue(ctypes.Union):
|
tvm_index_t = ctypes.c_uint32
|
||||||
"""TVMValue in C API"""
|
|
||||||
_fields_ = [("v_int64", ctypes.c_int64),
|
|
||||||
("v_float64", ctypes.c_double),
|
|
||||||
("v_handle", ctypes.c_void_p),
|
|
||||||
("v_str", ctypes.c_char_p)]
|
|
||||||
|
|
||||||
class TypeCode(object):
|
class TypeCode(object):
|
||||||
"""Type code used in API calls"""
|
"""Type code used in API calls"""
|
||||||
|
@ -19,9 +15,11 @@ class TypeCode(object):
|
||||||
FLOAT = 2
|
FLOAT = 2
|
||||||
HANDLE = 3
|
HANDLE = 3
|
||||||
NULL = 4
|
NULL = 4
|
||||||
NODE_HANDLE = 5
|
ARRAY_HANDLE = 5
|
||||||
STR = 6
|
TVM_TYPE = 6
|
||||||
FUNC_HANDLE = 7
|
NODE_HANDLE = 7
|
||||||
|
STR = 8
|
||||||
|
FUNC_HANDLE = 9
|
||||||
|
|
||||||
def _api_type(code):
|
def _api_type(code):
|
||||||
"""create a type accepted by API"""
|
"""create a type accepted by API"""
|
||||||
|
@ -40,13 +38,13 @@ class TVMType(ctypes.Structure):
|
||||||
CODE2STR = {
|
CODE2STR = {
|
||||||
0 : 'int',
|
0 : 'int',
|
||||||
1 : 'uint',
|
1 : 'uint',
|
||||||
2 : 'float'
|
2 : 'float',
|
||||||
|
4 : 'handle'
|
||||||
}
|
}
|
||||||
def __init__(self, type_str, lanes=1):
|
def __init__(self, type_str, lanes=1):
|
||||||
super(TVMType, self).__init__()
|
super(TVMType, self).__init__()
|
||||||
if isinstance(type_str, np.dtype):
|
if isinstance(type_str, np.dtype):
|
||||||
type_str = str(type_str)
|
type_str = str(type_str)
|
||||||
|
|
||||||
if type_str.startswith("int"):
|
if type_str.startswith("int"):
|
||||||
self.type_code = 0
|
self.type_code = 0
|
||||||
bits = int(type_str[3:])
|
bits = int(type_str[3:])
|
||||||
|
@ -56,6 +54,9 @@ class TVMType(ctypes.Structure):
|
||||||
elif type_str.startswith("float"):
|
elif type_str.startswith("float"):
|
||||||
self.type_code = 2
|
self.type_code = 2
|
||||||
bits = int(type_str[5:])
|
bits = int(type_str[5:])
|
||||||
|
elif type_str.startswith("handle"):
|
||||||
|
self.type_code = 4
|
||||||
|
bits = 64
|
||||||
else:
|
else:
|
||||||
raise ValueError("Donot know how to handle type %s" % type_str)
|
raise ValueError("Donot know how to handle type %s" % type_str)
|
||||||
|
|
||||||
|
@ -71,15 +72,61 @@ class TVMType(ctypes.Structure):
|
||||||
x += "x%d" % self.lanes
|
x += "x%d" % self.lanes
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (self.bits == other.bits and
|
||||||
|
self.type_code == other.type_code and
|
||||||
|
self.lanes == other.lanes)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
|
class TVMValue(ctypes.Union):
|
||||||
|
"""TVMValue in C API"""
|
||||||
|
_fields_ = [("v_int64", ctypes.c_int64),
|
||||||
|
("v_float64", ctypes.c_double),
|
||||||
|
("v_handle", ctypes.c_void_p),
|
||||||
|
("v_str", ctypes.c_char_p),
|
||||||
|
("v_type", TVMType)]
|
||||||
|
|
||||||
|
|
||||||
TVMPackedCFunc = ctypes.CFUNCTYPE(
|
TVMPackedCFunc = ctypes.CFUNCTYPE(
|
||||||
None,
|
None,
|
||||||
ctypes.POINTER(TVMValue),
|
ctypes.POINTER(TVMValue),
|
||||||
ctypes.POINTER(ctypes.c_int),
|
ctypes.POINTER(ctypes.c_int),
|
||||||
ctypes.c_int,
|
ctypes.c_int,
|
||||||
|
ctypes.c_void_p,
|
||||||
ctypes.c_void_p)
|
ctypes.c_void_p)
|
||||||
|
|
||||||
|
|
||||||
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
|
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
|
||||||
None,
|
None,
|
||||||
ctypes.c_void_p)
|
ctypes.c_void_p)
|
||||||
|
|
||||||
|
|
||||||
|
def _return_handle(x):
|
||||||
|
"""return handle"""
|
||||||
|
handle = x.v_handle
|
||||||
|
if not isinstance(handle, ctypes.c_void_p):
|
||||||
|
handle = ctypes.c_void_p(handle)
|
||||||
|
return handle
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_SWITCH = {
|
||||||
|
TypeCode.INT: lambda x: x.v_int64,
|
||||||
|
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||||
|
TypeCode.HANDLE: _return_handle,
|
||||||
|
TypeCode.NULL: lambda x: None,
|
||||||
|
TypeCode.TVM_TYPE: lambda x: x.v_type,
|
||||||
|
TypeCode.STR: lambda x: py_str(x.v_str)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
C_TO_PY_ARG_SWITCH = {
|
||||||
|
TypeCode.INT: lambda x: x.v_int64,
|
||||||
|
TypeCode.FLOAT: lambda x: x.v_float64,
|
||||||
|
TypeCode.HANDLE: _return_handle,
|
||||||
|
TypeCode.NULL: lambda x: None,
|
||||||
|
TypeCode.TVM_TYPE: lambda x: x.v_type,
|
||||||
|
TypeCode.STR: lambda x: py_str(x.v_str)
|
||||||
|
}
|
||||||
|
|
|
@ -2,16 +2,23 @@
|
||||||
# pylint: disable=redefined-builtin, undefined-variable, unused-import
|
# pylint: disable=redefined-builtin, undefined-variable, unused-import
|
||||||
"""Functions defined in TVM."""
|
"""Functions defined in TVM."""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
|
|
||||||
from numbers import Integral as _Integral
|
from numbers import Integral as _Integral
|
||||||
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
|
|
||||||
|
from ._ctypes._types import TVMType
|
||||||
|
from ._ctypes._node import register_node, NodeBase
|
||||||
|
from ._ctypes._node import convert_to_node as _convert_to_node
|
||||||
|
from ._ctypes._function import Function
|
||||||
|
from ._ctypes._function import _init_api_functions, register_func, get_global_func
|
||||||
|
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
|
||||||
from . import _api_internal
|
from . import _api_internal
|
||||||
from . import make as _make
|
from . import make as _make
|
||||||
from . import expr as _expr
|
from . import expr as _expr
|
||||||
from . import collections as _collections
|
from . import collections as _collections
|
||||||
|
|
||||||
int32 = "int32"
|
int32 = TVMType("int32")
|
||||||
float32 = "float32"
|
float32 = TVMType("float32")
|
||||||
handle = "handle"
|
handle = TVMType("handle")
|
||||||
|
|
||||||
def const(value, dtype=None):
|
def const(value, dtype=None):
|
||||||
"""construct a constant"""
|
"""construct a constant"""
|
||||||
|
@ -266,4 +273,25 @@ def Schedule(ops):
|
||||||
return _api_internal._Schedule(ops)
|
return _api_internal._Schedule(ops)
|
||||||
|
|
||||||
|
|
||||||
_init_api_module("tvm")
|
def convert(value):
|
||||||
|
"""Convert value to TVM node or function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
value : python value
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tvm_val : Node or function
|
||||||
|
Converted value in TVM
|
||||||
|
"""
|
||||||
|
if isinstance(value, (Function, NodeBase)):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if callable(value):
|
||||||
|
return _convert_tvm_func(value)
|
||||||
|
else:
|
||||||
|
return _convert_to_node(value)
|
||||||
|
|
||||||
|
|
||||||
|
_init_api_functions("tvm")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# pylint: disable=protected-access, no-member
|
# pylint: disable=protected-access, no-member
|
||||||
"""Collection structure in the high level DSL."""
|
"""Collection structure in the high level DSL."""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import NodeBase, register_node
|
from ._ctypes._node import NodeBase, register_node
|
||||||
from . import _api_internal
|
from . import _api_internal
|
||||||
from . import expr as _expr
|
from . import expr as _expr
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# pylint: disable=protected-access, no-member, missing-docstring
|
# pylint: disable=protected-access, no-member, missing-docstring
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import NodeBase, register_node
|
from ._ctypes._node import NodeBase, register_node
|
||||||
from . import make as _make
|
from . import make as _make
|
||||||
|
|
||||||
class ExprOp(object):
|
class ExprOp(object):
|
||||||
|
|
|
@ -6,11 +6,11 @@ This is a simplified runtime API for quick testing and proptyping.
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
import numpy as _np
|
import numpy as _np
|
||||||
|
|
||||||
from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase
|
from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase
|
||||||
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
|
from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync
|
||||||
from ._ctypes._runtime_api import _init_runtime_module
|
from ._ctypes._ndarray import _init_ndarray_module
|
||||||
from ._ctypes._runtime_api import init_opencl
|
from ._ctypes._ndarray import init_opencl
|
||||||
|
from ._ctypes._function import Function
|
||||||
|
|
||||||
class NDArray(NDArrayBase):
|
class NDArray(NDArrayBase):
|
||||||
"""Lightweight NDArray class of TVM runtime.
|
"""Lightweight NDArray class of TVM runtime.
|
||||||
|
@ -26,11 +26,6 @@ class NDArray(NDArrayBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Function(FunctionBase):
|
|
||||||
"""Function class that can executed a generated code."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def array(arr, ctx=cpu(0)):
|
def array(arr, ctx=cpu(0)):
|
||||||
"""Create an array from source arr.
|
"""Create an array from source arr.
|
||||||
|
|
||||||
|
@ -54,4 +49,4 @@ def array(arr, ctx=cpu(0)):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
_init_runtime_module(NDArray, Function)
|
_init_ndarray_module(NDArray)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# pylint: disable=protected-access, no-member
|
# pylint: disable=protected-access, no-member
|
||||||
"""Collection structure in the high level DSL."""
|
"""Collection structure in the high level DSL."""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import NodeBase, register_node
|
from ._ctypes._node import NodeBase, register_node
|
||||||
from . import _api_internal
|
from . import _api_internal
|
||||||
from . import tensor as _tensor
|
from . import tensor as _tensor
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# pylint: disable=protected-access, no-member, missing-docstring
|
# pylint: disable=protected-access, no-member, missing-docstring
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import NodeBase, register_node
|
from ._ctypes._node import NodeBase, register_node
|
||||||
|
|
||||||
class Stmt(NodeBase):
|
class Stmt(NodeBase):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# pylint: disable=protected-access, no-member, invalid-name
|
# pylint: disable=protected-access, no-member, invalid-name
|
||||||
"""Tensor related abstractions"""
|
"""Tensor related abstractions"""
|
||||||
from __future__ import absolute_import as _abs
|
from __future__ import absolute_import as _abs
|
||||||
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
|
from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
|
||||||
from . import collections as _collections
|
from . import collections as _collections
|
||||||
from . import _api_internal
|
from . import _api_internal
|
||||||
from . import make as _make
|
from . import make as _make
|
||||||
|
@ -26,7 +26,7 @@ class Tensor(NodeBase):
|
||||||
ndim = self.ndim
|
ndim = self.ndim
|
||||||
if len(indices) != ndim:
|
if len(indices) != ndim:
|
||||||
raise ValueError("Need to provide %d index in tensor slice" % ndim)
|
raise ValueError("Need to provide %d index in tensor slice" % ndim)
|
||||||
indices = convert(indices)
|
indices = convert_to_node(indices)
|
||||||
args = []
|
args = []
|
||||||
for x in indices:
|
for x in indices:
|
||||||
if isinstance(x, _collections.IterVar):
|
if isinstance(x, _collections.IterVar):
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* Implementation of basic API functions
|
||||||
|
* \file api_base.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/tensor.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_format_str)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
CHECK(args[0].type_code() == kNodeHandle);
|
||||||
|
std::ostringstream os;
|
||||||
|
os << args[0].operator NodeRef();
|
||||||
|
*ret = os.str();
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_raw_ptr)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
CHECK(args[0].type_code() == kNodeHandle);
|
||||||
|
*ret = reinterpret_cast<int64_t>(
|
||||||
|
args[0].node_sptr().get());
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_save_json)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = SaveJSON(args[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_load_json)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = NodeRef(LoadJSON_(args[0]));
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace tvm
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2016 by Contributors
|
||||||
|
* Implementation of API functions related to Codegen
|
||||||
|
* \file c_api_codegen.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/codegen.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
|
#include "../codegen/codegen_c.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_CompileToC)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = CodeGenC().Compile(args[0], args[1]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_MakeAPI)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = MakeAPI(
|
||||||
|
args[0], args[1], args[2], args[3]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_SplitHostDevice)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = SplitHostDevice(args[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
// generate a dummy packed function for testing
|
||||||
|
void DummyHelloFunction(TVMArgs args, TVMRetValue* rv) {
|
||||||
|
LOG(INFO) << args.size() << " arguments";
|
||||||
|
for (int i = 0; i < args.size(); ++i) {
|
||||||
|
switch (args.type_codes[i]) {
|
||||||
|
case kNull: LOG(INFO) << i << ":nullptr"; break;
|
||||||
|
case kFloat: LOG(INFO) << i << ": double=" << args.values[i].v_float64; break;
|
||||||
|
case kInt: LOG(INFO) << i << ": long=" << args.values[i].v_int64; break;
|
||||||
|
case kHandle: LOG(INFO) << i << ": handle=" << args.values[i].v_handle; break;
|
||||||
|
case kArrayHandle: LOG(INFO) << i << ": array_handle=" << args.values[i].v_handle; break;
|
||||||
|
default: LOG(FATAL) << "unhandled type " << runtime::TypeCode2Str(args.type_codes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_DummyHelloFunction)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = runtime::PackedFunc(DummyHelloFunction);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_codegen_BuildStackVM)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
*ret = BuildStackVM(args[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace tvm
|
|
@ -1,98 +1,93 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2016 by Contributors
|
* Copyright (c) 2016 by Contributors
|
||||||
* Implementation of API functions related to IR build
|
* Implementation of API functions related to IR build
|
||||||
* \file c_api_ir.cc
|
* \file api_ir.cc
|
||||||
*/
|
*/
|
||||||
#include <tvm/expr.h>
|
#include <tvm/expr.h>
|
||||||
#include <tvm/ir.h>
|
#include <tvm/ir.h>
|
||||||
#include <ir/IROperator.h>
|
#include <ir/IROperator.h>
|
||||||
#include "./c_api_registry.h"
|
#include <tvm/api_registry.h>
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace ir {
|
namespace ir {
|
||||||
|
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Var)
|
TVM_REGISTER_API(_Var)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = Variable::make(args.at(1), args.at(0));
|
*ret = Variable::make(args[1], args[0]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API(_make_For)
|
TVM_REGISTER_API(_make_For)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = For::make(args.at(0),
|
*ret = For::make(args[0],
|
||||||
args.at(1),
|
args[1],
|
||||||
args.at(2),
|
args[2],
|
||||||
static_cast<ForType>(args.at(3).operator int()),
|
static_cast<ForType>(args[3].operator int()),
|
||||||
static_cast<Halide::DeviceAPI>(args.at(4).operator int()),
|
static_cast<Halide::DeviceAPI>(args[4].operator int()),
|
||||||
args.at(5));
|
args[5]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API(_make_Realize)
|
TVM_REGISTER_API(_make_Realize)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = Realize::make(args.at(0),
|
*ret = Realize::make(args[0],
|
||||||
args.at(1),
|
args[1],
|
||||||
args.at(2),
|
args[2],
|
||||||
args.at(3),
|
args[3],
|
||||||
args.at(4),
|
args[4],
|
||||||
args.at(5));
|
args[5]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_make_Call)
|
TVM_REGISTER_API(_make_Call)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = Call::make(args.at(0),
|
*ret = Call::make(args[0],
|
||||||
args.at(1),
|
args[1],
|
||||||
args.at(2),
|
args[2],
|
||||||
static_cast<Call::CallType>(args.at(3).operator int()),
|
static_cast<Call::CallType>(args[3].operator int()),
|
||||||
args.at(4),
|
args[4],
|
||||||
args.at(5));
|
args[5]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API(_make_Allocate)
|
TVM_REGISTER_API(_make_Allocate)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = Allocate::make(args.at(0),
|
*ret = Allocate::make(args[0],
|
||||||
args.at(1),
|
args[1],
|
||||||
args.at(2),
|
args[2],
|
||||||
args.at(3),
|
args[3],
|
||||||
args.at(4));
|
args[4]);
|
||||||
});
|
});
|
||||||
|
|
||||||
// make from two arguments
|
// make from two arguments
|
||||||
#define REGISTER_MAKE1(Node) \
|
#define REGISTER_MAKE1(Node) \
|
||||||
TVM_REGISTER_API(_make_## Node) \
|
TVM_REGISTER_API(_make_## Node) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = Node::make(args.at(0)); \
|
*ret = Node::make(args[0]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_MAKE2(Node) \
|
#define REGISTER_MAKE2(Node) \
|
||||||
TVM_REGISTER_API(_make_## Node) \
|
TVM_REGISTER_API(_make_## Node) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = Node::make(args.at(0), args.at(1)); \
|
*ret = Node::make(args[0], args[1]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_MAKE3(Node) \
|
#define REGISTER_MAKE3(Node) \
|
||||||
TVM_REGISTER_API(_make_## Node) \
|
TVM_REGISTER_API(_make_## Node) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
|
*ret = Node::make(args[0], args[1], args[2]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_MAKE4(Node) \
|
#define REGISTER_MAKE4(Node) \
|
||||||
TVM_REGISTER_API(_make_## Node) \
|
TVM_REGISTER_API(_make_## Node) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \
|
*ret = Node::make(args[0], args[1], args[2], args[3]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_MAKE_BINARY_OP(Node) \
|
#define REGISTER_MAKE_BINARY_OP(Node) \
|
||||||
TVM_REGISTER_API(_make_## Node) \
|
TVM_REGISTER_API(_make_## Node) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
Expr a = args.at(0), b = args.at(1); \
|
Expr a = args[0], b = args[1]; \
|
||||||
match_types(a, b); \
|
match_types(a, b); \
|
||||||
*ret = Node::make(a, b); \
|
*ret = Node::make(a, b); \
|
||||||
}) \
|
})
|
||||||
.add_argument("lhs", "Expr", "left operand") \
|
|
||||||
.add_argument("rhs", "Expr", "right operand")
|
|
||||||
|
|
||||||
REGISTER_MAKE3(Reduce);
|
REGISTER_MAKE3(Reduce);
|
||||||
REGISTER_MAKE4(AttrStmt);
|
REGISTER_MAKE4(AttrStmt);
|
|
@ -0,0 +1,256 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2016 by Contributors
|
||||||
|
* Implementation of API functions related to Higher DSL build.
|
||||||
|
* \file api_lang.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
#include <tvm/tensor.h>
|
||||||
|
#include <tvm/buffer.h>
|
||||||
|
#include <tvm/schedule.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_const)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
if (args[0].type_code() == kInt) {
|
||||||
|
*ret = make_const(args[1], args[0].operator int64_t());
|
||||||
|
} else if (args[0].type_code() == kFloat) {
|
||||||
|
*ret = make_const(args[1], args[0].operator double());
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "only accept int or float";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_str)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = ir::StringImm::make(args[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Array)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
std::vector<std::shared_ptr<Node> > data;
|
||||||
|
for (int i = 0; i < args.size(); ++i) {
|
||||||
|
data.push_back(args[i].node_sptr());
|
||||||
|
}
|
||||||
|
auto node = std::make_shared<ArrayNode>();
|
||||||
|
node->data = std::move(data);
|
||||||
|
*ret = node;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_ArrayGetItem)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
int64_t i = args[1];
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<ArrayNode>());
|
||||||
|
auto* n = static_cast<const ArrayNode*>(sptr.get());
|
||||||
|
CHECK_LT(static_cast<size_t>(i), n->data.size())
|
||||||
|
<< "out of bound of array";
|
||||||
|
*ret = n->data[i];
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_ArraySize)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<ArrayNode>());
|
||||||
|
*ret = static_cast<int64_t>(
|
||||||
|
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Map)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
CHECK_EQ(args.size() % 2, 0);
|
||||||
|
MapNode::ContainerType data;
|
||||||
|
for (int i = 0; i < args.num_args; i += 2) {
|
||||||
|
CHECK(args[i].type_code() == kNodeHandle)
|
||||||
|
<< "need content of array to be NodeBase";
|
||||||
|
CHECK(args[i + 1].type_code() == kNodeHandle)
|
||||||
|
<< "need content of array to be NodeBase";
|
||||||
|
data.emplace(std::make_pair(args[i].node_sptr(),
|
||||||
|
args[i + 1].node_sptr()));
|
||||||
|
}
|
||||||
|
auto node = std::make_shared<MapNode>();
|
||||||
|
node->data = std::move(data);
|
||||||
|
*ret = node;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_MapSize)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<MapNode>());
|
||||||
|
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||||
|
*ret = static_cast<int64_t>(n->data.size());
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_MapGetItem)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
CHECK(args[0].type_code() == kNodeHandle);
|
||||||
|
CHECK(args[1].type_code() == kNodeHandle);
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<MapNode>());
|
||||||
|
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||||
|
auto it = n->data.find(args[1].node_sptr());
|
||||||
|
CHECK(it != n->data.end())
|
||||||
|
<< "cannot find the corresponding key in the Map";
|
||||||
|
*ret = (*it).second;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_MapCount)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
CHECK(args[0].type_code() == kNodeHandle);
|
||||||
|
CHECK(args[1].type_code() == kNodeHandle);
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<MapNode>());
|
||||||
|
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||||
|
*ret = static_cast<int64_t>(
|
||||||
|
n->data.count(args[1].node_sptr()));
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_MapItems)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
auto& sptr = args[0].node_sptr();
|
||||||
|
CHECK(sptr->is_type<MapNode>());
|
||||||
|
auto* n = static_cast<const MapNode*>(sptr.get());
|
||||||
|
auto rkvs = std::make_shared<ArrayNode>();
|
||||||
|
for (const auto& kv : n->data) {
|
||||||
|
rkvs->data.push_back(kv.first);
|
||||||
|
rkvs->data.push_back(kv.second);
|
||||||
|
}
|
||||||
|
*ret = rkvs;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(Range)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
if (args.size() == 1) {
|
||||||
|
*ret = Range(0, args[0]);
|
||||||
|
} else {
|
||||||
|
*ret = Range(args[0], args[1]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Buffer)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = BufferNode::make(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2],
|
||||||
|
args[3],
|
||||||
|
args[4]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Tensor)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = TensorNode::make(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2],
|
||||||
|
args[3]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_TensorEqual)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = args[0].operator Tensor() == args[1].operator Tensor();
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_TensorHash)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = static_cast<int64_t>(
|
||||||
|
std::hash<Tensor>()(args[0].operator Tensor()));
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Placeholder)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = Placeholder(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_ComputeOp)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = ComputeOpNode::make(args[0],
|
||||||
|
args[1],
|
||||||
|
args[2]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_OpGetOutput)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = args[0].operator Operation().output(
|
||||||
|
args[1].operator int64_t());
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_IterVar)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = IterVar(args[0], args[1], args[2]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Schedule)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = Schedule(args[0].operator Array<Operation>());
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageSetScope)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
args[0].operator Stage()
|
||||||
|
.set_scope(args[1]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageSplitByFactor)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
IterVar outer, inner;
|
||||||
|
args[0].operator Stage()
|
||||||
|
.split(args[1], &outer, &inner, args[2]);
|
||||||
|
*ret = Array<IterVar>({outer, inner});
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageSplitByOuter)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
IterVar inner;
|
||||||
|
args[0].operator Stage()
|
||||||
|
.split(args[1], args[2], &inner, args[3]);
|
||||||
|
*ret = inner;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageFuse)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
IterVar fused;
|
||||||
|
args[0].operator Stage()
|
||||||
|
.split(args[1], args[2], &fused);
|
||||||
|
*ret = fused;
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageComputeAt)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
args[0].operator Stage()
|
||||||
|
.compute_at(args[1], args[2]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageComputeInline)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
args[0].operator Stage()
|
||||||
|
.compute_inline();
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageComputeRoot)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
args[0].operator Stage()
|
||||||
|
.compute_root();
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageReorder)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
args[0].operator Stage()
|
||||||
|
.reorder(args[1]);
|
||||||
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_StageTile)
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
IterVar x_outer, y_outer, x_inner, y_inner;
|
||||||
|
args[0].operator Stage()
|
||||||
|
.tile(args[1], args[2], &x_outer, &y_outer,
|
||||||
|
&x_inner, &y_inner, args[3], args[4]);
|
||||||
|
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace tvm
|
|
@ -1,53 +1,51 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2016 by Contributors
|
* Copyright (c) 2017 by Contributors
|
||||||
* Exposre of pass functions.
|
* Exposre of pass functions.
|
||||||
* \file c_api_pass.cc
|
* \file api_pass.cc
|
||||||
*/
|
*/
|
||||||
#include <tvm/expr.h>
|
#include <tvm/expr.h>
|
||||||
#include <tvm/ir.h>
|
#include <tvm/ir.h>
|
||||||
#include <tvm/ir_pass.h>
|
#include <tvm/ir_pass.h>
|
||||||
#include "./c_api_registry.h"
|
#include <tvm/api_registry.h>
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace ir {
|
namespace ir {
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_pass_Simplify)
|
TVM_REGISTER_API(_pass_Simplify)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
|
if (args[0].IsNodeType<Stmt>()) {
|
||||||
*ret = Simplify(args.at(0).operator Stmt());
|
*ret = Simplify(args[0].operator Stmt());
|
||||||
} else {
|
} else {
|
||||||
*ret = Simplify(args.at(0).operator Expr());
|
*ret = Simplify(args[0].operator Expr());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API(_pass_Equal)
|
TVM_REGISTER_API(_pass_Equal)
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
|
if (args[0].IsNodeType<Stmt>()) {
|
||||||
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
|
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
|
||||||
} else {
|
} else {
|
||||||
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
|
*ret = Equal(args[0].operator Expr(), args[1].operator Expr());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// make from two arguments
|
// make from two arguments
|
||||||
#define REGISTER_PASS1(PassName) \
|
#define REGISTER_PASS1(PassName) \
|
||||||
TVM_REGISTER_API(_pass_## PassName) \
|
TVM_REGISTER_API(_pass_## PassName) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = PassName(args.at(0)); \
|
*ret = PassName(args[0]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_PASS2(PassName) \
|
#define REGISTER_PASS2(PassName) \
|
||||||
TVM_REGISTER_API(_pass_## PassName) \
|
TVM_REGISTER_API(_pass_## PassName) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = PassName(args.at(0), args.at(1)); \
|
*ret = PassName(args[0], args[1]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_PASS4(PassName) \
|
#define REGISTER_PASS4(PassName) \
|
||||||
TVM_REGISTER_API(_pass_## PassName) \
|
TVM_REGISTER_API(_pass_## PassName) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \
|
*ret = PassName(args[0], args[1], args[2], args[3]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
REGISTER_PASS1(ConvertSSA);
|
REGISTER_PASS1(ConvertSSA);
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file api_registry.cc
|
||||||
|
*/
|
||||||
|
#include <tvm/expr.h>
|
||||||
|
#include <tvm/tensor.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
|
||||||
|
struct APIManager {
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<APIRegistry> > fmap;
|
||||||
|
|
||||||
|
static APIManager* Global() {
|
||||||
|
static APIManager inst;
|
||||||
|
return &inst;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
APIRegistry& APIRegistry::__REGISTER__(const std::string& name) { // NOLINT(*)
|
||||||
|
APIManager* m = APIManager::Global();
|
||||||
|
CHECK(!m->fmap.count(name))
|
||||||
|
<< "API function " << name << " has already been registered";
|
||||||
|
std::unique_ptr<APIRegistry> p(new APIRegistry());
|
||||||
|
p->name_ = name;
|
||||||
|
m->fmap[name] = std::move(p);
|
||||||
|
return *(m->fmap[name]);
|
||||||
|
}
|
||||||
|
|
||||||
|
APIRegistry& APIRegistry::set_body(PackedFunc f) { // NOLINT(*)
|
||||||
|
PackedFunc::RegisterGlobal(name_, f);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
} // namespace tvm
|
|
@ -1,30 +1,28 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2016 by Contributors
|
* Copyright (c) 2017 by Contributors
|
||||||
* Implementation of API functions related to schedule pass.
|
* Implementation of API functions related to schedule pass.
|
||||||
* \file c_api_lang.cc
|
* \file api_schedule.cc
|
||||||
*/
|
*/
|
||||||
#include <tvm/expr.h>
|
#include <tvm/expr.h>
|
||||||
#include <tvm/tensor.h>
|
#include <tvm/tensor.h>
|
||||||
#include <tvm/schedule.h>
|
#include <tvm/schedule.h>
|
||||||
#include <tvm/schedule_pass.h>
|
#include <tvm/schedule_pass.h>
|
||||||
#include "./c_api_registry.h"
|
#include <tvm/api_registry.h>
|
||||||
#include "../schedule/graph.h"
|
#include "../schedule/graph.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace schedule {
|
namespace schedule {
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
#define REGISTER_SCHEDULE_PASS1(PassName) \
|
#define REGISTER_SCHEDULE_PASS1(PassName) \
|
||||||
TVM_REGISTER_API(_schedule_## PassName) \
|
TVM_REGISTER_API(_schedule_## PassName) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = PassName(args.at(0)); \
|
*ret = PassName(args[0]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
#define REGISTER_SCHEDULE_PASS2(PassName) \
|
#define REGISTER_SCHEDULE_PASS2(PassName) \
|
||||||
TVM_REGISTER_API(_schedule_## PassName) \
|
TVM_REGISTER_API(_schedule_## PassName) \
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) { \
|
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||||
*ret = PassName(args.at(0), args.at(1)); \
|
*ret = PassName(args[0], args[1]); \
|
||||||
}) \
|
}) \
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2016 by Contributors
|
||||||
|
* Implementation of C API
|
||||||
|
* \file c_api.cc
|
||||||
|
*/
|
||||||
|
#include <dmlc/base.h>
|
||||||
|
#include <dmlc/logging.h>
|
||||||
|
#include <dmlc/thread_local.h>
|
||||||
|
#include <tvm/c_api.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <exception>
|
||||||
|
#include "../runtime/runtime_base.h"
|
||||||
|
|
||||||
|
|
||||||
|
/*! \brief entry to to easily hold returning information */
|
||||||
|
struct TVMAPIThreadLocalEntry {
|
||||||
|
/*! \brief result holder for returning strings */
|
||||||
|
std::vector<std::string> ret_vec_str;
|
||||||
|
/*! \brief result holder for returning string pointers */
|
||||||
|
std::vector<const char *> ret_vec_charp;
|
||||||
|
/*! \brief result holder for retruning string */
|
||||||
|
std::string ret_str;
|
||||||
|
};
|
||||||
|
|
||||||
|
using namespace tvm;
|
||||||
|
|
||||||
|
/*! \brief Thread local store that can be used to hold return values. */
|
||||||
|
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
|
||||||
|
|
||||||
|
using TVMAPINode = std::shared_ptr<Node>;
|
||||||
|
|
||||||
|
struct APIAttrGetter : public AttrVisitor {
|
||||||
|
std::string skey;
|
||||||
|
TVMRetValue* ret;
|
||||||
|
bool found_node_ref{false};
|
||||||
|
|
||||||
|
void Visit(const char* key, double* value) final {
|
||||||
|
if (skey == key) *ret = value[0];
|
||||||
|
}
|
||||||
|
void Visit(const char* key, int64_t* value) final {
|
||||||
|
if (skey == key) *ret = value[0];
|
||||||
|
}
|
||||||
|
void Visit(const char* key, uint64_t* value) final {
|
||||||
|
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
|
||||||
|
<< "cannot return too big constant";
|
||||||
|
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, int* value) final {
|
||||||
|
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, bool* value) final {
|
||||||
|
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, Type* value) final {
|
||||||
|
if (skey == key) *ret = value[0];
|
||||||
|
}
|
||||||
|
void Visit(const char* key, std::string* value) final {
|
||||||
|
if (skey == key) *ret = value[0];
|
||||||
|
}
|
||||||
|
void Visit(const char* key, NodeRef* value) final {
|
||||||
|
if (skey == key) {
|
||||||
|
*ret = value[0];
|
||||||
|
found_node_ref = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct APIAttrDir : public AttrVisitor {
|
||||||
|
std::vector<std::string>* names;
|
||||||
|
|
||||||
|
void Visit(const char* key, double* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, int64_t* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, uint64_t* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, bool* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, int* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, Type* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, std::string* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
void Visit(const char* key, NodeRef* value) final {
|
||||||
|
names->push_back(key);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
int TVMNodeFree(NodeHandle handle) {
|
||||||
|
API_BEGIN();
|
||||||
|
delete static_cast<TVMAPINode*>(handle);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TVMNodeGetAttr(NodeHandle handle,
|
||||||
|
const char* key,
|
||||||
|
TVMValue* ret_val,
|
||||||
|
int* ret_type_code,
|
||||||
|
int* ret_success) {
|
||||||
|
API_BEGIN();
|
||||||
|
TVMRetValue rv;
|
||||||
|
APIAttrGetter getter;
|
||||||
|
getter.skey = key;
|
||||||
|
getter.ret = &rv;
|
||||||
|
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||||
|
if (getter.skey == "type_key") {
|
||||||
|
ret_val->v_str = (*tnode)->type_key();
|
||||||
|
*ret_type_code = kStr;
|
||||||
|
*ret_success = 1;
|
||||||
|
} else {
|
||||||
|
(*tnode)->VisitAttrs(&getter);
|
||||||
|
*ret_success = getter.found_node_ref || rv.type_code() != kNull;
|
||||||
|
if (rv.type_code() == kStr) {
|
||||||
|
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
|
||||||
|
e->ret_str = rv.operator std::string();
|
||||||
|
*ret_type_code = kStr;
|
||||||
|
ret_val->v_str = e->ret_str.c_str();
|
||||||
|
} else {
|
||||||
|
rv.MoveToCHost(ret_val, ret_type_code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TVMNodeListAttrNames(NodeHandle handle,
|
||||||
|
int *out_size,
|
||||||
|
const char*** out_array) {
|
||||||
|
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||||
|
API_BEGIN();
|
||||||
|
ret->ret_vec_str.clear();
|
||||||
|
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
||||||
|
APIAttrDir dir;
|
||||||
|
dir.names = &(ret->ret_vec_str);
|
||||||
|
(*tnode)->VisitAttrs(&dir);
|
||||||
|
ret->ret_vec_charp.clear();
|
||||||
|
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||||
|
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||||
|
}
|
||||||
|
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||||
|
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||||
|
API_END();
|
||||||
|
}
|
|
@ -42,5 +42,78 @@ inline Type String2Type(std::string s) {
|
||||||
return Type(code, bits, lanes);
|
return Type(code, bits, lanes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline const char* TVMTypeCode2Str(int type_code) {
|
||||||
|
switch (type_code) {
|
||||||
|
case kInt: return "int";
|
||||||
|
case kFloat: return "float";
|
||||||
|
case kStr: return "str";
|
||||||
|
case kHandle: return "Handle";
|
||||||
|
case kNull: return "NULL";
|
||||||
|
case kNodeHandle: return "NodeHandle";
|
||||||
|
default: LOG(FATAL) << "unknown type_code="
|
||||||
|
<< static_cast<int>(type_code); return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template<typename T>
|
||||||
|
struct NodeTypeChecker {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
// This is the only place in the project where RTTI is used
|
||||||
|
// It can be turned off, but will make non strict checking.
|
||||||
|
// TODO(tqchen) possibly find alternative to turn of RTTI
|
||||||
|
using ContainerType = typename T::ContainerType;
|
||||||
|
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
using ContainerType = typename T::ContainerType;
|
||||||
|
os << ContainerType::_type_key;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct NodeTypeChecker<Array<T> > {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
if (sptr == nullptr) return false;
|
||||||
|
if (!sptr->is_type<ArrayNode>()) return false;
|
||||||
|
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
||||||
|
for (const auto& p : n->data) {
|
||||||
|
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
os << "array<";
|
||||||
|
NodeTypeChecker<T>::PrintName(os);
|
||||||
|
os << ">";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename K, typename V>
|
||||||
|
struct NodeTypeChecker<Map<K, V> > {
|
||||||
|
static inline bool Check(Node* sptr) {
|
||||||
|
if (sptr == nullptr) return false;
|
||||||
|
if (!sptr->is_type<MapNode>()) return false;
|
||||||
|
MapNode* n = static_cast<MapNode*>(sptr);
|
||||||
|
for (const auto& kv : n->data) {
|
||||||
|
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
||||||
|
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
||||||
|
os << "map<";
|
||||||
|
NodeTypeChecker<K>::PrintName(os);
|
||||||
|
os << ',';
|
||||||
|
NodeTypeChecker<V>::PrintName(os);
|
||||||
|
os << '>';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline std::string NodeTypeName() {
|
||||||
|
std::ostringstream os;
|
||||||
|
NodeTypeChecker<T>::PrintName(os);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
#endif // TVM_BASE_COMMON_H_
|
#endif // TVM_BASE_COMMON_H_
|
||||||
|
|
|
@ -1,260 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* Implementation of C API
|
|
||||||
* \file c_api.cc
|
|
||||||
*/
|
|
||||||
#include <tvm/c_api.h>
|
|
||||||
#include "./c_api_common.h"
|
|
||||||
#include "./c_api_registry.h"
|
|
||||||
|
|
||||||
/*! \brief entry to to easily hold returning information */
|
|
||||||
struct TVMAPIThreadLocalEntry {
|
|
||||||
/*! \brief result holder for returning strings */
|
|
||||||
std::vector<std::string> ret_vec_str;
|
|
||||||
/*! \brief result holder for returning string pointers */
|
|
||||||
std::vector<const char *> ret_vec_charp;
|
|
||||||
/*! \brief argument stack */
|
|
||||||
std::vector<tvm::APIVariantValue> arg_stack;
|
|
||||||
/*! \brief return value */
|
|
||||||
tvm::APIVariantValue ret_value;
|
|
||||||
// clear calling stack
|
|
||||||
inline void Clear() {
|
|
||||||
arg_stack.clear();
|
|
||||||
ret_value.sptr.reset();
|
|
||||||
}
|
|
||||||
inline void SetReturn(TVMValue* ret_val, int* ret_type_code);
|
|
||||||
};
|
|
||||||
|
|
||||||
using namespace tvm;
|
|
||||||
|
|
||||||
/*! \brief Thread local store that can be used to hold return values. */
|
|
||||||
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
|
|
||||||
|
|
||||||
using TVMAPINode = std::shared_ptr<Node>;
|
|
||||||
|
|
||||||
struct APIAttrGetter : public AttrVisitor {
|
|
||||||
std::string skey;
|
|
||||||
APIVariantValue* ret;
|
|
||||||
bool found_node_ref{false};
|
|
||||||
|
|
||||||
void Visit(const char* key, double* value) final {
|
|
||||||
if (skey == key) *ret = value[0];
|
|
||||||
}
|
|
||||||
void Visit(const char* key, int64_t* value) final {
|
|
||||||
if (skey == key) *ret = value[0];
|
|
||||||
}
|
|
||||||
void Visit(const char* key, uint64_t* value) final {
|
|
||||||
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
|
|
||||||
<< "cannot return too big constant";
|
|
||||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, int* value) final {
|
|
||||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, bool* value) final {
|
|
||||||
if (skey == key) *ret = static_cast<int64_t>(value[0]);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, Type* value) final {
|
|
||||||
if (skey == key) *ret = value[0];
|
|
||||||
}
|
|
||||||
void Visit(const char* key, std::string* value) final {
|
|
||||||
if (skey == key) *ret = value[0];
|
|
||||||
}
|
|
||||||
void Visit(const char* key, NodeRef* value) final {
|
|
||||||
if (skey == key) {
|
|
||||||
*ret = value[0];
|
|
||||||
found_node_ref = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct APIAttrDir : public AttrVisitor {
|
|
||||||
std::vector<std::string>* names;
|
|
||||||
|
|
||||||
void Visit(const char* key, double* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, int64_t* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, uint64_t* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, bool* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, int* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, Type* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, std::string* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
void Visit(const char* key, NodeRef* value) final {
|
|
||||||
names->push_back(key);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
int TVMListAPIFuncNames(int *out_size,
|
|
||||||
const char*** out_array) {
|
|
||||||
API_BEGIN();
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames();
|
|
||||||
ret->ret_vec_charp.clear();
|
|
||||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
|
||||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
|
||||||
}
|
|
||||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
|
||||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMGetAPIFuncHandle(const char* fname,
|
|
||||||
APIFuncHandle* out) {
|
|
||||||
API_BEGIN();
|
|
||||||
const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname);
|
|
||||||
CHECK(reg != nullptr) << "cannot find function " << fname;
|
|
||||||
*out = (APIFuncHandle)reg;
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMGetAPIFuncInfo(APIFuncHandle handle,
|
|
||||||
const char **real_name,
|
|
||||||
const char **description,
|
|
||||||
int *num_doc_args,
|
|
||||||
const char ***arg_names,
|
|
||||||
const char ***arg_type_infos,
|
|
||||||
const char ***arg_descriptions,
|
|
||||||
const char **return_type) {
|
|
||||||
const auto *op = static_cast<const APIFuncReg *>(handle);
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
|
|
||||||
API_BEGIN();
|
|
||||||
*real_name = op->name.c_str();
|
|
||||||
*description = op->description.c_str();
|
|
||||||
*num_doc_args = static_cast<int>(op->arguments.size());
|
|
||||||
if (return_type) *return_type = nullptr;
|
|
||||||
ret->ret_vec_charp.clear();
|
|
||||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
|
||||||
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
|
||||||
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < op->arguments.size(); ++i) {
|
|
||||||
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
|
|
||||||
}
|
|
||||||
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
|
|
||||||
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
|
|
||||||
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMAPIPushStack(TVMValue arg,
|
|
||||||
int type_code) {
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
API_BEGIN();
|
|
||||||
ret->arg_stack.resize(ret->arg_stack.size() + 1);
|
|
||||||
APIVariantValue& v = ret->arg_stack.back();
|
|
||||||
|
|
||||||
v.type_code = type_code;
|
|
||||||
switch (type_code) {
|
|
||||||
case kInt: case kUInt: case kFloat: case kNull: {
|
|
||||||
v.v_union = arg; break;
|
|
||||||
}
|
|
||||||
case kStr: {
|
|
||||||
v.str = arg.v_str; break;
|
|
||||||
}
|
|
||||||
case kNodeHandle: {
|
|
||||||
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break;
|
|
||||||
}
|
|
||||||
default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code);
|
|
||||||
}
|
|
||||||
API_END_HANDLE_ERROR(ret->Clear());
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMAPIFuncCall(APIFuncHandle handle,
|
|
||||||
TVMValue* ret_val,
|
|
||||||
int* ret_type_code) {
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
API_BEGIN();
|
|
||||||
const auto *op = static_cast<const APIFuncReg *>(handle);
|
|
||||||
op->body(ret->arg_stack, &(ret->ret_value));
|
|
||||||
ret->SetReturn(ret_val, ret_type_code);
|
|
||||||
ret->arg_stack.clear();
|
|
||||||
API_END_HANDLE_ERROR(ret->Clear());
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMNodeFree(NodeHandle handle) {
|
|
||||||
API_BEGIN();
|
|
||||||
delete static_cast<TVMAPINode*>(handle);
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMNodeGetAttr(NodeHandle handle,
|
|
||||||
const char* key,
|
|
||||||
TVMValue* ret_val,
|
|
||||||
int* ret_type_code,
|
|
||||||
int* ret_success) {
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
API_BEGIN();
|
|
||||||
ret->ret_value.type_code = kNull;
|
|
||||||
APIAttrGetter getter;
|
|
||||||
getter.skey = key;
|
|
||||||
getter.ret = &(ret->ret_value);
|
|
||||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
|
||||||
if (getter.skey == "type_key") {
|
|
||||||
ret_val->v_str = (*tnode)->type_key();
|
|
||||||
*ret_type_code = kStr;
|
|
||||||
*ret_success = 1;
|
|
||||||
} else {
|
|
||||||
(*tnode)->VisitAttrs(&getter);
|
|
||||||
if (ret->ret_value.type_code != kNull) {
|
|
||||||
ret->SetReturn(ret_val, ret_type_code);
|
|
||||||
*ret_success = 1;
|
|
||||||
} else {
|
|
||||||
*ret_success = getter.found_node_ref ? 1 : 0;
|
|
||||||
*ret_type_code = kNull;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
API_END_HANDLE_ERROR(ret->Clear());
|
|
||||||
}
|
|
||||||
|
|
||||||
int TVMNodeListAttrNames(NodeHandle handle,
|
|
||||||
int *out_size,
|
|
||||||
const char*** out_array) {
|
|
||||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
|
||||||
API_BEGIN();
|
|
||||||
ret->ret_vec_str.clear();
|
|
||||||
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
|
|
||||||
APIAttrDir dir;
|
|
||||||
dir.names = &(ret->ret_vec_str);
|
|
||||||
(*tnode)->VisitAttrs(&dir);
|
|
||||||
ret->ret_vec_charp.clear();
|
|
||||||
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
|
||||||
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
|
||||||
}
|
|
||||||
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
|
||||||
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val,
|
|
||||||
int* ret_type_code) {
|
|
||||||
APIVariantValue& rv = ret_value;
|
|
||||||
*ret_type_code = rv.type_code;
|
|
||||||
if (rv.type_code == kNodeHandle) {
|
|
||||||
if (rv.sptr.get() != nullptr) {
|
|
||||||
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
|
|
||||||
} else {
|
|
||||||
ret_val->v_handle = nullptr;
|
|
||||||
}
|
|
||||||
} else if (rv.type_code == kFuncHandle) {
|
|
||||||
ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func));
|
|
||||||
} else {
|
|
||||||
*ret_val = rv.v_union;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* Implementation of API functions related to Codegen
|
|
||||||
* \file c_api_codegen.cc
|
|
||||||
*/
|
|
||||||
#include <tvm/expr.h>
|
|
||||||
#include <tvm/ir.h>
|
|
||||||
#include <tvm/codegen.h>
|
|
||||||
|
|
||||||
#include "./c_api_registry.h"
|
|
||||||
#include "../codegen/codegen_c.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
namespace codegen {
|
|
||||||
|
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_codegen_CompileToC)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = CodeGenC().Compile(args.at(0), args.at(1));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_codegen_MakeAPI)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = MakeAPI(
|
|
||||||
args.at(0), args.at(1), args.at(2), args.at(3));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_codegen_SplitHostDevice)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = SplitHostDevice(args.at(0));
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
// generate a dummy packed function for testing
|
|
||||||
void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) {
|
|
||||||
LOG(INFO) << num_args << " arguments";
|
|
||||||
for (int i = 0; i < num_args; ++i) {
|
|
||||||
switch (type_code[i]) {
|
|
||||||
case kNull: LOG(INFO) << i << ":nullptr"; break;
|
|
||||||
case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break;
|
|
||||||
case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break;
|
|
||||||
case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break;
|
|
||||||
default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_codegen_DummyHelloFunction)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = runtime::PackedFunc(DummyHelloFunction);
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_codegen_BuildStackVM)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = BuildStackVM(args.at(0));
|
|
||||||
});
|
|
||||||
|
|
||||||
} // namespace codegen
|
|
||||||
} // namespace tvm
|
|
|
@ -1,19 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* \file c_api_common.h
|
|
||||||
* \brief Common fields of all C APIs
|
|
||||||
*/
|
|
||||||
#ifndef TVM_C_API_C_API_COMMON_H_
|
|
||||||
#define TVM_C_API_C_API_COMMON_H_
|
|
||||||
|
|
||||||
#include <dmlc/base.h>
|
|
||||||
#include <dmlc/logging.h>
|
|
||||||
#include <dmlc/thread_local.h>
|
|
||||||
#include <tvm/c_api.h>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <exception>
|
|
||||||
#include "./c_api_registry.h"
|
|
||||||
#include "../runtime/runtime_base.h"
|
|
||||||
|
|
||||||
#endif // TVM_C_API_C_API_COMMON_H_
|
|
|
@ -1,47 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* Implementation of API functions
|
|
||||||
* \file c_api_impl.cc
|
|
||||||
*/
|
|
||||||
#include <tvm/expr.h>
|
|
||||||
#include <tvm/tensor.h>
|
|
||||||
#include "./c_api_registry.h"
|
|
||||||
|
|
||||||
namespace dmlc {
|
|
||||||
DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg);
|
|
||||||
} // namespace dmlc
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
|
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_format_str)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
std::ostringstream os;
|
|
||||||
os << args.at(0).operator NodeRef();
|
|
||||||
*ret = os.str();
|
|
||||||
})
|
|
||||||
.add_argument("expr", "Node", "expression to be printed");
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_raw_ptr)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
|
|
||||||
})
|
|
||||||
.add_argument("src", "NodeBase", "the node base");
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_save_json)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = SaveJSON(args.at(0));
|
|
||||||
})
|
|
||||||
.add_argument("src", "json_str", "the node ");
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_load_json)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = NodeRef(LoadJSON_(args.at(0)));
|
|
||||||
})
|
|
||||||
.add_argument("src", "NodeBase", "the node");
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
|
@ -1,273 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* Implementation of API functions related to Higher DSL build.
|
|
||||||
* \file c_api_lang.cc
|
|
||||||
*/
|
|
||||||
#include <tvm/expr.h>
|
|
||||||
#include <tvm/ir.h>
|
|
||||||
#include <tvm/tensor.h>
|
|
||||||
#include <tvm/buffer.h>
|
|
||||||
#include <tvm/schedule.h>
|
|
||||||
#include "./c_api_registry.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
|
|
||||||
using ArgStack = const std::vector<APIVariantValue>;
|
|
||||||
using RetValue = APIVariantValue;
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_const)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
if (args.at(0).type_code == kInt) {
|
|
||||||
*ret = make_const(args.at(1), args.at(0).operator int64_t());
|
|
||||||
} else if (args.at(0).type_code == kFloat) {
|
|
||||||
*ret = make_const(args.at(1), args.at(0).operator double());
|
|
||||||
} else {
|
|
||||||
LOG(FATAL) << "only accept int or float";
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.add_argument("src", "Number", "source number")
|
|
||||||
.add_argument("dtype", "str", "data type");
|
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_str)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = ir::StringImm::make(args.at(0));
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Array)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
std::vector<std::shared_ptr<Node> > data;
|
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
|
||||||
CHECK(args.at(i).type_code == kNodeHandle)
|
|
||||||
<< "need content of array to be NodeBase";
|
|
||||||
data.push_back(args.at(i).sptr);
|
|
||||||
}
|
|
||||||
auto node = std::make_shared<ArrayNode>();
|
|
||||||
node->data = std::move(data);
|
|
||||||
ret->type_code = kNodeHandle;
|
|
||||||
ret->sptr = node;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_ArrayGetItem)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
int64_t i = args.at(1);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<ArrayNode>());
|
|
||||||
auto* n = static_cast<const ArrayNode*>(sptr.get());
|
|
||||||
CHECK_LT(static_cast<size_t>(i), n->data.size())
|
|
||||||
<< "out of bound of array";
|
|
||||||
ret->sptr = n->data[i];
|
|
||||||
ret->type_code = kNodeHandle;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_ArraySize)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<ArrayNode>());
|
|
||||||
*ret = static_cast<int64_t>(
|
|
||||||
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Map)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK_EQ(args.size() % 2, 0U);
|
|
||||||
MapNode::ContainerType data;
|
|
||||||
for (size_t i = 0; i < args.size(); i += 2) {
|
|
||||||
CHECK(args.at(i).type_code == kNodeHandle)
|
|
||||||
<< "need content of array to be NodeBase";
|
|
||||||
CHECK(args.at(i + 1).type_code == kNodeHandle)
|
|
||||||
<< "need content of array to be NodeBase";
|
|
||||||
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
|
|
||||||
}
|
|
||||||
auto node = std::make_shared<MapNode>();
|
|
||||||
node->data = std::move(data);
|
|
||||||
ret->type_code = kNodeHandle;
|
|
||||||
ret->sptr = node;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_MapSize)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<MapNode>());
|
|
||||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
|
||||||
*ret = static_cast<int64_t>(n->data.size());
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_MapGetItem)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
CHECK(args.at(1).type_code == kNodeHandle);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<MapNode>());
|
|
||||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
|
||||||
auto it = n->data.find(args.at(1).sptr);
|
|
||||||
CHECK(it != n->data.end())
|
|
||||||
<< "cannot find the corresponding key in the Map";
|
|
||||||
ret->sptr = (*it).second;
|
|
||||||
ret->type_code = kNodeHandle;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_MapCount)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
CHECK(args.at(1).type_code == kNodeHandle);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<MapNode>());
|
|
||||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
|
||||||
*ret = static_cast<int64_t>(n->data.count(args.at(1).sptr));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_MapItems)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
CHECK(args.at(0).type_code == kNodeHandle);
|
|
||||||
auto& sptr = args.at(0).sptr;
|
|
||||||
CHECK(sptr->is_type<MapNode>());
|
|
||||||
auto* n = static_cast<const MapNode*>(sptr.get());
|
|
||||||
auto rkvs = std::make_shared<ArrayNode>();
|
|
||||||
for (const auto& kv : n->data) {
|
|
||||||
rkvs->data.push_back(kv.first);
|
|
||||||
rkvs->data.push_back(kv.second);
|
|
||||||
}
|
|
||||||
ret->sptr = rkvs;
|
|
||||||
ret->type_code = kNodeHandle;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(Range)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
if (args.size() == 1) {
|
|
||||||
*ret = Range(0, args.at(0));
|
|
||||||
} else {
|
|
||||||
*ret = Range(args.at(0), args.at(1));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.describe("create a domain range")
|
|
||||||
.add_argument("begin", "Expr", "beginning of the range.")
|
|
||||||
.add_argument("end", "Expr", "extent of the range");
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Buffer)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = BufferNode::make(args.at(0),
|
|
||||||
args.at(1),
|
|
||||||
args.at(2),
|
|
||||||
args.at(3),
|
|
||||||
args.at(4));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Tensor)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = TensorNode::make(args.at(0),
|
|
||||||
args.at(1),
|
|
||||||
args.at(2),
|
|
||||||
args.at(3));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_TensorEqual)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_TensorHash)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = static_cast<int64_t>(
|
|
||||||
std::hash<Tensor>()(args.at(0).operator Tensor()));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Placeholder)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = Placeholder(args.at(0),
|
|
||||||
args.at(1),
|
|
||||||
args.at(2));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_ComputeOp)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = ComputeOpNode::make(args.at(0),
|
|
||||||
args.at(1),
|
|
||||||
args.at(2));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_OpGetOutput)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = args.at(0).operator Operation().output(
|
|
||||||
args.at(1).operator int64_t());
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_IterVar)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = IterVar(args.at(0), args.at(1), args.at(2));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_Schedule)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
*ret = Schedule(args.at(0).operator Array<Operation>());
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageSetScope)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.set_scope(args.at(1));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageSplitByFactor)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
IterVar outer, inner;
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.split(args.at(1), &outer, &inner, args.at(2));
|
|
||||||
*ret = Array<IterVar>({outer, inner});
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageSplitByOuter)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
IterVar inner;
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.split(args.at(1), args.at(2), &inner, args.at(3));
|
|
||||||
*ret = inner;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageFuse)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
IterVar fused;
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.split(args.at(1), args.at(2), &fused);
|
|
||||||
*ret = fused;
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageComputeAt)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.compute_at(args.at(1), args.at(2));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageComputeInline)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.compute_inline();
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageComputeRoot)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.compute_root();
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageReorder)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.reorder(args.at(1));
|
|
||||||
});
|
|
||||||
|
|
||||||
TVM_REGISTER_API(_StageTile)
|
|
||||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
|
||||||
IterVar x_outer, y_outer, x_inner, y_inner;
|
|
||||||
args.at(0).operator Stage()
|
|
||||||
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
|
|
||||||
&x_inner, &y_inner, args.at(3), args.at(4));
|
|
||||||
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
|
|
||||||
});
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
|
@ -1,240 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2016 by Contributors
|
|
||||||
* \file c_api_registry.h
|
|
||||||
* \brief Quick registry for C API.
|
|
||||||
*/
|
|
||||||
#ifndef TVM_C_API_C_API_REGISTRY_H_
|
|
||||||
#define TVM_C_API_C_API_REGISTRY_H_
|
|
||||||
|
|
||||||
#include <tvm/base.h>
|
|
||||||
#include <tvm/expr.h>
|
|
||||||
#include <tvm/c_api.h>
|
|
||||||
#include <tvm/runtime/packed_func.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <limits>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include "../base/common.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
|
|
||||||
inline const char* TVMTypeCode2Str(int type_code) {
|
|
||||||
switch (type_code) {
|
|
||||||
case kInt: return "int";
|
|
||||||
case kFloat: return "float";
|
|
||||||
case kStr: return "str";
|
|
||||||
case kHandle: return "Handle";
|
|
||||||
case kNull: return "NULL";
|
|
||||||
case kNodeHandle: return "NodeHandle";
|
|
||||||
default: LOG(FATAL) << "unknown type_code="
|
|
||||||
<< static_cast<int>(type_code); return "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct NodeTypeChecker {
|
|
||||||
static inline bool Check(Node* sptr) {
|
|
||||||
// This is the only place in the project where RTTI is used
|
|
||||||
// It can be turned off, but will make non strict checking.
|
|
||||||
// TODO(tqchen) possibly find alternative to turn of RTTI
|
|
||||||
using ContainerType = typename T::ContainerType;
|
|
||||||
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
|
|
||||||
}
|
|
||||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
|
||||||
using ContainerType = typename T::ContainerType;
|
|
||||||
os << ContainerType::_type_key;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct NodeTypeChecker<Array<T> > {
|
|
||||||
static inline bool Check(Node* sptr) {
|
|
||||||
if (sptr == nullptr) return false;
|
|
||||||
if (!sptr->is_type<ArrayNode>()) return false;
|
|
||||||
ArrayNode* n = static_cast<ArrayNode*>(sptr);
|
|
||||||
for (const auto& p : n->data) {
|
|
||||||
if (!NodeTypeChecker<T>::Check(p.get())) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
|
||||||
os << "array<";
|
|
||||||
NodeTypeChecker<T>::PrintName(os);
|
|
||||||
os << ">";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename K, typename V>
|
|
||||||
struct NodeTypeChecker<Map<K, V> > {
|
|
||||||
static inline bool Check(Node* sptr) {
|
|
||||||
if (sptr == nullptr) return false;
|
|
||||||
if (!sptr->is_type<MapNode>()) return false;
|
|
||||||
MapNode* n = static_cast<MapNode*>(sptr);
|
|
||||||
for (const auto& kv : n->data) {
|
|
||||||
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
|
|
||||||
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
|
|
||||||
os << "map<";
|
|
||||||
NodeTypeChecker<K>::PrintName(os);
|
|
||||||
os << ',';
|
|
||||||
NodeTypeChecker<V>::PrintName(os);
|
|
||||||
os << '>';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline std::string NodeTypeName() {
|
|
||||||
std::ostringstream os;
|
|
||||||
NodeTypeChecker<T>::PrintName(os);
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief Variant container for API calls */
|
|
||||||
class APIVariantValue {
|
|
||||||
public:
|
|
||||||
/*! \brief the type id */
|
|
||||||
int type_code{kNull};
|
|
||||||
/*! \brief shared pointer container */
|
|
||||||
std::shared_ptr<Node> sptr;
|
|
||||||
/*! \brief string container */
|
|
||||||
std::string str;
|
|
||||||
/*! \brief the variant holder */
|
|
||||||
TVMValue v_union;
|
|
||||||
/*! \brief std::function */
|
|
||||||
runtime::PackedFunc::FType func;
|
|
||||||
// constructor
|
|
||||||
APIVariantValue() {
|
|
||||||
}
|
|
||||||
// clear value
|
|
||||||
inline void Clear() {
|
|
||||||
}
|
|
||||||
// assign op
|
|
||||||
inline APIVariantValue& operator=(double value) {
|
|
||||||
type_code = kFloat;
|
|
||||||
v_union.v_float64 = value;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(std::nullptr_t value) {
|
|
||||||
type_code = kHandle;
|
|
||||||
v_union.v_handle = value;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(int64_t value) {
|
|
||||||
type_code = kInt;
|
|
||||||
v_union.v_int64 = value;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(bool value) {
|
|
||||||
type_code = kInt;
|
|
||||||
v_union.v_int64 = value;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(std::string value) {
|
|
||||||
type_code = kStr;
|
|
||||||
str = std::move(value);
|
|
||||||
v_union.v_str = str.c_str();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(const NodeRef& ref) {
|
|
||||||
if (ref.node_.get() == nullptr) {
|
|
||||||
type_code = kNull;
|
|
||||||
} else {
|
|
||||||
type_code = kNodeHandle;
|
|
||||||
this->sptr = ref.node_;
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
|
|
||||||
type_code = kFuncHandle;
|
|
||||||
this->func = f.body();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline APIVariantValue& operator=(const Type& value) {
|
|
||||||
return operator=(Type2String(value));
|
|
||||||
}
|
|
||||||
template<typename T,
|
|
||||||
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
|
|
||||||
inline operator T() const {
|
|
||||||
if (type_code == kNull) return T();
|
|
||||||
CHECK_EQ(type_code, kNodeHandle);
|
|
||||||
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
|
|
||||||
<< "Did not get expected type " << NodeTypeName<T>();
|
|
||||||
return T(sptr);
|
|
||||||
}
|
|
||||||
inline operator Expr() const {
|
|
||||||
if (type_code == kNull) {
|
|
||||||
return Expr();
|
|
||||||
}
|
|
||||||
if (type_code == kInt) return Expr(operator int());
|
|
||||||
if (type_code == kFloat) {
|
|
||||||
return Expr(static_cast<float>(operator double()));
|
|
||||||
}
|
|
||||||
CHECK_EQ(type_code, kNodeHandle);
|
|
||||||
if (sptr->is_type<IterVarNode>()) {
|
|
||||||
return IterVar(sptr)->var;
|
|
||||||
} else {
|
|
||||||
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
|
|
||||||
<< "did not pass in Expr in a place need Expr";
|
|
||||||
return Expr(sptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inline operator double() const {
|
|
||||||
CHECK_EQ(type_code, kFloat);
|
|
||||||
return v_union.v_float64;
|
|
||||||
}
|
|
||||||
inline operator int64_t() const {
|
|
||||||
CHECK_EQ(type_code, kInt);
|
|
||||||
return v_union.v_int64;
|
|
||||||
}
|
|
||||||
inline operator uint64_t() const {
|
|
||||||
CHECK_EQ(type_code, kInt);
|
|
||||||
return v_union.v_int64;
|
|
||||||
}
|
|
||||||
inline operator int() const {
|
|
||||||
CHECK_EQ(type_code, kInt);
|
|
||||||
CHECK_LE(v_union.v_int64,
|
|
||||||
std::numeric_limits<int>::max());
|
|
||||||
return v_union.v_int64;
|
|
||||||
}
|
|
||||||
inline operator bool() const {
|
|
||||||
CHECK_EQ(type_code, kInt)
|
|
||||||
<< "expect boolean(int) but get "
|
|
||||||
<< TVMTypeCode2Str(type_code);
|
|
||||||
return v_union.v_int64 != 0;
|
|
||||||
}
|
|
||||||
inline operator std::string() const {
|
|
||||||
CHECK_EQ(type_code, kStr)
|
|
||||||
<< "expect Str but get "
|
|
||||||
<< TVMTypeCode2Str(type_code);
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
inline operator Type() const {
|
|
||||||
return String2Type(operator std::string());
|
|
||||||
}
|
|
||||||
inline operator runtime::PackedFunc() const {
|
|
||||||
CHECK_EQ(type_code, kFuncHandle);
|
|
||||||
return runtime::PackedFunc(func);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// common defintiion of API function.
|
|
||||||
using APIFunc = std::function<
|
|
||||||
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Registry entry for DataIterator factory functions.
|
|
||||||
*/
|
|
||||||
struct APIFuncReg
|
|
||||||
: public dmlc::FunctionRegEntryBase<APIFuncReg,
|
|
||||||
APIFunc> {
|
|
||||||
};
|
|
||||||
|
|
||||||
#define TVM_REGISTER_API(TypeName) \
|
|
||||||
DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
||||||
|
|
||||||
#endif // TVM_C_API_C_API_REGISTRY_H_
|
|
|
@ -12,19 +12,22 @@ using namespace ir;
|
||||||
|
|
||||||
runtime::PackedFunc BuildStackVM(LoweredFunc func) {
|
runtime::PackedFunc BuildStackVM(LoweredFunc func) {
|
||||||
StackVM vm = codegen::CodeGenStackVM().Compile(func);
|
StackVM vm = codegen::CodeGenStackVM().Compile(func);
|
||||||
auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) {
|
using runtime::TVMArgs;
|
||||||
LOG(INFO) << "Run stack VM";
|
using runtime::TVMRetValue;
|
||||||
|
|
||||||
|
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
|
||||||
StackVM::State* s = StackVM::ThreadLocalState();
|
StackVM::State* s = StackVM::ThreadLocalState();
|
||||||
s->sp = 0;
|
s->sp = 0;
|
||||||
s->pc = 0;
|
s->pc = 0;
|
||||||
if (s->heap.size() < vm.heap_size) {
|
if (s->heap.size() < vm.heap_size) {
|
||||||
s->heap.resize(vm.heap_size);
|
s->heap.resize(vm.heap_size);
|
||||||
}
|
}
|
||||||
s->heap[0].v_handle = (void*)args; // NOLINT(*)
|
s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
|
||||||
s->heap[1].v_handle = (void*)type_codes; // NOLINT(*)
|
s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
|
||||||
s->heap[2].v_int64 = num_args;
|
s->heap[2].v_int64 = args.num_args;
|
||||||
vm.Run(s);
|
vm.Run(s);
|
||||||
};
|
};
|
||||||
|
|
||||||
return runtime::PackedFunc(f);
|
return runtime::PackedFunc(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +121,9 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
|
||||||
auto it = fun_idmap_.find(name);
|
auto it = fun_idmap_.find(name);
|
||||||
if (it != fun_idmap_.end()) return it->second;
|
if (it != fun_idmap_.end()) return it->second;
|
||||||
using runtime::PackedFunc;
|
using runtime::PackedFunc;
|
||||||
|
using runtime::TVMArgs;
|
||||||
|
using runtime::TVMRetValue;
|
||||||
|
|
||||||
PackedFunc f = PackedFunc::GetGlobal(name);
|
PackedFunc f = PackedFunc::GetGlobal(name);
|
||||||
auto extern_f = [f](const TVMValue* args, int num_args) {
|
auto extern_f = [f](const TVMValue* args, int num_args) {
|
||||||
CHECK_EQ(num_args % 2, 0);
|
CHECK_EQ(num_args % 2, 0);
|
||||||
|
@ -128,7 +134,8 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
|
||||||
int code = (tcode >> (8 * 3)) & 255;
|
int code = (tcode >> (8 * 3)) & 255;
|
||||||
type_codes[i] = code;
|
type_codes[i] = code;
|
||||||
}
|
}
|
||||||
f.CallPacked(args, &type_codes[0], num_args);
|
TVMRetValue rv;
|
||||||
|
f.CallPacked(TVMArgs(args, &type_codes[0], num_args), &rv);
|
||||||
TVMValue r; r.v_int64 = 0;
|
TVMValue r; r.v_int64 = 0;
|
||||||
return r;
|
return r;
|
||||||
};
|
};
|
||||||
|
|
|
@ -136,7 +136,6 @@ class HostDeviceSplitter : public IRMutator {
|
||||||
public:
|
public:
|
||||||
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
||||||
if (op->type_key == "thread_extent") {
|
if (op->type_key == "thread_extent") {
|
||||||
LOG(INFO) << "??";
|
|
||||||
IterVar iv(op->node.node_);
|
IterVar iv(op->node.node_);
|
||||||
return SplitDeviceFunc(s);
|
return SplitDeviceFunc(s);
|
||||||
}
|
}
|
||||||
|
|
|
@ -302,7 +302,8 @@ void StackVM::Run(State* s) const {
|
||||||
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
|
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
|
||||||
}
|
}
|
||||||
case TVM_LOAD_ARG_HANDLE: {
|
case TVM_LOAD_ARG_HANDLE: {
|
||||||
STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break;
|
STACK_VM_TVM_LOAD_ARG(
|
||||||
|
tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break;
|
||||||
}
|
}
|
||||||
case TVM_ARRAY_GET_DATA: {
|
case TVM_ARRAY_GET_DATA: {
|
||||||
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
|
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
|
||||||
|
@ -317,7 +318,7 @@ void StackVM::Run(State* s) const {
|
||||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
|
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
|
||||||
}
|
}
|
||||||
case TVM_ARRAY_GET_TYPE_CODE: {
|
case TVM_ARRAY_GET_TYPE_CODE: {
|
||||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break;
|
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
|
||||||
}
|
}
|
||||||
case TVM_ARRAY_GET_TYPE_BITS: {
|
case TVM_ARRAY_GET_TYPE_BITS: {
|
||||||
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
|
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
|
||||||
|
|
|
@ -3,9 +3,11 @@
|
||||||
* \file c_runtime_api.cc
|
* \file c_runtime_api.cc
|
||||||
* \brief Device specific implementations
|
* \brief Device specific implementations
|
||||||
*/
|
*/
|
||||||
|
#include <dmlc/thread_local.h>
|
||||||
#include <tvm/runtime/c_runtime_api.h>
|
#include <tvm/runtime/c_runtime_api.h>
|
||||||
#include <tvm/runtime/packed_func.h>
|
#include <tvm/runtime/packed_func.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
#include "./runtime_base.h"
|
#include "./runtime_base.h"
|
||||||
#include "./device_api.h"
|
#include "./device_api.h"
|
||||||
|
|
||||||
|
@ -37,7 +39,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
|
||||||
|
|
||||||
inline void VerifyType(TVMType dtype) {
|
inline void VerifyType(TVMType dtype) {
|
||||||
CHECK_GE(dtype.lanes, 1U);
|
CHECK_GE(dtype.lanes, 1U);
|
||||||
if (dtype.type_code == kFloat) {
|
if (dtype.code == kFloat) {
|
||||||
CHECK_EQ(dtype.bits % 32U, 0U);
|
CHECK_EQ(dtype.bits % 32U, 0U);
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(dtype.bits % 8U, 0U);
|
CHECK_EQ(dtype.bits % 8U, 0U);
|
||||||
|
@ -65,6 +67,12 @@ inline size_t GetDataAlignment(TVMArray* arr) {
|
||||||
|
|
||||||
using namespace tvm::runtime;
|
using namespace tvm::runtime;
|
||||||
|
|
||||||
|
struct TVMRuntimeEntry {
|
||||||
|
std::string ret_str;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
|
||||||
|
|
||||||
int TVMDeviceInit(int dev_mask,
|
int TVMDeviceInit(int dev_mask,
|
||||||
const char** option_keys,
|
const char** option_keys,
|
||||||
const char** option_vals,
|
const char** option_vals,
|
||||||
|
@ -177,10 +185,31 @@ int TVMFuncFree(TVMFunctionHandle func) {
|
||||||
int TVMFuncCall(TVMFunctionHandle func,
|
int TVMFuncCall(TVMFunctionHandle func,
|
||||||
TVMValue* args,
|
TVMValue* args,
|
||||||
int* arg_type_codes,
|
int* arg_type_codes,
|
||||||
int num_args) {
|
int num_args,
|
||||||
|
TVMValue* ret_val,
|
||||||
|
int* ret_type_code) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
TVMRetValue rv;
|
||||||
(*static_cast<const PackedFunc*>(func)).CallPacked(
|
(*static_cast<const PackedFunc*>(func)).CallPacked(
|
||||||
args, arg_type_codes, num_args);
|
TVMArgs(args, arg_type_codes, num_args), &rv);
|
||||||
|
// handle return string.
|
||||||
|
if (rv.type_code() == kStr) {
|
||||||
|
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
|
||||||
|
e->ret_str = rv.operator std::string();
|
||||||
|
*ret_type_code = kStr;
|
||||||
|
ret_val->v_str = e->ret_str.c_str();
|
||||||
|
} else {
|
||||||
|
rv.MoveToCHost(ret_val, ret_type_code);
|
||||||
|
}
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TVMCFuncSetReturn(TVMRetValueHandle ret,
|
||||||
|
TVMValue value,
|
||||||
|
int type_code) {
|
||||||
|
API_BEGIN();
|
||||||
|
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
|
||||||
|
*rv = TVMArgValue(value, type_code);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,22 +220,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
if (fin == nullptr) {
|
if (fin == nullptr) {
|
||||||
*out = new PackedFunc(
|
*out = new PackedFunc(
|
||||||
[func, resource_handle](const TVMValue* args,
|
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
|
||||||
const int* type_codes,
|
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||||
int num_args) {
|
args.num_args, rv, resource_handle);
|
||||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
|
||||||
num_args, resource_handle);
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// wrap it in a shared_ptr, with fin as deleter.
|
// wrap it in a shared_ptr, with fin as deleter.
|
||||||
// so fin will be called when the lambda went out of scope.
|
// so fin will be called when the lambda went out of scope.
|
||||||
std::shared_ptr<void> rpack(resource_handle, fin);
|
std::shared_ptr<void> rpack(resource_handle, fin);
|
||||||
*out = new PackedFunc(
|
*out = new PackedFunc(
|
||||||
[func, rpack](const TVMValue* args,
|
[func, rpack](TVMArgs args, TVMRetValue* rv) {
|
||||||
const int* type_codes,
|
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
|
||||||
int num_args) {
|
args.num_args, rv, rpack.get());
|
||||||
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
|
|
||||||
num_args, rpack.get());
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
API_END();
|
API_END();
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
* \brief The global registry of packed function.
|
* \brief The global registry of packed function.
|
||||||
*/
|
*/
|
||||||
#include <dmlc/logging.h>
|
#include <dmlc/logging.h>
|
||||||
|
#include <dmlc/thread_local.h>
|
||||||
#include <tvm/runtime/packed_func.h>
|
#include <tvm/runtime/packed_func.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -58,6 +59,18 @@ std::vector<std::string> PackedFunc::ListGlobalNames() {
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
||||||
|
/*! \brief entry to to easily hold returning information */
|
||||||
|
struct TVMFuncThreadLocalEntry {
|
||||||
|
/*! \brief result holder for returning strings */
|
||||||
|
std::vector<std::string> ret_vec_str;
|
||||||
|
/*! \brief result holder for returning string pointers */
|
||||||
|
std::vector<const char *> ret_vec_charp;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief Thread local store that can be used to hold return values. */
|
||||||
|
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
|
||||||
|
|
||||||
|
|
||||||
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
||||||
using tvm::runtime::PackedFunc;
|
using tvm::runtime::PackedFunc;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
@ -68,6 +81,22 @@ int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
|
||||||
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
|
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
|
||||||
using tvm::runtime::PackedFunc;
|
using tvm::runtime::PackedFunc;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
*out = new PackedFunc(PackedFunc::GetGlobal(name));
|
const PackedFunc& f = PackedFunc::GetGlobal(name);
|
||||||
|
*out = (TVMFunctionHandle)(&f); // NOLINT(*)
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TVMFuncListGlobalNames(int *out_size,
|
||||||
|
const char*** out_array) {
|
||||||
|
using tvm::runtime::PackedFunc;
|
||||||
|
API_BEGIN();
|
||||||
|
TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
|
||||||
|
ret->ret_vec_str = PackedFunc::ListGlobalNames();
|
||||||
|
ret->ret_vec_charp.clear();
|
||||||
|
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
|
||||||
|
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
|
||||||
|
}
|
||||||
|
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
|
||||||
|
*out_size = static_cast<int>(ret->ret_vec_str.size());
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,24 +1,116 @@
|
||||||
#include <dmlc/logging.h>
|
#include <dmlc/logging.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <tvm/runtime/packed_func.h>
|
#include <tvm/runtime/packed_func.h>
|
||||||
|
#include <tvm/tvm.h>
|
||||||
|
#include <tvm/ir.h>
|
||||||
|
|
||||||
TEST(PackedFunc, Basic) {
|
TEST(PackedFunc, Basic) {
|
||||||
|
using namespace tvm;
|
||||||
using namespace tvm::runtime;
|
using namespace tvm::runtime;
|
||||||
int x = 0;
|
int x = 0;
|
||||||
void* handle = &x;
|
void* handle = &x;
|
||||||
TVMArray a;
|
TVMArray a;
|
||||||
|
|
||||||
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) {
|
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||||
CHECK(num_args == 3);
|
CHECK(args.num_args == 3);
|
||||||
CHECK(args[0].v_float64 == 1.0);
|
CHECK(args.values[0].v_float64 == 1.0);
|
||||||
CHECK(type_codes[0] == kFloat);
|
CHECK(args.type_codes[0] == kFloat);
|
||||||
CHECK(args[1].v_handle == &a);
|
CHECK(args.values[1].v_handle == &a);
|
||||||
CHECK(type_codes[1] == kHandle);
|
CHECK(args.type_codes[1] == kArrayHandle);
|
||||||
CHECK(args[2].v_handle == &x);
|
CHECK(args.values[2].v_handle == &x);
|
||||||
CHECK(type_codes[2] == kHandle);
|
CHECK(args.type_codes[2] == kHandle);
|
||||||
|
*rv = Var("a");
|
||||||
})(1.0, &a, handle);
|
})(1.0, &a, handle);
|
||||||
|
CHECK(v->name_hint == "a");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PackedFunc, Node) {
|
||||||
|
using namespace tvm;
|
||||||
|
using namespace tvm::runtime;
|
||||||
|
Var x;
|
||||||
|
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
CHECK(args.num_args == 1);
|
||||||
|
CHECK(args.type_codes[0] == kNodeHandle);
|
||||||
|
Var b = args[0];
|
||||||
|
CHECK(x.same_as(b));
|
||||||
|
*rv = b;
|
||||||
|
})(x);
|
||||||
|
CHECK(t.same_as(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PackedFunc, str) {
|
||||||
|
using namespace tvm;
|
||||||
|
using namespace tvm::runtime;
|
||||||
|
PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
CHECK(args.num_args == 1);
|
||||||
|
std::string x = args[0];
|
||||||
|
CHECK(x == "hello");
|
||||||
|
*rv = x;
|
||||||
|
})("hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(PackedFunc, func) {
|
||||||
|
using namespace tvm;
|
||||||
|
using namespace tvm::runtime;
|
||||||
|
PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
*rv = args[0].operator int() + 1;
|
||||||
|
});
|
||||||
|
// function as arguments
|
||||||
|
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
PackedFunc f = args[0];
|
||||||
|
// TVMArgValue -> Arguments as function
|
||||||
|
*rv = f(args[1]).operator int();
|
||||||
|
})(addone, 1);
|
||||||
|
CHECK_EQ(r0, 2);
|
||||||
|
|
||||||
|
int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
// TVMArgValue -> TVMRetValue
|
||||||
|
*rv = args[1];
|
||||||
|
})(2, 100);
|
||||||
|
CHECK_EQ(r1, 100);
|
||||||
|
|
||||||
|
int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
// re-assignment
|
||||||
|
*rv = args[0];
|
||||||
|
// TVMRetValue -> Function argument
|
||||||
|
*rv = addone(args[0].operator PackedFunc()(args[1], 1));
|
||||||
|
})(addone, 100);
|
||||||
|
CHECK_EQ(r2, 102);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PackedFunc, Expr) {
|
||||||
|
using namespace tvm;
|
||||||
|
using namespace tvm::runtime;
|
||||||
|
// automatic conversion of int to expr
|
||||||
|
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
Expr x = args[0];
|
||||||
|
*rv = x.as<tvm::ir::IntImm>()->value + 1;
|
||||||
|
});
|
||||||
|
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
PackedFunc f = args[0];
|
||||||
|
// TVMArgValue -> Arguments as function
|
||||||
|
*rv = f(args[1]).operator int();
|
||||||
|
})(addone, 1);
|
||||||
|
CHECK_EQ(r0, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PackedFunc, Type) {
|
||||||
|
using namespace tvm;
|
||||||
|
using namespace tvm::runtime;
|
||||||
|
auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
Type x = args[0];
|
||||||
|
*rv = x;
|
||||||
|
});
|
||||||
|
auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
*rv = args[0];
|
||||||
|
});
|
||||||
|
CHECK(get_type("int32").operator Type() == Int(32));
|
||||||
|
CHECK(get_type("float").operator Type() == Float(32));
|
||||||
|
CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
testing::InitGoogleTest(&argc, argv);
|
testing::InitGoogleTest(&argc, argv);
|
||||||
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||||
|
|
|
@ -2,7 +2,8 @@ import tvm
|
||||||
|
|
||||||
def test_const():
|
def test_const():
|
||||||
x = tvm.const(1)
|
x = tvm.const(1)
|
||||||
assert x.dtype == 'int32'
|
print(x.dtype)
|
||||||
|
assert x.dtype == tvm.int32
|
||||||
assert isinstance(x, tvm.expr.IntImm)
|
assert isinstance(x, tvm.expr.IntImm)
|
||||||
|
|
||||||
def test_const_saveload_json():
|
def test_const_saveload_json():
|
||||||
|
|
|
@ -17,10 +17,22 @@ def test_get_global():
|
||||||
@tvm.register_func
|
@tvm.register_func
|
||||||
def my_packed_func(*args):
|
def my_packed_func(*args):
|
||||||
assert(tuple(args) == targs)
|
assert(tuple(args) == targs)
|
||||||
|
return 10
|
||||||
# get it out from global function table
|
# get it out from global function table
|
||||||
f = tvm.get_global_func("my_packed_func")
|
f = tvm.get_global_func("my_packed_func")
|
||||||
assert isinstance(f, tvm.nd.Function)
|
assert isinstance(f, tvm.nd.Function)
|
||||||
f(*targs)
|
y = f(*targs)
|
||||||
|
assert y == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_func():
|
||||||
|
def addy(y):
|
||||||
|
def add(x):
|
||||||
|
return tvm.convert(x + y)
|
||||||
|
return add
|
||||||
|
myf = tvm.convert(addy)
|
||||||
|
f = myf(10)
|
||||||
|
assert f(11).value == 21
|
||||||
|
|
||||||
|
|
||||||
def test_convert():
|
def test_convert():
|
||||||
|
@ -38,3 +50,4 @@ if __name__ == "__main__":
|
||||||
test_function()
|
test_function()
|
||||||
test_convert()
|
test_convert()
|
||||||
test_get_global()
|
test_get_global()
|
||||||
|
test_return_func()
|
||||||
|
|
|
@ -38,10 +38,10 @@ fi
|
||||||
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
|
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
|
||||||
make all || exit -1
|
make all || exit -1
|
||||||
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||||
python -m nose tests/python/ || exit -1
|
python -m nose -v tests/python/ || exit -1
|
||||||
python3 -m nose tests/python/ || exit -1
|
python3 -m nose -v tests/python/ || exit -1
|
||||||
else
|
else
|
||||||
nosetests tests/python/ || exit -1
|
nosetests -v tests/python/ || exit -1
|
||||||
nosetests3 tests/python/ || exit -1
|
nosetests3 -v tests/python/ || exit -1
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
Загрузка…
Ссылка в новой задаче