470 строки
14 KiB
C
470 строки
14 KiB
C
|
/*!
|
||
|
* Copyright (c) 2018 by Contributors
|
||
|
* \file tvm/relay/op.h
|
||
|
* \brief Primitive operator definition.
|
||
|
*/
|
||
|
#ifndef TVM_RELAY_OP_H_
|
||
|
#define TVM_RELAY_OP_H_
|
||
|
|
||
|
#include <functional>
|
||
|
#include <limits>
|
||
|
#include <string>
|
||
|
#include <typeinfo>
|
||
|
#include <utility>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "../attrs.h"
|
||
|
#include "./base.h"
|
||
|
#include "./expr.h"
|
||
|
#include "./type.h"
|
||
|
|
||
|
namespace tvm {
|
||
|
namespace relay {
|
||
|
|
||
|
// forward declare name.
|
||
|
template <typename ValueType>
|
||
|
class OpMap;
|
||
|
class GenericOpMap;
|
||
|
class OpRegistry;
|
||
|
|
||
|
/*!
|
||
|
* \brief Node container of operator structure.
|
||
|
*/
|
||
|
class OpNode : public relay::ExprNode {
|
||
|
public:
|
||
|
/*! \brief name of the operator */
|
||
|
std::string name;
|
||
|
/*! \brief the type of the operator */
|
||
|
mutable FuncType op_type;
|
||
|
/*!
|
||
|
* \brief detailed description of the operator
|
||
|
* This can be used to generate docstring automatically for the operator.
|
||
|
*/
|
||
|
std::string description;
|
||
|
/* \brief Information of input arguments to the operator */
|
||
|
Array<AttrFieldInfo> arguments;
|
||
|
/*!
|
||
|
* \brief The type key of the attribute field
|
||
|
* This can be empty, in which case it defaults to
|
||
|
*/
|
||
|
std::string attrs_type_key;
|
||
|
/*!
|
||
|
* \brief number of input arguments to the operator,
|
||
|
* -1 means it is variable length
|
||
|
*/
|
||
|
int32_t num_inputs = -1;
|
||
|
/*!
|
||
|
* \brief support level of the operator,
|
||
|
* The lower the more priority it contains.
|
||
|
* This is in analogies to BLAS levels.
|
||
|
*/
|
||
|
int32_t support_level = 10;
|
||
|
|
||
|
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||
|
v->Visit("name", &name);
|
||
|
v->Visit("op_type", &op_type);
|
||
|
v->Visit("description", &description);
|
||
|
v->Visit("arguments", &arguments);
|
||
|
v->Visit("attrs_type_key", &attrs_type_key);
|
||
|
v->Visit("num_inputs", &num_inputs);
|
||
|
v->Visit("support_level", &support_level);
|
||
|
}
|
||
|
|
||
|
static constexpr const char* _type_key = "relay.Op";
|
||
|
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
|
||
|
|
||
|
private:
|
||
|
// friend class
|
||
|
friend class GenericOpMap;
|
||
|
friend class OpRegistry;
|
||
|
// Program internal unique index of operator.
|
||
|
// Used to help index the program.
|
||
|
uint32_t index_{0};
|
||
|
};
|
||
|
|
||
|
/*!
|
||
|
* \brief Operator reference class.
|
||
|
*/
|
||
|
class Op : public relay::Expr {
|
||
|
public:
|
||
|
/*! \brief default constructor */
|
||
|
Op() {}
|
||
|
/*! \brief constructor from node pointer */
|
||
|
explicit Op(std::shared_ptr<Node> n) : Expr(n) {}
|
||
|
/*!
|
||
|
* \brief access the internal node container
|
||
|
* \return the pointer to the internal node container
|
||
|
*/
|
||
|
inline const OpNode* operator->() const;
|
||
|
/*!
|
||
|
* \brief Get additional registered attribute about operators.
|
||
|
* If nothing has been registered, an empty OpMap will be returned.
|
||
|
* \param attr_name The name of the attribute.
|
||
|
* \return An OpMap of specified attr_name.
|
||
|
* \tparam ValueType The type of the attribute.
|
||
|
*/
|
||
|
template <typename ValueType>
|
||
|
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
|
||
|
/*!
|
||
|
* \brief Get an Op for a given operator name.
|
||
|
* Will raise an error if the op has not been registered.
|
||
|
* \param op_name Name of the operator.
|
||
|
* \return Pointer to a Op, valid throughout program lifetime.
|
||
|
*/
|
||
|
TVM_DLL static const Op& Get(const std::string& op_name);
|
||
|
|
||
|
/*! \brief specify container node */
|
||
|
using ContainerType = OpNode;
|
||
|
|
||
|
private:
|
||
|
/*!
|
||
|
* \brief Get generic attrmap given attr name
|
||
|
* \param key The attribute key
|
||
|
* \return reference to GenericOpMap
|
||
|
*/
|
||
|
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
|
||
|
};
|
||
|
|
||
|
/*! \brief Helper structure to register operators */
|
||
|
class OpRegistry {
|
||
|
public:
|
||
|
/*! \return the operator */
|
||
|
const Op& op() const { return op_; }
|
||
|
/*!
|
||
|
* \brief setter function during registration
|
||
|
* Set the description of operator
|
||
|
* \param descr the description string.
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& describe(const std::string& descr); // NOLINT(*)
|
||
|
/*!
|
||
|
* \brief Add argument information to the function.
|
||
|
* \param name Name of the argument.
|
||
|
* \param type Type of the argument.
|
||
|
* \param description Description of the argument.
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& add_argument(const std::string& name,
|
||
|
const std::string& type,
|
||
|
const std::string& description);
|
||
|
/*!
|
||
|
* \brief Attach the type function corresponding to the return type.
|
||
|
* \param rel_name The type relation name to register.
|
||
|
* \param type_rel_func The backing relation function which can solve an arbitrary
|
||
|
* relation on variables.
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& add_type_rel(
|
||
|
const std::string& rel_name,
|
||
|
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func);
|
||
|
|
||
|
/*!
|
||
|
* \brief Set the type key of attributes.
|
||
|
* \param type_key The type of of the attrs field.x
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& set_attrs_type_key(const std::string& type_key);
|
||
|
/*!
|
||
|
* \brief Set the num_inputs
|
||
|
* \param n The number of inputs to be set.
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*)
|
||
|
/*!
|
||
|
* \brief Set the support level of op.
|
||
|
* \param level The support level.
|
||
|
* \return reference to self.
|
||
|
*/
|
||
|
inline OpRegistry& set_support_level(int32_t level); // NOLINT(*)
|
||
|
/*!
|
||
|
* \brief Register additional attributes to operator.
|
||
|
* \param attr_name The name of the attribute.
|
||
|
* \param value The value to be set.
|
||
|
* \param plevel The priority level of this set,
|
||
|
* an higher priority level attribute
|
||
|
* will replace lower priority level attribute.
|
||
|
* Must be bigger than 0.
|
||
|
*
|
||
|
* Cannot set with same plevel twice in the code.
|
||
|
*
|
||
|
* \tparam ValueType The type of the value to be set.
|
||
|
*/
|
||
|
template <typename ValueType>
|
||
|
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
|
||
|
const ValueType& value, int plevel = 10);
|
||
|
|
||
|
// set the name of the op to be the same as registry
|
||
|
inline OpRegistry& set_name() { // NOLINT(*)
|
||
|
if (get()->name.length() == 0) {
|
||
|
get()->name = name;
|
||
|
}
|
||
|
return *this;
|
||
|
}
|
||
|
/*! \return The global single registry */
|
||
|
TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
|
||
|
|
||
|
private:
|
||
|
friend class ::dmlc::Registry<OpRegistry>;
|
||
|
// the name
|
||
|
std::string name;
|
||
|
/*! \brief The operator */
|
||
|
Op op_;
|
||
|
// private constructor
|
||
|
OpRegistry();
|
||
|
// return internal pointer to op.
|
||
|
inline OpNode* get();
|
||
|
// update the attribute OpMap
|
||
|
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
|
||
|
int plevel);
|
||
|
};
|
||
|
|
||
|
/*!
|
||
|
* \brief Generic map to store additional information of Op.
|
||
|
*/
|
||
|
class GenericOpMap {
|
||
|
public:
|
||
|
/*!
|
||
|
* \brief Check if the map has op as key.
|
||
|
* \param op The key to the map
|
||
|
* \return 1 if op is contained in map, 0 otherwise.
|
||
|
*/
|
||
|
inline int count(const Op& op) const;
|
||
|
/*!
|
||
|
* \brief get the corresponding value element at op
|
||
|
* \param op The key to the map
|
||
|
* \return the const reference to the content value.
|
||
|
*/
|
||
|
inline const TVMRetValue& operator[](const Op& op) const;
|
||
|
/*!
|
||
|
* \brief get the corresponding value element at op with default value.
|
||
|
* \param op The key to the map
|
||
|
* \param def_value The default value when the key does not exist.
|
||
|
* \return the const reference to the content value.
|
||
|
* \tparam ValueType The content value type.
|
||
|
*/
|
||
|
template <typename ValueType>
|
||
|
inline ValueType get(const Op& op, ValueType def_value) const;
|
||
|
|
||
|
private:
|
||
|
friend class OpRegistry;
|
||
|
// the attribute field.
|
||
|
std::string attr_name_;
|
||
|
// internal data
|
||
|
std::vector<std::pair<TVMRetValue, int> > data_;
|
||
|
// The value
|
||
|
GenericOpMap() = default;
|
||
|
};
|
||
|
|
||
|
/*!
|
||
|
* \brief Map<Op,ValueType> used to store meta-information about Op.
|
||
|
* \tparam ValueType The type of the value stored in map.
|
||
|
*/
|
||
|
template <typename ValueType>
|
||
|
class OpMap {
|
||
|
public:
|
||
|
/*!
|
||
|
* \brief Check if the map has op as key.
|
||
|
* \param op The key to the map
|
||
|
* \return 1 if op is contained in map, 0 otherwise.
|
||
|
*/
|
||
|
inline int count(const Op& op) const;
|
||
|
/*!
|
||
|
* \brief get the corresponding value element at op
|
||
|
* \param op The key to the map
|
||
|
* \return the const reference to the content value.
|
||
|
*/
|
||
|
inline ValueType operator[](const Op& op) const;
|
||
|
/*!
|
||
|
* \brief get the corresponding value element at op with default value.
|
||
|
* \param op The key to the map
|
||
|
* \param def_value The default value when the key does not exist.
|
||
|
* \return the const reference to the content value.
|
||
|
*/
|
||
|
inline ValueType get(const Op& op, ValueType def_value) const;
|
||
|
|
||
|
private:
|
||
|
friend class Op;
|
||
|
// constructor
|
||
|
explicit OpMap(const GenericOpMap& map) : map_(map) {}
|
||
|
/*! \brief The internal map field */
|
||
|
const GenericOpMap& map_;
|
||
|
};
|
||
|
|
||
|
// internal macros to make
|
||
|
#define RELAY_REGISTER_VAR_DEF \
|
||
|
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
|
||
|
|
||
|
/*!
|
||
|
* \def RELAY_REGISTER_OP
|
||
|
* \brief Register a new operator, or set attribute of the corresponding op.
|
||
|
*
|
||
|
* \param OpName The name of registry
|
||
|
*
|
||
|
* \code
|
||
|
*
|
||
|
* RELAY_REGISTER_OP("add")
|
||
|
* .describe("add two inputs together")
|
||
|
* .set_num_inputs(2)
|
||
|
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
|
||
|
*
|
||
|
* \endcode
|
||
|
*/
|
||
|
#define RELAY_REGISTER_OP(OpName) \
|
||
|
DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
|
||
|
::tvm::relay::OpRegistry::Registry() \
|
||
|
->__REGISTER_OR_GET__(OpName) \
|
||
|
.set_name()
|
||
|
|
||
|
// implementations
|
||
|
inline const OpNode* Op::operator->() const {
|
||
|
return static_cast<const OpNode*>(node_.get());
|
||
|
}
|
||
|
|
||
|
template <typename ValueType>
|
||
|
inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
|
||
|
return OpMap<ValueType>(Op::GetGenericAttr(key));
|
||
|
}
|
||
|
|
||
|
inline OpNode* OpRegistry::get() {
|
||
|
return const_cast<OpNode*>(op_.operator->());
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::describe(
|
||
|
const std::string& descr) { // NOLINT(*)
|
||
|
get()->description = descr;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
|
||
|
const std::string& type,
|
||
|
const std::string& description) {
|
||
|
std::shared_ptr<AttrFieldInfoNode> n = std::make_shared<AttrFieldInfoNode>();
|
||
|
n->name = name;
|
||
|
n->type_info = type;
|
||
|
n->description = description;
|
||
|
get()->arguments.push_back(AttrFieldInfo(n));
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::add_type_rel(
|
||
|
const std::string& rel_name,
|
||
|
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func) {
|
||
|
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
|
||
|
|
||
|
TypedEnvFunc<Array<Type>(const Array<Type>&, int)> env_type_rel_func;
|
||
|
|
||
|
if (runtime::Registry::Get(func_name)) {
|
||
|
auto env_func = EnvFunc::Get(func_name);
|
||
|
env_type_rel_func = env_func;
|
||
|
} else {
|
||
|
runtime::Registry::Register(func_name)
|
||
|
.set_body_typed<Array<Type>(const Array<Type>&, int)>(type_rel_func);
|
||
|
auto env_func = EnvFunc::Get(func_name);
|
||
|
env_type_rel_func = env_func;
|
||
|
}
|
||
|
|
||
|
std::vector<TypeParam> type_params;
|
||
|
std::vector<Type> arg_types;
|
||
|
|
||
|
// Add inputs.
|
||
|
std::string input_name_prefix = "in";
|
||
|
for (int i = 0; i < get()->num_inputs; i++) {
|
||
|
auto name = input_name_prefix + std::to_string(i);
|
||
|
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType);
|
||
|
type_params.push_back(param);
|
||
|
arg_types.push_back(param);
|
||
|
}
|
||
|
|
||
|
auto ty_call_args = Array<Type>(arg_types);
|
||
|
|
||
|
// Add output type.
|
||
|
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
|
||
|
type_params.push_back(out_param);
|
||
|
ty_call_args.push_back(out_param);
|
||
|
|
||
|
TypeConstraint type_rel =
|
||
|
TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args);
|
||
|
|
||
|
auto func_type =
|
||
|
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
|
||
|
|
||
|
get()->op_type = func_type;
|
||
|
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
|
||
|
get()->num_inputs = n;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
|
||
|
const std::string& type_key) {
|
||
|
get()->attrs_type_key = type_key;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*)
|
||
|
get()->support_level = n;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
template <typename ValueType>
|
||
|
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
|
||
|
const std::string& attr_name, const ValueType& value, int plevel) {
|
||
|
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
|
||
|
TVMRetValue rv;
|
||
|
rv = value;
|
||
|
UpdateAttr(attr_name, rv, plevel);
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// member functions of OpMap
|
||
|
inline int GenericOpMap::count(const Op& op) const {
|
||
|
if (op.defined()) {
|
||
|
const uint32_t idx = op->index_;
|
||
|
return idx < data_.size() ? (data_[idx].second != 0) : 0;
|
||
|
} else {
|
||
|
return 0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
|
||
|
CHECK(op.defined());
|
||
|
const uint32_t idx = op->index_;
|
||
|
CHECK(idx < data_.size() && data_[idx].second != 0)
|
||
|
<< "Attribute " << attr_name_ << " has not been registered for Operator "
|
||
|
<< op->name;
|
||
|
return data_[idx].first;
|
||
|
}
|
||
|
|
||
|
template <typename ValueType>
|
||
|
inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
|
||
|
CHECK(op.defined());
|
||
|
const uint32_t idx = op->index_;
|
||
|
if (idx < data_.size() && data_[idx].second != 0) {
|
||
|
return data_[idx].first;
|
||
|
} else {
|
||
|
return value;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <typename ValueType>
|
||
|
inline int OpMap<ValueType>::count(const Op& op) const {
|
||
|
return map_.count(op);
|
||
|
}
|
||
|
|
||
|
template <typename ValueType>
|
||
|
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
|
||
|
return map_[op];
|
||
|
}
|
||
|
template <typename ValueType>
|
||
|
inline ValueType OpMap<ValueType>::get(const Op& op,
|
||
|
ValueType def_value) const {
|
||
|
return map_.get<ValueType>(op, def_value);
|
||
|
}
|
||
|
|
||
|
} // namespace relay
|
||
|
} // namespace tvm
|
||
|
#endif // TVM_RELAY_OP_H_
|