/*! * Copyright (c) 2016 by Contributors * \file tensor.h * \brief Dataflow tensor object */ #ifndef TVM_TENSOR_H_ #define TVM_TENSOR_H_ #include #include #include #include #include #include "./base.h" #include "./expr.h" #include "./arithmetic.h" namespace tvm { // Internal node container of Tensor class TensorNode; // internal node container for Operation class OperationNode; using Halide::IR::FunctionRef; /*! * \brief Tensor structure representing a possible input, * or intermediate computation result. */ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} explicit Tensor(std::shared_ptr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const TensorNode* operator->() const; /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. * \return whether the two tensors equals each other. */ inline bool operator==(const Tensor& other) const; /*! \return The dimension of the tensor */ inline size_t ndim() const; /*! * \brief Take elements from the tensor * \param args The indices * \return the result expression representing tensor read. */ template inline Expr operator()(Args&& ...args) const { Array indices{std::forward(args)...}; return operator()(indices); } /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ Expr operator()(Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ Expr operator()(Array indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. */ class Slice { public: // construct via tensor and indices Slice(const Tensor& tensor, std::vector indices) : tensor_(tensor), indices_(indices) {} /*! * \brief get i-th slice from the current slice. * \param i the index of the coordinate * \return the subsequent slice. */ inline Slice operator[](Expr i) { std::vector other = indices_; other.emplace_back(i); return Slice(tensor_, other); } /*! * \brief Convert slice to expression. * This is only valid when all the coordinates are fully specified. * \return the corresponding expression of this slice. */ inline operator Expr() const { return tensor_(indices_); } private: const Tensor& tensor_; std::vector indices_; }; /*! * \brief get i-th slice from the current Tensor. * \param i the index of the coordinate * \return the subsequent slice. */ inline Slice operator[](Expr i) const { return Slice(*this, {i}); } /*! \brief specify container node */ using ContainerType = TensorNode; }; /*! \brief Operation that produces tensors */ class Operation : public FunctionRef { public: /*! \brief default constructor */ Operation() {} explicit Operation(std::shared_ptr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const OperationNode* operator->() const; /*! * \brief get the i-th output of the operation. * \param i the output index. * \return The i-th output. */ Tensor output(size_t i) const; /*! \brief specify container node */ using ContainerType = OperationNode; }; /*! \brief Node to represent a tensor */ class TensorNode : public Node { public: /*! \brief The shape of the tensor */ Array shape; /*! \brief data type in the content of the tensor */ Type dtype; /*! \brief the source operation, can be None */ Operation op; /*! \brief the output index from source operation */ int value_index{0}; /*! \brief constructor */ TensorNode() {} void VisitAttrs(AttrVisitor* v) final { v->Visit("shape", &shape); v->Visit("dtype", &dtype); v->Visit("op", &op); v->Visit("value_index", &value_index); } static Tensor make(Array shape, Type dtype, Operation op, int value_index); static constexpr const char* _type_key = "Tensor"; TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); }; // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { return static_cast(node_.get()); } inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { if (get() == other.get()) return true; if (get() == nullptr || other.get() == nullptr) return false; if ((*this)->op.defined() || other->op.defined()) { return (*this)->op == other->op && (*this)->value_index == other->value_index; } else { return false; } } // macro to turn every operation of slice to expression #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ inline Expr operator Op (const Tensor::Slice& a) { \ return Op a.operator Expr() ; \ } #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ template \ inline Expr operator Op (const Tensor::Slice& a, const T& b) { \ return a.operator Expr() Op b; \ } \ template \ inline Expr operator Op (const T& a, const Tensor::Slice& b) { \ return a Op b.operator Expr(); \ } \ inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ return a.operator Expr() Op b.operator Expr(); \ } DEFINE_OVERLOAD_SLICE_UNARY_OP(!); DEFINE_OVERLOAD_SLICE_UNARY_OP(-); DEFINE_OVERLOAD_SLICE_BINARY_OP(+); DEFINE_OVERLOAD_SLICE_BINARY_OP(-); DEFINE_OVERLOAD_SLICE_BINARY_OP(*); DEFINE_OVERLOAD_SLICE_BINARY_OP(/); DEFINE_OVERLOAD_SLICE_BINARY_OP(%); DEFINE_OVERLOAD_SLICE_BINARY_OP(==); DEFINE_OVERLOAD_SLICE_BINARY_OP(<=); DEFINE_OVERLOAD_SLICE_BINARY_OP(>=); DEFINE_OVERLOAD_SLICE_BINARY_OP(!=); DEFINE_OVERLOAD_SLICE_BINARY_OP(&&); DEFINE_OVERLOAD_SLICE_BINARY_OP(||); DEFINE_OVERLOAD_SLICE_BINARY_OP(>>); DEFINE_OVERLOAD_SLICE_BINARY_OP(<<); DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*) DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) } // namespace tvm namespace std { template <> struct hash<::tvm::Operation> { std::size_t operator()(const ::tvm::Operation& k) const { return k.hash(); } }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { if (k.defined() && k->op.defined()) { return k->op.hash(); } else{ return k.hash(); } } }; } // namespace std #endif // TVM_TENSOR_H_