572 строки
17 KiB
C++
572 строки
17 KiB
C++
/*!
|
|
* Copyright (c) 2018 by Contributors
|
|
* \file tvm/relay/op.h
|
|
* \brief Primitive operator definition.
|
|
*/
|
|
#ifndef TVM_RELAY_OP_H_
|
|
#define TVM_RELAY_OP_H_
|
|
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <string>
|
|
#include <typeinfo>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "base.h"
|
|
#include "expr.h"
|
|
#include "type.h"
|
|
|
|
namespace tvm {
|
|
namespace relay {
|
|
|
|
// forward declare name.
|
|
template <typename ValueType>
|
|
class OpMap;
|
|
class GenericOpMap;
|
|
class OpRegistry;
|
|
|
|
/*!
|
|
* \brief Node container of operator structure.
|
|
*/
|
|
class OpNode : public relay::ExprNode {
|
|
public:
|
|
/*! \brief name of the operator */
|
|
std::string name;
|
|
/*! \brief the type of the operator */
|
|
mutable FuncType op_type;
|
|
/*!
|
|
* \brief detailed description of the operator
|
|
* This can be used to generate docstring automatically for the operator.
|
|
*/
|
|
std::string description;
|
|
/* \brief Information of input arguments to the operator */
|
|
Array<AttrFieldInfo> arguments;
|
|
/*!
|
|
* \brief The type key of the attribute field
|
|
* This can be empty, in which case it defaults to anything.
|
|
*/
|
|
std::string attrs_type_key;
|
|
/*!
|
|
* \brief attribute type index,
|
|
* this field varies in each run and is not exposed to frontend.
|
|
*/
|
|
uint32_t attrs_type_index{0};
|
|
/*!
|
|
* \brief number of input arguments to the operator,
|
|
* -1 means it is variable length
|
|
*/
|
|
int32_t num_inputs = -1;
|
|
/*!
|
|
* \brief support level of the operator,
|
|
* The lower the more priority it contains.
|
|
* This is in analogies to BLAS levels.
|
|
*/
|
|
int32_t support_level = 10;
|
|
|
|
void VisitAttrs(tvm::AttrVisitor* v) final {
|
|
v->Visit("name", &name);
|
|
v->Visit("op_type", &op_type);
|
|
v->Visit("description", &description);
|
|
v->Visit("arguments", &arguments);
|
|
v->Visit("attrs_type_key", &attrs_type_key);
|
|
v->Visit("num_inputs", &num_inputs);
|
|
v->Visit("support_level", &support_level);
|
|
}
|
|
|
|
/*!
|
|
* \brief Check that if current op is a "primtive operator".
|
|
* That is the arguments are all type variables, and there is a single
|
|
* type relation applied to the input and output types.
|
|
*/
|
|
bool IsPrimitiveOp() const {
|
|
if (is_primitive_ != -1) return is_primitive_ != 0;
|
|
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
|
|
return is_primitive_ != 0;
|
|
}
|
|
|
|
static constexpr const char* _type_key = "relay.Op";
|
|
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
|
|
|
|
private:
|
|
// friend class
|
|
friend class GenericOpMap;
|
|
friend class OpRegistry;
|
|
friend bool IsPrimitiveOp(const Expr&);
|
|
// Program internal unique index of operator.
|
|
// Used to help index the program.
|
|
uint32_t index_{0};
|
|
// whether this is a primitive op. -1 means unknown.
|
|
mutable int is_primitive_{-1};
|
|
// Internal function to compute if it is primitive op
|
|
bool IsPrimitiveOp_() const {
|
|
const auto& fn_ty = this->op_type;
|
|
if (fn_ty->type_constraints.size() != 1) return false;
|
|
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
|
|
if (rel == nullptr) return false;
|
|
// validate if the type parameter matches up
|
|
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
|
|
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/*!
|
|
* \brief Operator reference class.
|
|
*/
|
|
class Op : public relay::Expr {
|
|
public:
|
|
/*! \brief default constructor */
|
|
Op() {}
|
|
/*! \brief constructor from node pointer */
|
|
explicit Op(NodePtr<Node> n) : Expr(n) {}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const OpNode* operator->() const;
|
|
/*!
|
|
* \brief Get additional registered attribute about operators.
|
|
* If nothing has been registered, an empty OpMap will be returned.
|
|
* \param attr_name The name of the attribute.
|
|
* \return An OpMap of specified attr_name.
|
|
* \tparam ValueType The type of the attribute.
|
|
*/
|
|
template <typename ValueType>
|
|
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
|
|
/*!
|
|
* \brief Get an Op for a given operator name.
|
|
* Will raise an error if the op has not been registered.
|
|
* \param op_name Name of the operator.
|
|
* \return Pointer to a Op, valid throughout program lifetime.
|
|
*/
|
|
TVM_DLL static const Op& Get(const std::string& op_name);
|
|
|
|
/*! \brief specify container node */
|
|
using ContainerType = OpNode;
|
|
|
|
private:
|
|
/*!
|
|
* \brief Get generic attrmap given attr name
|
|
* \param key The attribute key
|
|
* \return reference to GenericOpMap
|
|
*/
|
|
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
|
|
};
|
|
|
|
/*! \brief Helper structure to register operators */
|
|
class OpRegistry {
|
|
public:
|
|
/*! \return the operator */
|
|
const Op& op() const { return op_; }
|
|
/*!
|
|
* \brief setter function during registration
|
|
* Set the description of operator
|
|
* \param descr the description string.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& describe(const std::string& descr); // NOLINT(*)
|
|
/*!
|
|
* \brief Add argument information to the function.
|
|
* \param name Name of the argument.
|
|
* \param type Type of the argument.
|
|
* \param description Description of the argument.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& add_argument(const std::string& name,
|
|
const std::string& type,
|
|
const std::string& description);
|
|
/*!
|
|
* \brief Attach the type function corresponding to the return type.
|
|
* \param rel_name The type relation name to register.
|
|
* \param type_rel_func The backing relation function which can solve an arbitrary
|
|
* relation on variables.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& add_type_rel(
|
|
const std::string& rel_name,
|
|
runtime::TypedPackedFunc<bool(const Array<Type>&,
|
|
int,
|
|
const Attrs&,
|
|
const TypeReporter&)> type_rel_func);
|
|
/*!
|
|
* \brief Set the type key of attributes.
|
|
* \param type_key The type of of the attrs field.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& set_attrs_type_key(const std::string& type_key);
|
|
/*!
|
|
* \brief Set the num_inputs
|
|
* \param n The number of inputs to be set.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*)
|
|
/*!
|
|
* \brief Set the support level of op.
|
|
* \param level The support level.
|
|
* \return reference to self.
|
|
*/
|
|
inline OpRegistry& set_support_level(int32_t level); // NOLINT(*)
|
|
/*!
|
|
* \brief Register additional attributes to operator.
|
|
* \param attr_name The name of the attribute.
|
|
* \param value The value to be set.
|
|
* \param plevel The priority level of this set,
|
|
* an higher priority level attribute
|
|
* will replace lower priority level attribute.
|
|
* Must be bigger than 0.
|
|
*
|
|
* Cannot set with same plevel twice in the code.
|
|
*
|
|
* \tparam ValueType The type of the value to be set.
|
|
*/
|
|
template <typename ValueType>
|
|
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
|
|
const ValueType& value, int plevel = 10);
|
|
|
|
// set the name of the op to be the same as registry
|
|
inline OpRegistry& set_name() { // NOLINT(*)
|
|
if (get()->name.length() == 0) {
|
|
get()->name = name;
|
|
}
|
|
return *this;
|
|
}
|
|
/*! \return The global single registry */
|
|
TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
|
|
|
|
private:
|
|
friend class ::dmlc::Registry<OpRegistry>;
|
|
// the name
|
|
std::string name;
|
|
/*! \brief The operator */
|
|
Op op_;
|
|
// private constructor
|
|
OpRegistry();
|
|
// return internal pointer to op.
|
|
inline OpNode* get();
|
|
// update the attribute OpMap
|
|
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
|
|
int plevel);
|
|
};
|
|
|
|
/*!
|
|
* \brief Generic map to store additional information of Op.
|
|
*/
|
|
class GenericOpMap {
|
|
public:
|
|
/*!
|
|
* \brief Check if the map has op as key.
|
|
* \param op The key to the map
|
|
* \return 1 if op is contained in map, 0 otherwise.
|
|
*/
|
|
inline int count(const Op& op) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op
|
|
* \param op The key to the map
|
|
* \return the const reference to the content value.
|
|
*/
|
|
inline const TVMRetValue& operator[](const Op& op) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op with default value.
|
|
* \param op The key to the map
|
|
* \param def_value The default value when the key does not exist.
|
|
* \return the const reference to the content value.
|
|
* \tparam ValueType The content value type.
|
|
*/
|
|
template <typename ValueType>
|
|
inline ValueType get(const Op& op, ValueType def_value) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op with default value.
|
|
* \param expr The key to the map
|
|
* \param def_value The default value when the key does not exist
|
|
* or if expr is not an Op.
|
|
* \return the const reference to the content value.
|
|
* \tparam ValueType The content value type.
|
|
*/
|
|
template <typename ValueType>
|
|
inline ValueType get(const Expr& expr, ValueType def_value) const;
|
|
|
|
private:
|
|
friend class OpRegistry;
|
|
// the attribute field.
|
|
std::string attr_name_;
|
|
// internal data
|
|
std::vector<std::pair<TVMRetValue, int> > data_;
|
|
// The value
|
|
GenericOpMap() = default;
|
|
};
|
|
|
|
/*!
|
|
* \brief Map<Op,ValueType> used to store meta-information about Op.
|
|
* \tparam ValueType The type of the value stored in map.
|
|
*/
|
|
template <typename ValueType>
|
|
class OpMap {
|
|
public:
|
|
/*!
|
|
* \brief Check if the map has op as key.
|
|
* \param op The key to the map
|
|
* \return 1 if op is contained in map, 0 otherwise.
|
|
*/
|
|
inline int count(const Op& op) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op
|
|
* \param op The key to the map
|
|
* \return the const reference to the content value.
|
|
*/
|
|
inline ValueType operator[](const Op& op) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op with default value.
|
|
* \param op The key to the map
|
|
* \param def_value The default value when the key does not exist.
|
|
* \return the const reference to the content value.
|
|
*/
|
|
inline ValueType get(const Op& op, ValueType def_value) const;
|
|
/*!
|
|
* \brief get the corresponding value element at op with default value.
|
|
* \param expr The key to the map
|
|
* \param def_value The default value when the key does not exist
|
|
* or if expr is not an Op.
|
|
* \return the const reference to the content value.
|
|
*/
|
|
inline ValueType get(const Expr& expr, ValueType def_value) const;
|
|
|
|
private:
|
|
friend class Op;
|
|
// constructor
|
|
explicit OpMap(const GenericOpMap& map) : map_(map) {}
|
|
/*! \brief The internal map field */
|
|
const GenericOpMap& map_;
|
|
};
|
|
|
|
// internal macros to make
|
|
#define RELAY_REGISTER_VAR_DEF \
|
|
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
|
|
|
|
/*!
|
|
* \def RELAY_REGISTER_OP
|
|
* \brief Register a new operator, or set attribute of the corresponding op.
|
|
*
|
|
* \param OpName The name of registry
|
|
*
|
|
* \code
|
|
*
|
|
* RELAY_REGISTER_OP("add")
|
|
* .describe("add two inputs together")
|
|
* .set_num_inputs(2)
|
|
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
|
|
*
|
|
* \endcode
|
|
*/
|
|
#define RELAY_REGISTER_OP(OpName) \
|
|
DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
|
|
::tvm::relay::OpRegistry::Registry() \
|
|
->__REGISTER_OR_GET__(OpName) \
|
|
.set_name()
|
|
|
|
// implementations
|
|
inline const OpNode* Op::operator->() const {
|
|
return static_cast<const OpNode*>(node_.get());
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
|
|
return OpMap<ValueType>(Op::GetGenericAttr(key));
|
|
}
|
|
|
|
inline OpNode* OpRegistry::get() {
|
|
return const_cast<OpNode*>(op_.operator->());
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::describe(
|
|
const std::string& descr) { // NOLINT(*)
|
|
get()->description = descr;
|
|
return *this;
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
|
|
const std::string& type,
|
|
const std::string& description) {
|
|
auto n = make_node<AttrFieldInfoNode>();
|
|
n->name = name;
|
|
n->type_info = type;
|
|
n->description = description;
|
|
get()->arguments.push_back(AttrFieldInfo(n));
|
|
return *this;
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::add_type_rel(
|
|
const std::string& rel_name,
|
|
runtime::TypedPackedFunc<bool(const Array<Type>&,
|
|
int,
|
|
const Attrs&,
|
|
const TypeReporter&)> type_rel_func) {
|
|
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
|
|
TypeRelationFn env_type_rel_func;
|
|
|
|
if (runtime::Registry::Get(func_name)) {
|
|
auto env_func = EnvFunc::Get(func_name);
|
|
env_type_rel_func = env_func;
|
|
} else {
|
|
runtime::Registry::Register(func_name)
|
|
.set_body(type_rel_func.packed());
|
|
auto env_func = EnvFunc::Get(func_name);
|
|
env_type_rel_func = env_func;
|
|
}
|
|
|
|
Array<TypeVar> type_params;
|
|
Array<Type> arg_types;
|
|
|
|
// Add inputs.
|
|
std::string input_name_prefix = "in";
|
|
for (int i = 0; i < get()->num_inputs; i++) {
|
|
auto name = input_name_prefix + std::to_string(i);
|
|
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
|
|
type_params.push_back(param);
|
|
arg_types.push_back(param);
|
|
}
|
|
|
|
Array<Type> ty_call_args = arg_types;
|
|
|
|
// Add output type.
|
|
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
|
|
type_params.push_back(out_param);
|
|
// this will trigger copy on write.
|
|
ty_call_args.push_back(out_param);
|
|
|
|
// The attributes of primitive op is nullptr
|
|
//
|
|
// The attributes of primitive operator can vary at the call site.
|
|
// The type of sum is also dependent on Attrs being passed.
|
|
// So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
|
|
//
|
|
// A common example is sum(x, axis), where the choice of axis
|
|
// can affect the type of the function.
|
|
TypeConstraint type_rel =
|
|
TypeRelationNode::make(env_type_rel_func,
|
|
ty_call_args,
|
|
arg_types.size(),
|
|
Attrs());
|
|
|
|
auto func_type =
|
|
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
|
|
|
|
get()->op_type = func_type;
|
|
|
|
return *this;
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
|
|
get()->num_inputs = n;
|
|
return *this;
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
|
|
const std::string& type_key) {
|
|
get()->attrs_type_key = type_key;
|
|
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
|
|
return *this;
|
|
}
|
|
|
|
inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*)
|
|
get()->support_level = n;
|
|
return *this;
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
|
|
const std::string& attr_name, const ValueType& value, int plevel) {
|
|
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
|
|
TVMRetValue rv;
|
|
rv = value;
|
|
UpdateAttr(attr_name, rv, plevel);
|
|
return *this;
|
|
}
|
|
|
|
// member functions of OpMap
|
|
inline int GenericOpMap::count(const Op& op) const {
|
|
if (op.defined()) {
|
|
const uint32_t idx = op->index_;
|
|
return idx < data_.size() ? (data_[idx].second != 0) : 0;
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
|
|
CHECK(op.defined());
|
|
const uint32_t idx = op->index_;
|
|
CHECK(idx < data_.size() && data_[idx].second != 0)
|
|
<< "Attribute " << attr_name_ << " has not been registered for Operator "
|
|
<< op->name;
|
|
return data_[idx].first;
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
|
|
CHECK(op.defined());
|
|
const uint32_t idx = op->index_;
|
|
if (idx < data_.size() && data_[idx].second != 0) {
|
|
return data_[idx].first;
|
|
} else {
|
|
return value;
|
|
}
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
|
|
CHECK(expr.defined());
|
|
if (const OpNode* op = expr.as<OpNode>()) {
|
|
const uint32_t idx = op->index_;
|
|
if (idx < data_.size() && data_[idx].second != 0) {
|
|
return data_[idx].first;
|
|
} else {
|
|
return value;
|
|
}
|
|
} else {
|
|
return value;
|
|
}
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline int OpMap<ValueType>::count(const Op& op) const {
|
|
return map_.count(op);
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
|
|
return map_[op];
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline ValueType OpMap<ValueType>::get(const Op& op,
|
|
ValueType def_value) const {
|
|
return map_.get<ValueType>(op, def_value);
|
|
}
|
|
|
|
template <typename ValueType>
|
|
inline ValueType OpMap<ValueType>::get(const Expr& expr,
|
|
ValueType def_value) const {
|
|
return map_.get<ValueType>(expr, def_value);
|
|
}
|
|
|
|
/*!
|
|
* \brief Check that an expression is a "primtive operator".
|
|
*
|
|
* Will return true if the expression is an operator which
|
|
* matches the form of primtive operators registered directly
|
|
* by the Relay codebase.
|
|
*
|
|
* That is the arguments are all type variables, and there is a single
|
|
* type relation applied to the input and output types.
|
|
*/
|
|
inline bool IsPrimitiveOp(const Expr& expr) {
|
|
const auto* op = expr.as<OpNode>();
|
|
return op != nullptr && op->IsPrimitiveOp();
|
|
}
|
|
|
|
} // namespace relay
|
|
} // namespace tvm
|
|
#endif // TVM_RELAY_OP_H_
|