[NODE] Keep base node system in HalideIR (#1793)
This commit is contained in:
Родитель
06108bed43
Коммит
46363d0a74
|
@ -1 +1 @@
|
||||||
Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db
|
Subproject commit 2f3ecdfdedf3efa7e45a3945dca63a25856c4674
|
|
@ -11,7 +11,7 @@
|
||||||
#include "base.h"
|
#include "base.h"
|
||||||
#include "expr.h"
|
#include "expr.h"
|
||||||
#include "ir_operator.h"
|
#include "ir_operator.h"
|
||||||
#include "node/container.h"
|
#include "tvm/node/container.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#ifndef TVM_IR_FUNCTOR_EXT_H_
|
#ifndef TVM_IR_FUNCTOR_EXT_H_
|
||||||
#define TVM_IR_FUNCTOR_EXT_H_
|
#define TVM_IR_FUNCTOR_EXT_H_
|
||||||
|
|
||||||
#include "node/ir_functor.h"
|
#include "tvm/node/ir_functor.h"
|
||||||
#include "ir.h"
|
#include "ir.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "expr.h"
|
#include "expr.h"
|
||||||
#include "ir.h"
|
#include "ir.h"
|
||||||
#include "node/ir_functor.h"
|
#include "tvm/node/ir_functor.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace ir {
|
namespace ir {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#define TVM_IR_VISITOR_H_
|
#define TVM_IR_VISITOR_H_
|
||||||
|
|
||||||
#include "ir.h"
|
#include "ir.h"
|
||||||
#include "node/ir_functor.h"
|
#include "tvm/node/ir_functor.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace ir {
|
namespace ir {
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "base.h"
|
#include "base.h"
|
||||||
#include "expr.h"
|
#include "expr.h"
|
||||||
#include "tensor.h"
|
#include "tensor.h"
|
||||||
#include "node/container.h"
|
#include "tvm/node/container.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
|
|
@ -1,586 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2018 by Contributors
|
|
||||||
* \file tvm/node/container.h
|
|
||||||
* \brief Array/Map container in the DSL graph.
|
|
||||||
*/
|
|
||||||
#ifndef TVM_NODE_CONTAINER_H_
|
|
||||||
#define TVM_NODE_CONTAINER_H_
|
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
#include <vector>
|
|
||||||
#include <initializer_list>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
#include <string>
|
|
||||||
#include "node.h"
|
|
||||||
#include "memory.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
|
|
||||||
/*! \brief array node content in array */
|
|
||||||
class ArrayNode : public Node {
|
|
||||||
public:
|
|
||||||
/*! \brief the data content */
|
|
||||||
std::vector<NodePtr<Node> > data;
|
|
||||||
|
|
||||||
void VisitAttrs(AttrVisitor* visitor) final {
|
|
||||||
// Visitor to array have no effect.
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr const char* _type_key = "Array";
|
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node);
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief map node content */
|
|
||||||
class MapNode : public Node {
|
|
||||||
public:
|
|
||||||
void VisitAttrs(AttrVisitor* visitor) final {
|
|
||||||
// Visitor to map have no effect.
|
|
||||||
}
|
|
||||||
// hash function
|
|
||||||
struct Hash {
|
|
||||||
size_t operator()(const NodePtr<Node>& n) const {
|
|
||||||
return std::hash<Node*>()(n.get());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// comparator
|
|
||||||
struct Equal {
|
|
||||||
bool operator()(
|
|
||||||
const NodePtr<Node>& a,
|
|
||||||
const NodePtr<Node>& b) const {
|
|
||||||
return a.get() == b.get();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief The corresponding conatiner type */
|
|
||||||
using ContainerType = std::unordered_map<
|
|
||||||
NodePtr<Node>,
|
|
||||||
NodePtr<Node>,
|
|
||||||
Hash, Equal>;
|
|
||||||
|
|
||||||
/*! \brief the data content */
|
|
||||||
ContainerType data;
|
|
||||||
|
|
||||||
static constexpr const char* _type_key = "Map";
|
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/*! \brief specialized map node with string as key */
|
|
||||||
class StrMapNode : public Node {
|
|
||||||
public:
|
|
||||||
void VisitAttrs(AttrVisitor* visitor) final {
|
|
||||||
// Visitor to map have no effect.
|
|
||||||
}
|
|
||||||
/*! \brief The corresponding conatiner type */
|
|
||||||
using ContainerType = std::unordered_map<
|
|
||||||
std::string,
|
|
||||||
NodePtr<Node> >;
|
|
||||||
|
|
||||||
/*! \brief the data content */
|
|
||||||
ContainerType data;
|
|
||||||
|
|
||||||
static constexpr const char* _type_key = "StrMap";
|
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node);
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief iterator adapter that adapts TIter to return another type.
|
|
||||||
* \tparam Converter a struct that contains converting function
|
|
||||||
* \tparam TIter the content iterator type.
|
|
||||||
*/
|
|
||||||
template<typename Converter,
|
|
||||||
typename TIter>
|
|
||||||
class IterAdapter {
|
|
||||||
public:
|
|
||||||
explicit IterAdapter(TIter iter) : iter_(iter) {}
|
|
||||||
inline IterAdapter& operator++() { // NOLINT(*)
|
|
||||||
++iter_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline IterAdapter& operator++(int) { // NOLINT(*)
|
|
||||||
++iter_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
inline IterAdapter operator+(int offset) const { // NOLINT(*)
|
|
||||||
return IterAdapter(iter_ + offset);
|
|
||||||
}
|
|
||||||
inline bool operator==(IterAdapter other) const {
|
|
||||||
return iter_ == other.iter_;
|
|
||||||
}
|
|
||||||
inline bool operator!=(IterAdapter other) const {
|
|
||||||
return !(*this == other);
|
|
||||||
}
|
|
||||||
inline const typename Converter::ResultType operator*() const {
|
|
||||||
return Converter::convert(*iter_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
TIter iter_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Array container of NodeRef in DSL graph.
|
|
||||||
* Array implements copy on write semantics, which means array is mutable
|
|
||||||
* but copy will happen when array is referenced in more than two places.
|
|
||||||
*
|
|
||||||
* operator[] only provide const acces, use Set to mutate the content.
|
|
||||||
* \tparam T The content NodeRef type.
|
|
||||||
*/
|
|
||||||
template<typename T,
|
|
||||||
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
|
|
||||||
class Array : public NodeRef {
|
|
||||||
public:
|
|
||||||
/*!
|
|
||||||
* \brief default constructor
|
|
||||||
*/
|
|
||||||
Array() {
|
|
||||||
node_ = make_node<ArrayNode>();
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief move constructor
|
|
||||||
* \param other source
|
|
||||||
*/
|
|
||||||
Array(Array<T> && other) { // NOLINT(*)
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy constructor
|
|
||||||
* \param other source
|
|
||||||
*/
|
|
||||||
Array(const Array<T> &other) { // NOLINT(*)
|
|
||||||
node_ = other.node_;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from pointer
|
|
||||||
* \param n the container pointer
|
|
||||||
*/
|
|
||||||
explicit Array(NodePtr<Node> n) : NodeRef(n) {}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from iterator
|
|
||||||
* \param begin begin of iterator
|
|
||||||
* \param end end of iterator
|
|
||||||
* \tparam IterType The type of iterator
|
|
||||||
*/
|
|
||||||
template<typename IterType>
|
|
||||||
Array(IterType begin, IterType end) {
|
|
||||||
assign(begin, end);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from initializer list
|
|
||||||
* \param init The initalizer list
|
|
||||||
*/
|
|
||||||
Array(std::initializer_list<T> init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from vector
|
|
||||||
* \param init The vector
|
|
||||||
*/
|
|
||||||
Array(const std::vector<T>& init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief move assign operator
|
|
||||||
* \param other The source of assignment
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
Array<T>& operator=(Array<T> && other) {
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy assign operator
|
|
||||||
* \param other The source of assignment
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
Array<T>& operator=(const Array<T> & other) {
|
|
||||||
node_ = other.node_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief reset the array to content from iterator.
|
|
||||||
* \param begin begin of iterator
|
|
||||||
* \param end end of iterator
|
|
||||||
* \tparam IterType The type of iterator
|
|
||||||
*/
|
|
||||||
template<typename IterType>
|
|
||||||
void assign(IterType begin, IterType end) {
|
|
||||||
auto n = make_node<ArrayNode>();
|
|
||||||
for (IterType it = begin; it != end; ++it) {
|
|
||||||
n->data.push_back((*it).node_);
|
|
||||||
}
|
|
||||||
node_ = std::move(n);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief Read i-th element from array.
|
|
||||||
* \param i The index
|
|
||||||
* \return the i-th element.
|
|
||||||
*/
|
|
||||||
inline const T operator[](size_t i) const {
|
|
||||||
return T(static_cast<const ArrayNode*>(node_.get())->data[i]);
|
|
||||||
}
|
|
||||||
/*! \return The size of the array */
|
|
||||||
inline size_t size() const {
|
|
||||||
if (node_.get() == nullptr) return 0;
|
|
||||||
return static_cast<const ArrayNode*>(node_.get())->data.size();
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy on write semantics
|
|
||||||
* Do nothing if current handle is the unique copy of the array.
|
|
||||||
* Otherwise make a new copy of the array to ensure the current handle
|
|
||||||
* hold a unique copy.
|
|
||||||
*
|
|
||||||
* \return Handle to the internal node container(which ganrantees to be unique)
|
|
||||||
*/
|
|
||||||
inline ArrayNode* CopyOnWrite() {
|
|
||||||
if (node_.get() == nullptr || !node_.unique()) {
|
|
||||||
NodePtr<ArrayNode> n = make_node<ArrayNode>();
|
|
||||||
n->data = static_cast<ArrayNode*>(node_.get())->data;
|
|
||||||
NodePtr<Node>(std::move(n)).swap(node_);
|
|
||||||
}
|
|
||||||
return static_cast<ArrayNode*>(node_.get());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief push a new item to the back of the list
|
|
||||||
* \param item The item to be pushed.
|
|
||||||
*/
|
|
||||||
inline void push_back(const T& item) {
|
|
||||||
ArrayNode* n = this->CopyOnWrite();
|
|
||||||
n->data.push_back(item.node_);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set i-th element of the array.
|
|
||||||
* \param i The index
|
|
||||||
* \param value The value to be setted.
|
|
||||||
*/
|
|
||||||
inline void Set(size_t i, const T& value) {
|
|
||||||
ArrayNode* n = this->CopyOnWrite();
|
|
||||||
n->data[i] = value.node_;
|
|
||||||
}
|
|
||||||
/*! \return whether array is empty */
|
|
||||||
inline bool empty() const {
|
|
||||||
return size() == 0;
|
|
||||||
}
|
|
||||||
/*! \brief specify container node */
|
|
||||||
using ContainerType = ArrayNode;
|
|
||||||
|
|
||||||
struct Ptr2NodeRef {
|
|
||||||
using ResultType = T;
|
|
||||||
static inline T convert(const NodePtr<Node>& n) {
|
|
||||||
return T(n);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
using iterator = IterAdapter<Ptr2NodeRef,
|
|
||||||
std::vector<NodePtr<Node> >::const_iterator>;
|
|
||||||
|
|
||||||
using reverse_iterator = IterAdapter<
|
|
||||||
Ptr2NodeRef,
|
|
||||||
std::vector<NodePtr<Node> >::const_reverse_iterator>;
|
|
||||||
|
|
||||||
/*! \return begin iterator */
|
|
||||||
inline iterator begin() const {
|
|
||||||
return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin());
|
|
||||||
}
|
|
||||||
/*! \return end iterator */
|
|
||||||
inline iterator end() const {
|
|
||||||
return iterator(static_cast<const ArrayNode*>(node_.get())->data.end());
|
|
||||||
}
|
|
||||||
/*! \return rbegin iterator */
|
|
||||||
inline reverse_iterator rbegin() const {
|
|
||||||
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin());
|
|
||||||
}
|
|
||||||
/*! \return rend iterator */
|
|
||||||
inline reverse_iterator rend() const {
|
|
||||||
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Map container of NodeRef->NodeRef in DSL graph.
|
|
||||||
* Map implements copy on write semantics, which means map is mutable
|
|
||||||
* but copy will happen when array is referenced in more than two places.
|
|
||||||
*
|
|
||||||
* operator[] only provide const acces, use Set to mutate the content.
|
|
||||||
* \tparam K The key NodeRef type.
|
|
||||||
* \tparam V The value NodeRef type.
|
|
||||||
*/
|
|
||||||
template<typename K,
|
|
||||||
typename V,
|
|
||||||
typename = typename std::enable_if<
|
|
||||||
std::is_base_of<NodeRef, K>::value ||
|
|
||||||
std::is_base_of<std::string, K>::value >::type,
|
|
||||||
typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
|
|
||||||
class Map : public NodeRef {
|
|
||||||
public:
|
|
||||||
/*!
|
|
||||||
* \brief default constructor
|
|
||||||
*/
|
|
||||||
Map() {
|
|
||||||
node_ = make_node<MapNode>();
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief move constructor
|
|
||||||
* \param other source
|
|
||||||
*/
|
|
||||||
Map(Map<K, V> && other) { // NOLINT(*)
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy constructor
|
|
||||||
* \param other source
|
|
||||||
*/
|
|
||||||
Map(const Map<K, V> &other) { // NOLINT(*)
|
|
||||||
node_ = other.node_;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from pointer
|
|
||||||
* \param n the container pointer
|
|
||||||
*/
|
|
||||||
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from iterator
|
|
||||||
* \param begin begin of iterator
|
|
||||||
* \param end end of iterator
|
|
||||||
* \tparam IterType The type of iterator
|
|
||||||
*/
|
|
||||||
template<typename IterType>
|
|
||||||
Map(IterType begin, IterType end) {
|
|
||||||
assign(begin, end);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from initializer list
|
|
||||||
* \param init The initalizer list
|
|
||||||
*/
|
|
||||||
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief constructor from vector
|
|
||||||
* \param init The vector
|
|
||||||
*/
|
|
||||||
template<typename Hash, typename Equal>
|
|
||||||
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief move assign operator
|
|
||||||
* \param other The source of assignment
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
Map<K, V>& operator=(Map<K, V> && other) {
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy assign operator
|
|
||||||
* \param other The source of assignment
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
Map<K, V>& operator=(const Map<K, V> & other) {
|
|
||||||
node_ = other.node_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief reset the array to content from iterator.
|
|
||||||
* \param begin begin of iterator
|
|
||||||
* \param end end of iterator
|
|
||||||
* \tparam IterType The type of iterator
|
|
||||||
*/
|
|
||||||
template<typename IterType>
|
|
||||||
void assign(IterType begin, IterType end) {
|
|
||||||
NodePtr<MapNode> n = make_node<MapNode>();
|
|
||||||
for (IterType i = begin; i != end; ++i) {
|
|
||||||
n->data.emplace(std::make_pair(i->first.node_,
|
|
||||||
i->second.node_));
|
|
||||||
}
|
|
||||||
node_ = std::move(n);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief Read element from map.
|
|
||||||
* \param key The key
|
|
||||||
* \return the corresonding element.
|
|
||||||
*/
|
|
||||||
inline const V operator[](const K& key) const {
|
|
||||||
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief Read element from map.
|
|
||||||
* \param key The key
|
|
||||||
* \return the corresonding element.
|
|
||||||
*/
|
|
||||||
inline const V at(const K& key) const {
|
|
||||||
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
|
|
||||||
}
|
|
||||||
/*! \return The size of the array */
|
|
||||||
inline size_t size() const {
|
|
||||||
if (node_.get() == nullptr) return 0;
|
|
||||||
return static_cast<const MapNode*>(node_.get())->data.size();
|
|
||||||
}
|
|
||||||
/*! \return The size of the array */
|
|
||||||
inline size_t count(const K& key) const {
|
|
||||||
if (node_.get() == nullptr) return 0;
|
|
||||||
return static_cast<const MapNode*>(node_.get())->data.count(key.node_);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief copy on write semantics
|
|
||||||
* Do nothing if current handle is the unique copy of the array.
|
|
||||||
* Otherwise make a new copy of the array to ensure the current handle
|
|
||||||
* hold a unique copy.
|
|
||||||
*
|
|
||||||
* \return Handle to the internal node container(which ganrantees to be unique)
|
|
||||||
*/
|
|
||||||
inline MapNode* CopyOnWrite() {
|
|
||||||
if (node_.get() == nullptr || !node_.unique()) {
|
|
||||||
NodePtr<MapNode> n = make_node<MapNode>();
|
|
||||||
n->data = static_cast<const MapNode*>(node_.get())->data;
|
|
||||||
NodePtr<Node>(std::move(n)).swap(node_);
|
|
||||||
}
|
|
||||||
return static_cast<MapNode*>(node_.get());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set the Map.
|
|
||||||
* \param key The index key.
|
|
||||||
* \param value The value to be setted.
|
|
||||||
*/
|
|
||||||
inline void Set(const K& key, const V& value) {
|
|
||||||
MapNode* n = this->CopyOnWrite();
|
|
||||||
n->data[key.node_] = value.node_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \return whether array is empty */
|
|
||||||
inline bool empty() const {
|
|
||||||
return size() == 0;
|
|
||||||
}
|
|
||||||
/*! \brief specify container node */
|
|
||||||
using ContainerType = MapNode;
|
|
||||||
|
|
||||||
struct Ptr2NodeRef {
|
|
||||||
using ResultType = std::pair<K, V>;
|
|
||||||
static inline ResultType convert(const std::pair<
|
|
||||||
NodePtr<Node>,
|
|
||||||
NodePtr<Node> >& n) {
|
|
||||||
return std::make_pair(K(n.first), V(n.second));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using iterator = IterAdapter<
|
|
||||||
Ptr2NodeRef, MapNode::ContainerType::const_iterator>;
|
|
||||||
|
|
||||||
/*! \return begin iterator */
|
|
||||||
inline iterator begin() const {
|
|
||||||
return iterator(static_cast<const MapNode*>(node_.get())->data.begin());
|
|
||||||
}
|
|
||||||
/*! \return end iterator */
|
|
||||||
inline iterator end() const {
|
|
||||||
return iterator(static_cast<const MapNode*>(node_.get())->data.end());
|
|
||||||
}
|
|
||||||
/*! \return begin iterator */
|
|
||||||
inline iterator find(const K& key) const {
|
|
||||||
return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// specialize of string map
|
|
||||||
template<typename V, typename T1, typename T2>
|
|
||||||
class Map<std::string, V, T1, T2> : public NodeRef {
|
|
||||||
public:
|
|
||||||
// for code reuse
|
|
||||||
Map() {
|
|
||||||
node_ = make_node<StrMapNode>();
|
|
||||||
}
|
|
||||||
Map(Map<std::string, V> && other) { // NOLINT(*)
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
}
|
|
||||||
Map(const Map<std::string, V> &other) { // NOLINT(*)
|
|
||||||
node_ = other.node_;
|
|
||||||
}
|
|
||||||
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
|
|
||||||
template<typename IterType>
|
|
||||||
Map(IterType begin, IterType end) {
|
|
||||||
assign(begin, end);
|
|
||||||
}
|
|
||||||
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Hash, typename Equal>
|
|
||||||
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
|
|
||||||
assign(init.begin(), init.end());
|
|
||||||
}
|
|
||||||
Map<std::string, V>& operator=(Map<std::string, V> && other) {
|
|
||||||
node_ = std::move(other.node_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
|
|
||||||
node_ = other.node_;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
template<typename IterType>
|
|
||||||
void assign(IterType begin, IterType end) {
|
|
||||||
auto n = make_node<StrMapNode>();
|
|
||||||
for (IterType i = begin; i != end; ++i) {
|
|
||||||
n->data.emplace(std::make_pair(i->first,
|
|
||||||
i->second.node_));
|
|
||||||
}
|
|
||||||
node_ = std::move(n);
|
|
||||||
}
|
|
||||||
inline const V operator[](const std::string& key) const {
|
|
||||||
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
|
|
||||||
}
|
|
||||||
inline const V at(const std::string& key) const {
|
|
||||||
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
|
|
||||||
}
|
|
||||||
inline size_t size() const {
|
|
||||||
if (node_.get() == nullptr) return 0;
|
|
||||||
return static_cast<const StrMapNode*>(node_.get())->data.size();
|
|
||||||
}
|
|
||||||
inline size_t count(const std::string& key) const {
|
|
||||||
if (node_.get() == nullptr) return 0;
|
|
||||||
return static_cast<const StrMapNode*>(node_.get())->data.count(key);
|
|
||||||
}
|
|
||||||
inline StrMapNode* CopyOnWrite() {
|
|
||||||
if (node_.get() == nullptr || !node_.unique()) {
|
|
||||||
NodePtr<StrMapNode> n = make_node<StrMapNode>();
|
|
||||||
n->data = static_cast<const StrMapNode*>(node_.get())->data;
|
|
||||||
NodePtr<Node>(std::move(n)).swap(node_);
|
|
||||||
}
|
|
||||||
return static_cast<StrMapNode*>(node_.get());
|
|
||||||
}
|
|
||||||
inline void Set(const std::string& key, const V& value) {
|
|
||||||
StrMapNode* n = this->CopyOnWrite();
|
|
||||||
n->data[key] = value.node_;
|
|
||||||
}
|
|
||||||
inline bool empty() const {
|
|
||||||
return size() == 0;
|
|
||||||
}
|
|
||||||
using ContainerType = StrMapNode;
|
|
||||||
|
|
||||||
struct Ptr2NodeRef {
|
|
||||||
using ResultType = std::pair<std::string, V>;
|
|
||||||
static inline ResultType convert(const std::pair<
|
|
||||||
std::string,
|
|
||||||
NodePtr<Node> >& n) {
|
|
||||||
return std::make_pair(n.first, V(n.second));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using iterator = IterAdapter<
|
|
||||||
Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>;
|
|
||||||
|
|
||||||
/*! \return begin iterator */
|
|
||||||
inline iterator begin() const {
|
|
||||||
return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin());
|
|
||||||
}
|
|
||||||
/*! \return end iterator */
|
|
||||||
inline iterator end() const {
|
|
||||||
return iterator(static_cast<const StrMapNode*>(node_.get())->data.end());
|
|
||||||
}
|
|
||||||
/*! \return begin iterator */
|
|
||||||
inline iterator find(const std::string& key) const {
|
|
||||||
return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
||||||
#endif // TVM_NODE_CONTAINER_H_
|
|
|
@ -1,254 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2018 by Contributors
|
|
||||||
* \file tvm/node/ir_functor.h
|
|
||||||
* \brief Defines the IRFunctor data structures.
|
|
||||||
*/
|
|
||||||
#ifndef TVM_NODE_IR_FUNCTOR_H_
|
|
||||||
#define TVM_NODE_IR_FUNCTOR_H_
|
|
||||||
|
|
||||||
#include <dmlc/logging.h>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <functional>
|
|
||||||
#include "node.h"
|
|
||||||
#include "../runtime/registry.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
/*!
|
|
||||||
* \brief A dynamical dispatched functor on NodeRef in the first argument.
|
|
||||||
*
|
|
||||||
* \code
|
|
||||||
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
|
|
||||||
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
|
|
||||||
* return prefix + "Add";
|
|
||||||
* });
|
|
||||||
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
|
|
||||||
* return prefix + "IntImm"
|
|
||||||
* });
|
|
||||||
*
|
|
||||||
* Expr x = make_const(1);
|
|
||||||
* Expr y = x + x;
|
|
||||||
* // dispatch to IntImm, outputs "MyIntImm"
|
|
||||||
* LOG(INFO) << tostr(x, "My");
|
|
||||||
* // dispatch to IntImm, outputs "MyAdd"
|
|
||||||
* LOG(INFO) << tostr(y, "My");
|
|
||||||
* \endcode
|
|
||||||
*
|
|
||||||
* \tparam FType function signiture
|
|
||||||
* This type if only defined for FType with function signiture
|
|
||||||
*/
|
|
||||||
template<typename FType>
|
|
||||||
class IRFunctor;
|
|
||||||
|
|
||||||
template<typename R, typename ...Args>
|
|
||||||
class IRFunctor<R(const NodeRef& n, Args...)> {
|
|
||||||
private:
|
|
||||||
using Function = std::function<R (const NodeRef&n, Args...)>;
|
|
||||||
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
|
|
||||||
/*! \brief internal function table */
|
|
||||||
std::vector<Function> func_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
/*! \brief the result type of this functor */
|
|
||||||
using result_type = R;
|
|
||||||
/*!
|
|
||||||
* \brief Whether the functor can dispatch the corresponding Node
|
|
||||||
* \param n The node to be dispatched
|
|
||||||
* \return Whether dispatching function is registered for n's type.
|
|
||||||
*/
|
|
||||||
inline bool can_dispatch(const NodeRef& n) const {
|
|
||||||
uint32_t type_index = n.type_index();
|
|
||||||
return type_index < func_.size() && func_[type_index] != nullptr;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief invoke the functor , dispatch on type of n
|
|
||||||
* \param n The Node argument
|
|
||||||
* \param args The additional arguments
|
|
||||||
* \return The result.
|
|
||||||
*/
|
|
||||||
inline R operator()(const NodeRef& n, Args... args) const {
|
|
||||||
uint32_t type_index = n.type_index();
|
|
||||||
CHECK(type_index < func_.size() &&
|
|
||||||
func_[type_index] != nullptr)
|
|
||||||
<< "IRFunctor calls un-registered function on type "
|
|
||||||
<< Node::TypeIndex2Key(type_index);
|
|
||||||
return func_[type_index](n, std::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set the dispacher for type TNode
|
|
||||||
* \param f The function to be set.
|
|
||||||
* \tparam TNode the type of Node to be dispatched.
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
template<typename TNode>
|
|
||||||
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
|
|
||||||
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
|
|
||||||
if (func_.size() <= tindex) {
|
|
||||||
func_.resize(tindex + 1, nullptr);
|
|
||||||
}
|
|
||||||
CHECK(func_[tindex] == nullptr)
|
|
||||||
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
|
|
||||||
<< " is already set";
|
|
||||||
func_[tindex] = f;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set the dispacher for type TNode
|
|
||||||
* This allows f to used detailed const Node pointer to replace NodeRef
|
|
||||||
*
|
|
||||||
* \param f The function to be set.
|
|
||||||
* \tparam TNode the type of Node to be dispatched.
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
template<typename TNode>
|
|
||||||
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
|
|
||||||
Function fun = [f](const NodeRef& n, Args... args) {
|
|
||||||
return f(static_cast<const TNode*>(n.node_.get()),
|
|
||||||
std::forward<Args>(args)...);
|
|
||||||
};
|
|
||||||
return this->set_dispatch<TNode>(fun);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief unset the dispacher for type TNode
|
|
||||||
*
|
|
||||||
* \tparam TNode the type of Node to be dispatched.
|
|
||||||
* \return reference to self.
|
|
||||||
*/
|
|
||||||
template<typename TNode>
|
|
||||||
inline TSelf& clear_dispatch() { // NOLINT(*)
|
|
||||||
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
|
|
||||||
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
|
|
||||||
func_[tindex] = nullptr;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#define TVM_REGISTER_VAR_DEF(ClsName) \
|
|
||||||
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Useful macro to set IRFunctor dispatch in a global static field.
|
|
||||||
*
|
|
||||||
* \code
|
|
||||||
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
|
|
||||||
* // vtable allows easy patch in of new Node types, without changing
|
|
||||||
* // interface of IRPrinter.
|
|
||||||
*
|
|
||||||
* class IRPrinter {
|
|
||||||
* public:
|
|
||||||
* std::ostream& stream;
|
|
||||||
* // the dispatch function.
|
|
||||||
* void print(Expr e) {
|
|
||||||
* const static FType& f = *vtable();
|
|
||||||
* f(e, this);
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
|
|
||||||
* // function to return global function table
|
|
||||||
* static FType& vtable();
|
|
||||||
* };
|
|
||||||
*
|
|
||||||
* // in cpp/cc file
|
|
||||||
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0
|
|
||||||
* static FType inst; return inst;
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|
||||||
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
|
|
||||||
* p->print(n->a);
|
|
||||||
* p->stream << '+'
|
|
||||||
* p->print(n->b);
|
|
||||||
* });
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* \endcode
|
|
||||||
*
|
|
||||||
* \param ClsName The name of the class
|
|
||||||
* \param FField The static function that returns a singleton of IRFunctor.
|
|
||||||
*/
|
|
||||||
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
|
|
||||||
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
|
|
||||||
ClsName::FField()
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief A container for a list of callbacks. All callbacks are invoked when
|
|
||||||
* the object is destructed.
|
|
||||||
*/
|
|
||||||
class IRFunctorCleanList {
|
|
||||||
public:
|
|
||||||
~IRFunctorCleanList() {
|
|
||||||
for (auto &f : clean_items) {
|
|
||||||
f();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void append(std::function<void()> func) {
|
|
||||||
clean_items.push_back(func);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector< std::function<void()> > clean_items;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
|
|
||||||
* and make a corresponding call to clear_dispatch when the last copy of
|
|
||||||
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
|
|
||||||
* this can be used by NNVM and other libraries to unregister callbacks when
|
|
||||||
* the library is unloaded. This prevents crashes when the underlying IRFunctor
|
|
||||||
* is destructed as it will no longer contain std::function instances allocated
|
|
||||||
* by a library that has been unloaded.
|
|
||||||
*/
|
|
||||||
template<typename FType>
|
|
||||||
class IRFunctorStaticRegistry;
|
|
||||||
|
|
||||||
template<typename R, typename ...Args>
|
|
||||||
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
|
|
||||||
private:
|
|
||||||
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
|
|
||||||
std::shared_ptr<IRFunctorCleanList> free_list;
|
|
||||||
|
|
||||||
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
|
|
||||||
irf_ = irf;
|
|
||||||
free_list = std::make_shared<IRFunctorCleanList>();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename TNode>
|
|
||||||
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
|
|
||||||
irf_->template set_dispatch<TNode>(f);
|
|
||||||
auto irf_copy = irf_;
|
|
||||||
free_list.get()->append([irf_copy] {
|
|
||||||
irf_copy->template clear_dispatch<TNode>();
|
|
||||||
});
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
|
|
||||||
* the compiler to deduce the template types.
|
|
||||||
*/
|
|
||||||
template<typename R, typename ...Args>
|
|
||||||
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
|
|
||||||
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
|
|
||||||
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
|
|
||||||
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
|
|
||||||
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
|
|
||||||
* TVM_STATIC_IR_FUNCTOR.
|
|
||||||
*/
|
|
||||||
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
|
|
||||||
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
|
|
||||||
MakeIRFunctorStaticRegistry(&ClsName::FField())
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
||||||
#endif // TVM_NODE_IR_FUNCTOR_H_
|
|
|
@ -1,59 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2018 by Contributors
|
|
||||||
* \file tvm/node/memory.h
|
|
||||||
* \brief Node memory management.
|
|
||||||
*/
|
|
||||||
#ifndef TVM_NODE_MEMORY_H_
|
|
||||||
#define TVM_NODE_MEMORY_H_
|
|
||||||
|
|
||||||
#include "node.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
/*!
|
|
||||||
* \brief Allocate a node object.
|
|
||||||
* \param args arguments to the constructor.
|
|
||||||
* \tparam T the node type.
|
|
||||||
* \return The NodePtr to the allocated object.
|
|
||||||
*/
|
|
||||||
template<typename T, typename... Args>
|
|
||||||
inline NodePtr<T> make_node(Args&&... args);
|
|
||||||
|
|
||||||
// Detail implementations after this
|
|
||||||
//
|
|
||||||
// The current design allows swapping the
|
|
||||||
// allocator pattern when necessary.
|
|
||||||
//
|
|
||||||
// Possible future allocator optimizations:
|
|
||||||
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
|
|
||||||
// - Thread-local object pools: one pool per size and alignment requirement.
|
|
||||||
// - Can specialize by type of object to give the specific allocator to each object.
|
|
||||||
//
|
|
||||||
template<typename T>
|
|
||||||
class SimpleNodeAllocator {
|
|
||||||
public:
|
|
||||||
template<typename... Args>
|
|
||||||
static T* New(Args&&... args) {
|
|
||||||
return new T(std::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
static NodeBase::FDeleter Deleter() {
|
|
||||||
return Deleter_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static void Deleter_(NodeBase* ptr) {
|
|
||||||
delete static_cast<T*>(ptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T, typename... Args>
|
|
||||||
inline NodePtr<T> make_node(Args&&... args) {
|
|
||||||
using Allocator = SimpleNodeAllocator<T>;
|
|
||||||
static_assert(std::is_base_of<NodeBase, T>::value,
|
|
||||||
"make_node can only be used to create NodeBase");
|
|
||||||
T* node = Allocator::New(std::forward<Args>(args)...);
|
|
||||||
node->deleter_ = Allocator::Deleter();
|
|
||||||
return NodePtr<T>(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
||||||
#endif // TVM_NODE_MEMORY_H_
|
|
|
@ -1,337 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2018 by Contributors
|
|
||||||
* \file tvm/node/node.h
|
|
||||||
* \brief Node system data structure.
|
|
||||||
*/
|
|
||||||
#ifndef TVM_NODE_NODE_H_
|
|
||||||
#define TVM_NODE_NODE_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <type_traits>
|
|
||||||
#include "base/Type.h"
|
|
||||||
#include "../runtime/node_base.h"
|
|
||||||
#include "../runtime/c_runtime_api.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
using HalideIR::Type;
|
|
||||||
// forward declaration
|
|
||||||
class Node;
|
|
||||||
class NodeRef;
|
|
||||||
|
|
||||||
namespace runtime {
|
|
||||||
// forward declaration
|
|
||||||
class NDArray;
|
|
||||||
} // namespace runtime
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Visitor class to each node content.
|
|
||||||
* The content is going to be called for each field.
|
|
||||||
*/
|
|
||||||
class TVM_DLL AttrVisitor {
|
|
||||||
public:
|
|
||||||
//! \cond Doxygen_Suppress
|
|
||||||
virtual void Visit(const char* key, double* value) = 0;
|
|
||||||
virtual void Visit(const char* key, int64_t* value) = 0;
|
|
||||||
virtual void Visit(const char* key, uint64_t* value) = 0;
|
|
||||||
virtual void Visit(const char* key, int* value) = 0;
|
|
||||||
virtual void Visit(const char* key, bool* value) = 0;
|
|
||||||
virtual void Visit(const char* key, std::string* value) = 0;
|
|
||||||
virtual void Visit(const char* key, void** value) = 0;
|
|
||||||
virtual void Visit(const char* key, Type* value) = 0;
|
|
||||||
virtual void Visit(const char* key, NodeRef* value) = 0;
|
|
||||||
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
|
|
||||||
template<typename ENum,
|
|
||||||
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
|
|
||||||
void Visit(const char* key, ENum* ptr) {
|
|
||||||
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
|
|
||||||
"declare enum to be enum int to use visitor");
|
|
||||||
this->Visit(key, reinterpret_cast<int*>(ptr));
|
|
||||||
}
|
|
||||||
//! \endcond
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief base class of node container in DSL AST.
|
|
||||||
* All object's internal is stored as std::shared_ptr<Node>
|
|
||||||
*/
|
|
||||||
class TVM_DLL Node : public NodeBase {
|
|
||||||
public:
|
|
||||||
/*! \brief virtual destructor */
|
|
||||||
virtual ~Node() {}
|
|
||||||
/*! \return The unique type key of the node */
|
|
||||||
virtual const char* type_key() const = 0;
|
|
||||||
/*!
|
|
||||||
* \brief Apply visitor to each field of the Node
|
|
||||||
* Visitor could mutate the content of the node.
|
|
||||||
* override if Node contains attribute fields.
|
|
||||||
* \param visitor The visitor
|
|
||||||
*/
|
|
||||||
virtual void VisitAttrs(AttrVisitor* visitor) {}
|
|
||||||
/*! \return the type index of the node */
|
|
||||||
virtual const uint32_t type_index() const = 0;
|
|
||||||
/*!
|
|
||||||
* \brief Whether this node derives from node with type_index=tid.
|
|
||||||
* Implemented by TVM_DECLARE_NODE_TYPE_INFO
|
|
||||||
*
|
|
||||||
* \param tid The type index.
|
|
||||||
* \return the check result.
|
|
||||||
*/
|
|
||||||
virtual const bool _DerivedFrom(uint32_t tid) const;
|
|
||||||
/*!
|
|
||||||
* \brief get a runtime unique type index given a type key
|
|
||||||
* \param type_key Type key of a type.
|
|
||||||
* \return the corresponding type index.
|
|
||||||
*/
|
|
||||||
static uint32_t TypeKey2Index(const char* type_key);
|
|
||||||
/*!
|
|
||||||
* \brief get type key from type index.
|
|
||||||
* \param index The type index
|
|
||||||
* \return the corresponding type key.
|
|
||||||
*/
|
|
||||||
static const char* TypeIndex2Key(uint32_t index);
|
|
||||||
/*!
|
|
||||||
* \return whether the type is derived from
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline bool derived_from() const;
|
|
||||||
/*!
|
|
||||||
* \return whether the node is of type T
|
|
||||||
* \tparam The type to be checked.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline bool is_type() const;
|
|
||||||
/*!
|
|
||||||
* \brief Get a NodePtr that holds reference to this Node.
|
|
||||||
* \return the NodePtr
|
|
||||||
*/
|
|
||||||
inline NodePtr<Node> GetNodePtr() const;
|
|
||||||
// node ref can see this
|
|
||||||
friend class NodeRef;
|
|
||||||
static constexpr const char* _type_key = "Node";
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief Base class of all node reference object */
|
|
||||||
class NodeRef {
|
|
||||||
public:
|
|
||||||
/*! \brief type indicate the container type */
|
|
||||||
using ContainerType = Node;
|
|
||||||
/*!
|
|
||||||
* \brief Comparator
|
|
||||||
* \param other Another node ref.
|
|
||||||
* \return the compare result.
|
|
||||||
*/
|
|
||||||
inline bool operator==(const NodeRef& other) const;
|
|
||||||
/*!
|
|
||||||
* \brief Comparator
|
|
||||||
* \param other Another node ref.
|
|
||||||
* \return the compare result.
|
|
||||||
*/
|
|
||||||
inline bool same_as(const NodeRef& other) const;
|
|
||||||
/*!
|
|
||||||
* \brief Comparator
|
|
||||||
* \param other Another node ref.
|
|
||||||
* \return the compare result.
|
|
||||||
*/
|
|
||||||
inline bool operator<(const NodeRef& other) const;
|
|
||||||
/*!
|
|
||||||
* \brief Comparator
|
|
||||||
* \param other Another node ref.
|
|
||||||
* \return the compare result.
|
|
||||||
*/
|
|
||||||
inline bool operator!=(const NodeRef& other) const;
|
|
||||||
/*! \return the hash function for NodeRef */
|
|
||||||
inline size_t hash() const;
|
|
||||||
/*! \return whether the expression is null */
|
|
||||||
inline bool defined() const;
|
|
||||||
/*! \return the internal type index of IRNode */
|
|
||||||
inline uint32_t type_index() const;
|
|
||||||
/*! \return the internal node pointer */
|
|
||||||
inline const Node* get() const;
|
|
||||||
/*! \return the internal node pointer */
|
|
||||||
inline const Node* operator->() const;
|
|
||||||
/*!
|
|
||||||
* \brief Downcast this ir node to its actual type (e.g. Add, or
|
|
||||||
* Select). This returns nullptr if the node is not of the requested
|
|
||||||
* type. Example usage:
|
|
||||||
*
|
|
||||||
* if (const Add *add = node->as<Add>()) {
|
|
||||||
* // This is an add node
|
|
||||||
* }
|
|
||||||
* \tparam T the target type, must be subtype of IRNode
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline const T *as() const;
|
|
||||||
/*!
|
|
||||||
* \brief A more powerful version of as that also works with
|
|
||||||
* intermediate base types.
|
|
||||||
* \tparam T the target type, must be subtype of IRNode
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline const T *as_derived() const;
|
|
||||||
/*! \brief default constructor */
|
|
||||||
NodeRef() = default;
|
|
||||||
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
|
|
||||||
/*! \brief the internal node object, do not touch */
|
|
||||||
NodePtr<Node> node_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Get a reference type from a Node ptr type
|
|
||||||
*
|
|
||||||
* It is always important to get a reference type
|
|
||||||
* if we want to return a value as reference or keep
|
|
||||||
* the node alive beyond the scope of the function.
|
|
||||||
*
|
|
||||||
* \param ptr The node pointer
|
|
||||||
* \tparam RefType The reference type
|
|
||||||
* \tparam NodeType The node type
|
|
||||||
* \return The corresponding RefType
|
|
||||||
*/
|
|
||||||
template <typename RefType, typename NodeType>
|
|
||||||
inline RefType GetRef(const NodeType* ptr);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Downcast a base reference type to a more specific type.
|
|
||||||
*
|
|
||||||
* \param ref The inptut reference
|
|
||||||
* \return The corresponding SubRef.
|
|
||||||
* \tparam SubRef The target specific reference type.
|
|
||||||
* \tparam BaseRef the current reference type.
|
|
||||||
*/
|
|
||||||
template <typename SubRef, typename BaseRef>
|
|
||||||
inline SubRef Downcast(BaseRef ref);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief helper macro to declare type information in a base node.
|
|
||||||
*/
|
|
||||||
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
|
|
||||||
const bool _DerivedFrom(uint32_t tid) const override { \
|
|
||||||
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
|
|
||||||
if (tidx == tid) return true; \
|
|
||||||
return Parent::_DerivedFrom(tid); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief helper macro to declare type information in a terminal node
|
|
||||||
*/
|
|
||||||
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
|
|
||||||
const char* type_key() const final { \
|
|
||||||
return TypeName::_type_key; \
|
|
||||||
} \
|
|
||||||
const uint32_t type_index() const final { \
|
|
||||||
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
|
|
||||||
return tidx; \
|
|
||||||
} \
|
|
||||||
const bool _DerivedFrom(uint32_t tid) const final { \
|
|
||||||
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
|
|
||||||
if (tidx == tid) return true; \
|
|
||||||
return Parent::_DerivedFrom(tid); \
|
|
||||||
}
|
|
||||||
|
|
||||||
// implementations of inline functions after this
|
|
||||||
template<typename T>
|
|
||||||
inline bool Node::is_type() const {
|
|
||||||
// use static field so query only happens once.
|
|
||||||
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
|
|
||||||
return type_id == this->type_index();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline bool Node::derived_from() const {
|
|
||||||
// use static field so query only happens once.
|
|
||||||
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
|
|
||||||
return this->_DerivedFrom(type_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline NodePtr<Node> Node::GetNodePtr() const {
|
|
||||||
return NodePtr<Node>(const_cast<Node*>(this));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename RefType, typename NodeType>
|
|
||||||
inline RefType GetRef(const NodeType* ptr) {
|
|
||||||
static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
|
|
||||||
"Can only cast to the ref of same container type");
|
|
||||||
return RefType(ptr->GetNodePtr());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SubRef, typename BaseRef>
|
|
||||||
inline SubRef Downcast(BaseRef ref) {
|
|
||||||
CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
|
|
||||||
ref->template derived_from<typename SubRef::ContainerType>())
|
|
||||||
<< "Downcast from " << ref->type_key() << " to "
|
|
||||||
<< SubRef::ContainerType::_type_key << " failed.";
|
|
||||||
return SubRef(std::move(ref.node_));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline const Node* NodeRef::get() const {
|
|
||||||
return node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline const Node* NodeRef::operator->() const {
|
|
||||||
return node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool NodeRef::defined() const {
|
|
||||||
return node_.get() != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool NodeRef::operator==(const NodeRef& other) const {
|
|
||||||
return node_.get() == other.node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool NodeRef::same_as(const NodeRef& other) const {
|
|
||||||
return node_.get() == other.node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool NodeRef::operator<(const NodeRef& other) const {
|
|
||||||
return node_.get() < other.node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool NodeRef::operator!=(const NodeRef& other) const {
|
|
||||||
return node_.get() != other.node_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t NodeRef::hash() const {
|
|
||||||
return std::hash<Node*>()(node_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint32_t NodeRef::type_index() const {
|
|
||||||
CHECK(node_.get() != nullptr)
|
|
||||||
<< "null type";
|
|
||||||
return get()->type_index();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline const T* NodeRef::as() const {
|
|
||||||
const Node* ptr = static_cast<const Node*>(get());
|
|
||||||
if (ptr && ptr->is_type<T>()) {
|
|
||||||
return static_cast<const T*>(ptr);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline const T* NodeRef::as_derived() const {
|
|
||||||
const Node* ptr = static_cast<const Node*>(get());
|
|
||||||
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
|
|
||||||
return static_cast<const T*>(ptr);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief The hash function for nodes */
|
|
||||||
struct NodeHash {
|
|
||||||
size_t operator()(const NodeRef& a) const {
|
|
||||||
return a.hash();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief The equal comparator for nodes */
|
|
||||||
struct NodeEqual {
|
|
||||||
bool operator()(const NodeRef& a, const NodeRef& b) const {
|
|
||||||
return a.get() == b.get();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace tvm
|
|
||||||
#endif // TVM_NODE_NODE_H_
|
|
|
@ -7,6 +7,7 @@
|
||||||
#define TVM_TENSOR_H_
|
#define TVM_TENSOR_H_
|
||||||
|
|
||||||
#include <ir/FunctionBase.h>
|
#include <ir/FunctionBase.h>
|
||||||
|
#include <tvm/node/container.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -15,7 +16,6 @@
|
||||||
#include "expr.h"
|
#include "expr.h"
|
||||||
#include "ir_operator.h"
|
#include "ir_operator.h"
|
||||||
#include "arithmetic.h"
|
#include "arithmetic.h"
|
||||||
#include "node/container.h"
|
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
/*!
|
|
||||||
* Copyright (c) 2018 by Contributors
|
|
||||||
* Implementation of IR Node API
|
|
||||||
* \file node.cc
|
|
||||||
*/
|
|
||||||
#include <tvm/node/node.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <atomic>
|
|
||||||
#include <mutex>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace tvm {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// single manager of operator information.
|
|
||||||
struct TypeManager {
|
|
||||||
// mutex to avoid registration from multiple threads.
|
|
||||||
// recursive is needed for trigger(which calls UpdateAttrMap)
|
|
||||||
std::mutex mutex;
|
|
||||||
std::atomic<uint32_t> type_counter{0};
|
|
||||||
std::unordered_map<std::string, uint32_t> key2index;
|
|
||||||
std::vector<std::string> index2key;
|
|
||||||
// get singleton of the
|
|
||||||
static TypeManager* Global() {
|
|
||||||
static TypeManager inst;
|
|
||||||
return &inst;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const bool Node::_DerivedFrom(uint32_t tid) const {
|
|
||||||
static uint32_t tindex = TypeKey2Index(Node::_type_key);
|
|
||||||
return tid == tindex;
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is slow, usually caller always hold the result in a static variable.
|
|
||||||
uint32_t Node::TypeKey2Index(const char* key) {
|
|
||||||
TypeManager *t = TypeManager::Global();
|
|
||||||
std::lock_guard<std::mutex>(t->mutex);
|
|
||||||
std::string skey = key;
|
|
||||||
auto it = t->key2index.find(skey);
|
|
||||||
if (it != t->key2index.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
uint32_t tid = ++(t->type_counter);
|
|
||||||
t->key2index[skey] = tid;
|
|
||||||
t->index2key.push_back(skey);
|
|
||||||
return tid;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* Node::TypeIndex2Key(uint32_t index) {
|
|
||||||
TypeManager *t = TypeManager::Global();
|
|
||||||
std::lock_guard<std::mutex>(t->mutex);
|
|
||||||
internal_assert(index != 0);
|
|
||||||
return t->index2key.at(index - 1).c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tvm
|
|
Загрузка…
Ссылка в новой задаче