246 строки
7.2 KiB
C++
246 строки
7.2 KiB
C++
/*!
|
|
* Copyright (c) 2016 by Contributors
|
|
* \file tensor.h
|
|
* \brief Dataflow tensor object
|
|
*/
|
|
#ifndef TVM_TENSOR_H_
|
|
#define TVM_TENSOR_H_
|
|
|
|
#include <tvm/container.h>
|
|
#include <ir/FunctionBase.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <type_traits>
|
|
|
|
#include "./base.h"
|
|
#include "./expr.h"
|
|
#include "./arithmetic.h"
|
|
|
|
namespace tvm {
|
|
|
|
// Internal node container of Tensor
|
|
class TensorNode;
|
|
// internal node container for Operation
|
|
class OperationNode;
|
|
|
|
using Halide::IR::FunctionRef;
|
|
|
|
/*!
|
|
* \brief Tensor structure representing a possible input,
|
|
* or intermediate computation result.
|
|
*/
|
|
class Tensor : public NodeRef {
|
|
public:
|
|
/*! \brief default constructor, used internally */
|
|
Tensor() {}
|
|
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const TensorNode* operator->() const;
|
|
/*!
|
|
* \brief check if two tensors equals each other.
|
|
* \param other tensor to be checked.
|
|
* \return whether the two tensors equals each other.
|
|
*/
|
|
inline bool operator==(const Tensor& other) const;
|
|
/*! \return The dimension of the tensor */
|
|
inline size_t ndim() const;
|
|
/*!
|
|
* \brief Take elements from the tensor
|
|
* \param args The indices
|
|
* \return the result expression representing tensor read.
|
|
*/
|
|
template<typename... Args>
|
|
inline Expr operator()(Args&& ...args) const {
|
|
Array<Expr> indices{std::forward<Args>(args)...};
|
|
return operator()(indices);
|
|
}
|
|
/*!
|
|
* \brief Take elements from the tensor
|
|
* \param indices the indices.
|
|
* \return the result expression representing tensor read.
|
|
*/
|
|
Expr operator()(Array<Expr> indices) const;
|
|
/*!
|
|
* \brief Take elements from the tensor
|
|
* \param indices the indices.
|
|
* \return the result expression representing tensor read.
|
|
*/
|
|
Expr operator()(Array<Var> indices) const;
|
|
/*!
|
|
* \brief data structure to represent a slice that fixes first k coordinates.
|
|
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
|
|
*/
|
|
class Slice {
|
|
public:
|
|
// construct via tensor and indices
|
|
Slice(const Tensor& tensor, std::vector<Expr> indices)
|
|
: tensor_(tensor), indices_(indices) {}
|
|
/*!
|
|
* \brief get i-th slice from the current slice.
|
|
* \param i the index of the coordinate
|
|
* \return the subsequent slice.
|
|
*/
|
|
inline Slice operator[](Expr i) {
|
|
std::vector<Expr> other = indices_;
|
|
other.emplace_back(i);
|
|
return Slice(tensor_, other);
|
|
}
|
|
/*!
|
|
* \brief Convert slice to expression.
|
|
* This is only valid when all the coordinates are fully specified.
|
|
* \return the corresponding expression of this slice.
|
|
*/
|
|
inline operator Expr() const {
|
|
return tensor_(indices_);
|
|
}
|
|
|
|
private:
|
|
const Tensor& tensor_;
|
|
std::vector<Expr> indices_;
|
|
};
|
|
/*!
|
|
* \brief get i-th slice from the current Tensor.
|
|
* \param i the index of the coordinate
|
|
* \return the subsequent slice.
|
|
*/
|
|
inline Slice operator[](Expr i) const {
|
|
return Slice(*this, {i});
|
|
}
|
|
/*! \brief specify container node */
|
|
using ContainerType = TensorNode;
|
|
};
|
|
|
|
/*! \brief Operation that produces tensors */
|
|
class Operation : public FunctionRef {
|
|
public:
|
|
/*! \brief default constructor */
|
|
Operation() {}
|
|
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
|
|
/*!
|
|
* \brief access the internal node container
|
|
* \return the pointer to the internal node container
|
|
*/
|
|
inline const OperationNode* operator->() const;
|
|
/*!
|
|
* \brief get the i-th output of the operation.
|
|
* \param i the output index.
|
|
* \return The i-th output.
|
|
*/
|
|
Tensor output(size_t i) const;
|
|
/*! \brief specify container node */
|
|
using ContainerType = OperationNode;
|
|
};
|
|
|
|
/*! \brief Node to represent a tensor */
|
|
class TensorNode : public Node {
|
|
public:
|
|
/*! \brief The shape of the tensor */
|
|
Array<Expr> shape;
|
|
/*! \brief data type in the content of the tensor */
|
|
Type dtype;
|
|
/*! \brief the source operation, can be None */
|
|
Operation op;
|
|
/*! \brief the output index from source operation */
|
|
int value_index{0};
|
|
/*! \brief constructor */
|
|
TensorNode() {}
|
|
|
|
void VisitAttrs(AttrVisitor* v) final {
|
|
v->Visit("shape", &shape);
|
|
v->Visit("dtype", &dtype);
|
|
v->Visit("op", &op);
|
|
v->Visit("value_index", &value_index);
|
|
}
|
|
static Tensor make(Array<Expr> shape,
|
|
Type dtype,
|
|
Operation op,
|
|
int value_index);
|
|
|
|
static constexpr const char* _type_key = "Tensor";
|
|
TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
|
|
};
|
|
|
|
|
|
// Implementations of inline functions
|
|
inline const TensorNode* Tensor::operator->() const {
|
|
return static_cast<const TensorNode*>(node_.get());
|
|
}
|
|
|
|
inline size_t Tensor::ndim() const {
|
|
return (*this)->shape.size();
|
|
}
|
|
|
|
inline bool Tensor::operator==(const Tensor& other) const {
|
|
if (get() == other.get()) return true;
|
|
if (get() == nullptr || other.get() == nullptr) return false;
|
|
if ((*this)->op.defined() || other->op.defined()) {
|
|
return (*this)->op == other->op &&
|
|
(*this)->value_index == other->value_index;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// macro to turn every operation of slice to expression
|
|
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
|
|
inline Expr operator Op (const Tensor::Slice& a) { \
|
|
return Op a.operator Expr() ; \
|
|
}
|
|
|
|
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
|
|
template<typename T> \
|
|
inline Expr operator Op (const Tensor::Slice& a, const T& b) { \
|
|
return a.operator Expr() Op b; \
|
|
} \
|
|
template<typename T> \
|
|
inline Expr operator Op (const T& a, const Tensor::Slice& b) { \
|
|
return a Op b.operator Expr(); \
|
|
} \
|
|
inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
|
|
return a.operator Expr() Op b.operator Expr(); \
|
|
}
|
|
|
|
DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
|
|
DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(/);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(%);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
|
|
DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
|
|
|
|
} // namespace tvm
|
|
|
|
namespace std {
|
|
template <>
|
|
struct hash<::tvm::Operation> {
|
|
std::size_t operator()(const ::tvm::Operation& k) const {
|
|
return k.hash();
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct hash<::tvm::Tensor> {
|
|
std::size_t operator()(const ::tvm::Tensor& k) const {
|
|
if (k.defined() && k->op.defined()) {
|
|
return k->op.hash();
|
|
} else{
|
|
return k.hash();
|
|
}
|
|
}
|
|
};
|
|
} // namespace std
|
|
#endif // TVM_TENSOR_H_
|