check stmt in
This commit is contained in:
Родитель
dac6b528bd
Коммит
151707e03f
|
@ -22,13 +22,28 @@ class NodeRef;
|
|||
class UnaryOp;
|
||||
class BinaryOp;
|
||||
|
||||
/*! \brief pointer type mask */
|
||||
const int kPtrTypeMask = 16;
|
||||
|
||||
/*! \brief list of all supported data types */
|
||||
enum DataType : int {
|
||||
kUnknown = 0,
|
||||
kInt32 = 1,
|
||||
kFloat32 = 2
|
||||
kFloat32 = 2,
|
||||
kInt32Buffer = kInt32 | kPtrTypeMask,
|
||||
kFloat32Buffer = kFloat32 | kPtrTypeMask
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief convert pointer type to data type
|
||||
* \param ptr_type The pointer type.
|
||||
* \return The corresponding data type.
|
||||
*/
|
||||
inline DataType Ptr2DataType(DataType ptr_type) {
|
||||
CHECK_GE(ptr_type, kPtrTypeMask);
|
||||
return static_cast<DataType>(ptr_type & (kPtrTypeMask -1));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief List of subset node types used for quick runtime switch.
|
||||
*
|
||||
|
@ -45,6 +60,7 @@ enum NodeType {
|
|||
kBinaryOpNode,
|
||||
kReduceNode,
|
||||
kTensorReadNode,
|
||||
kBufferReadNode,
|
||||
// stmt nodes
|
||||
kStoreNode,
|
||||
kForRangeNode,
|
||||
|
@ -157,6 +173,8 @@ class NodeRef {
|
|||
inline bool operator!=(const NodeRef& other) const;
|
||||
/*! \return the hash function for NodeRef */
|
||||
inline size_t hash() const;
|
||||
/*! \return the raw internal pointer of the node */
|
||||
inline Node* node_ptr() const;
|
||||
|
||||
protected:
|
||||
template<typename T, typename>
|
||||
|
@ -217,7 +235,11 @@ inline bool NodeRef::operator!=(const NodeRef& other) const {
|
|||
}
|
||||
|
||||
inline size_t NodeRef::hash() const {
|
||||
return std::hash<Node*>()(node_.get());
|
||||
return std::hash<Node*>()(node_ptr());
|
||||
}
|
||||
|
||||
inline Node* NodeRef::node_ptr() const {
|
||||
return node_.get();
|
||||
}
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -10,8 +10,9 @@
|
|||
#include "./base.h"
|
||||
|
||||
namespace tvm {
|
||||
// forward declare Expr
|
||||
// Forward declare Expr
|
||||
class Expr;
|
||||
class Var;
|
||||
|
||||
/*!
|
||||
* \brief create a constant expression
|
||||
|
@ -23,35 +24,34 @@ template<typename T,
|
|||
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
|
||||
inline Expr constant(T value);
|
||||
|
||||
/*!
|
||||
* \brief create a integer expression
|
||||
* \param value The value to the expression
|
||||
* \return the expression.
|
||||
*/
|
||||
Expr IntConstant(int64_t value);
|
||||
|
||||
/*!
|
||||
* \brief create a float expression.
|
||||
* \param value The value to the expression
|
||||
* \return the expression.
|
||||
*/
|
||||
Expr FloatConstant(double value);
|
||||
|
||||
/*!
|
||||
* \brief create a float expression.
|
||||
* \param value The value to the expression
|
||||
* \return the expression.
|
||||
*/
|
||||
Expr BufferRead(Var buffer, Expr offset);
|
||||
|
||||
/*!
|
||||
* \brief a expression type, holds a ref to root of an AST
|
||||
*/
|
||||
class Expr : public NodeRef {
|
||||
public:
|
||||
/*! \brief default constructor */
|
||||
Expr() = default;
|
||||
/*!
|
||||
* \brief copy constructor
|
||||
* \param other the input
|
||||
*/
|
||||
Expr(const Expr& other) = default;
|
||||
/*!
|
||||
* \brief move constructor
|
||||
* \param other the input
|
||||
*/
|
||||
Expr(Expr&& other) = default;
|
||||
/*!
|
||||
* \brief assign operator.
|
||||
* \param other the input.
|
||||
* \return reference to self
|
||||
*/
|
||||
Expr& operator=(const Expr& other) = default;
|
||||
/*!
|
||||
* \brief assign move operator.
|
||||
* \param other the input.
|
||||
* \return reference to self
|
||||
*/
|
||||
Expr& operator=(Expr&& other) = default;
|
||||
Expr() {}
|
||||
/*!
|
||||
* \brief constructor from constant value
|
||||
* \param value the constant value
|
||||
|
@ -82,15 +82,17 @@ class Expr : public NodeRef {
|
|||
void Print(std::ostream& os) const; // NOLINT(*)
|
||||
};
|
||||
|
||||
/*! \brief Variable class */
|
||||
/*!
|
||||
* \brief Variable class to represent the symbolic placeholder
|
||||
* in the DSL, internally it is a VarNode.
|
||||
*
|
||||
* The Variable is uniquely identified by the address of VarNode.
|
||||
*/
|
||||
class Var : public Expr {
|
||||
public:
|
||||
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
|
||||
};
|
||||
|
||||
Expr IntConstant(int64_t value);
|
||||
Expr FloatConstant(double value);
|
||||
|
||||
/*! \brief base of expression node */
|
||||
class ExprNode : public Node {
|
||||
public:
|
||||
|
@ -98,7 +100,7 @@ class ExprNode : public Node {
|
|||
DataType dtype_{kUnknown};
|
||||
};
|
||||
|
||||
// inline implementations
|
||||
// implementations
|
||||
inline DataType Expr::dtype() const {
|
||||
return static_cast<const ExprNode*>(node_.get())->dtype_;
|
||||
}
|
||||
|
|
|
@ -12,10 +12,8 @@
|
|||
#include "./expr.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
/*! \brief variable node for symbolic variables */
|
||||
class VarNode : public ExprNode {
|
||||
public:
|
||||
struct VarNode : public ExprNode {
|
||||
/*! \brief hint name of the variable */
|
||||
std::string name;
|
||||
/*! \brief constructor */
|
||||
|
@ -32,7 +30,7 @@ class VarNode : public ExprNode {
|
|||
};
|
||||
|
||||
/*! \brief integer constant node */
|
||||
class IntNode : public ExprNode {
|
||||
struct IntNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief the value field */
|
||||
int64_t value;
|
||||
|
@ -51,8 +49,7 @@ class IntNode : public ExprNode {
|
|||
};
|
||||
|
||||
/*! \brief float constant node */
|
||||
class FloatNode : public ExprNode {
|
||||
public:
|
||||
struct FloatNode : public ExprNode {
|
||||
/*! \brief the value field */
|
||||
double value;
|
||||
/*! \brief constructor */
|
||||
|
@ -61,7 +58,7 @@ class FloatNode : public ExprNode {
|
|||
dtype_ = kFloat32;
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "IntNode";
|
||||
return "FloatNode";
|
||||
}
|
||||
void VisitAttrs(AttrVisitor* visitor) override {
|
||||
visitor->Visit("value", &value);
|
||||
|
@ -70,8 +67,7 @@ class FloatNode : public ExprNode {
|
|||
};
|
||||
|
||||
/*! \brief Unary mapping operator */
|
||||
class UnaryOpNode : public ExprNode {
|
||||
public:
|
||||
struct UnaryOpNode : public ExprNode {
|
||||
/*! \brief The operator */
|
||||
const UnaryOp* op;
|
||||
/*! \brief The source expression */
|
||||
|
@ -105,7 +101,6 @@ class UnaryOpNode : public ExprNode {
|
|||
|
||||
/*! \brief Binary mapping operator */
|
||||
struct BinaryOpNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The operator */
|
||||
const BinaryOp* op;
|
||||
/*! \brief The left operand */
|
||||
|
@ -143,7 +138,6 @@ struct BinaryOpNode : public ExprNode {
|
|||
|
||||
/*! \brief Reduction operator operator */
|
||||
struct ReduceNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The operator */
|
||||
const BinaryOp* op;
|
||||
/*! \brief The source operand */
|
||||
|
@ -180,7 +174,6 @@ struct ReduceNode : public ExprNode {
|
|||
|
||||
/*! \brief Tensor read operator */
|
||||
struct TensorReadNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The tensor to be read from */
|
||||
Tensor tensor;
|
||||
/*! \brief The indices of read */
|
||||
|
@ -215,6 +208,32 @@ struct TensorReadNode : public ExprNode {
|
|||
}
|
||||
};
|
||||
|
||||
/*! \brief Buffer read node */
|
||||
struct BufferReadNode : public ExprNode {
|
||||
/*! \brief The buffer variable to be read from */
|
||||
Var buffer;
|
||||
/*! \brief The offset to be read from */
|
||||
Expr offset;
|
||||
/*! \brief constructor, do not use constructor */
|
||||
BufferReadNode() {
|
||||
node_type_ = kBufferReadNode;
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "BufferReadNode";
|
||||
}
|
||||
void Verify() const override {
|
||||
CHECK_EQ(dtype_, Ptr2DataType(buffer.dtype()));
|
||||
CHECK_EQ(offset.dtype(), kInt32);
|
||||
}
|
||||
void VisitAttrs(AttrVisitor* visitor) override {
|
||||
visitor->Visit("dtype", &dtype_);
|
||||
}
|
||||
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
|
||||
fvisit("buffer", &buffer);
|
||||
fvisit("offset", &offset);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_EXPR_NODE_H_
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file stmt.h
|
||||
* \brief The statement creation functions.
|
||||
* The underlying container are defined in stmt_node.h
|
||||
*/
|
||||
#ifndef TVM_STMT_H_
|
||||
#define TVM_STMT_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include "./base.h"
|
||||
#include "./domain.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
/*!
|
||||
* \brief a expression type, holds a ref to root of an AST
|
||||
*/
|
||||
class Stmt : public NodeRef {
|
||||
public:
|
||||
/*! \brief default constructor */
|
||||
Stmt() {}
|
||||
/*!
|
||||
* \brief constructor from node pointer
|
||||
* \param nptr Another node shared pointer
|
||||
*/
|
||||
explicit Stmt(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
|
||||
CHECK(node_.get() != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief construct Store Stmt.
|
||||
* \param buffer The variable representing the buffer.
|
||||
* \param offset The offset in the buffer
|
||||
* \param src The source expression.
|
||||
*/
|
||||
Stmt Store(Var buffer, Expr offset, Expr src);
|
||||
|
||||
/*!
|
||||
* \brief construct ForRange Stmt
|
||||
* \param loop_var The loop variable
|
||||
* \param range The loop range
|
||||
* \param body The loop body
|
||||
*/
|
||||
Stmt ForRange(Var loop_var, Range range, Stmt body);
|
||||
|
||||
/*!
|
||||
* \brief construct a IfThenElse
|
||||
* \param cond The condition.
|
||||
* \param then_body The body to go to in then condition.
|
||||
* \param else_body The body to go to in else condition.
|
||||
*/
|
||||
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body);
|
||||
|
||||
} // namespace tvm
|
||||
#endif // TVM_STMT_H_
|
|
@ -6,8 +6,15 @@
|
|||
#ifndef TVM_STMT_NODE_H_
|
||||
#define TVM_STMT_NODE_H_
|
||||
|
||||
#include "./base.h"
|
||||
#include "./domain.h"
|
||||
|
||||
namespace tvm {
|
||||
|
||||
/*!
|
||||
* \brief The internal base class of StmtNode
|
||||
* So far no extra stuffs in here.
|
||||
*/
|
||||
struct StmtNode : public Node {
|
||||
};
|
||||
|
||||
|
@ -23,11 +30,18 @@ struct StoreNode : public StmtNode {
|
|||
StoreNode() {
|
||||
node_type_ = kStoreNode;
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "StoreNode";
|
||||
}
|
||||
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
|
||||
fvisit("buffer", &buffer);
|
||||
fvisit("offset", &offset);
|
||||
fvisit("src", &src);
|
||||
}
|
||||
void Verify() const override {
|
||||
CHECK_EQ(Ptr2DataType(buffer.dtype()), src.dtype());
|
||||
CHECK_EQ(offset.dtype(), kInt32);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief for loop in range */
|
||||
|
@ -42,11 +56,19 @@ struct ForRangeNode : public StmtNode {
|
|||
ForRangeNode() {
|
||||
node_type_ = kForRangeNode;
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "ForRangeNode";
|
||||
}
|
||||
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
|
||||
fvisit("loop_var", &loop_var);
|
||||
fvisit("range", &range);
|
||||
fvisit("body", &body);
|
||||
}
|
||||
void Verify() const override {
|
||||
CHECK_EQ(loop_var.dtype(), kInt32);
|
||||
CHECK_EQ(this->range->begin.dtype(), loop_var.dtype());
|
||||
CHECK_EQ(this->range->end.dtype(), loop_var.dtype());
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief conditional expression */
|
||||
|
@ -61,13 +83,19 @@ struct IfThenElseNode : public StmtNode {
|
|||
IfThenElseNode() {
|
||||
node_type_ = kIfThenElseNode;
|
||||
}
|
||||
const char* type_key() const override {
|
||||
return "IfThenElseNode";
|
||||
}
|
||||
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
|
||||
fvisit("cond", &cond);
|
||||
fvisit("then_body", &then_body);
|
||||
fvisit("else_body", &else_body);
|
||||
}
|
||||
void Verify() const override {
|
||||
CHECK_EQ(cond.dtype(), kInt32);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_CODEGEN_H_
|
||||
#endif // TVM_STMT_NODE_H_
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Code organization
|
||||
|
||||
- c_api C API related functions
|
||||
- lang The definition of DSL related data structure
|
||||
- schedule The Schedule->Stmt generation logic
|
||||
- codegen Backend code generation related
|
|
@ -5,7 +5,6 @@
|
|||
#include <tvm/expr.h>
|
||||
#include <tvm/op.h>
|
||||
#include <tvm/expr_node.h>
|
||||
#include <cctype>
|
||||
|
||||
namespace tvm {
|
||||
|
||||
|
@ -28,4 +27,12 @@ Expr FloatConstant(double value) {
|
|||
return Expr(std::move(nptr));
|
||||
}
|
||||
|
||||
Expr BufferRead(Var buffer, Expr offset) {
|
||||
auto nptr = std::make_shared<BufferReadNode>();
|
||||
nptr->buffer = std::move(buffer);
|
||||
nptr->offset = std::move(offset);
|
||||
nptr->Verify();
|
||||
return Expr(std::move(nptr));
|
||||
}
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* \file stmt.cc
|
||||
*/
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/stmt.h>
|
||||
#include <tvm/stmt_node.h>
|
||||
|
||||
namespace tvm {
|
||||
|
||||
Stmt Store(Var buffer, Expr offset, Expr src) {
|
||||
auto nptr = std::make_shared<StoreNode>();
|
||||
nptr->buffer = std::move(buffer);
|
||||
nptr->offset = std::move(offset);
|
||||
nptr->src = std::move(src);
|
||||
nptr->Verify();
|
||||
return Stmt(std::move(nptr));
|
||||
}
|
||||
|
||||
Stmt ForRange(Var loop_var, Range range, Stmt body) {
|
||||
auto nptr = std::make_shared<ForRangeNode>();
|
||||
nptr->loop_var = std::move(loop_var);
|
||||
nptr->range = std::move(range);
|
||||
nptr->body = std::move(body);
|
||||
nptr->Verify();
|
||||
return Stmt(std::move(nptr));
|
||||
}
|
||||
|
||||
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body) {
|
||||
auto nptr = std::make_shared<IfThenElseNode>();
|
||||
nptr->cond = std::move(cond);
|
||||
nptr->then_body = std::move(then_body);
|
||||
nptr->else_body = std::move(else_body);
|
||||
nptr->Verify();
|
||||
return Stmt(std::move(nptr));
|
||||
}
|
||||
|
||||
} // namespace tvm
|
Загрузка…
Ссылка в новой задаче