From 051303ce09291dfbed537fa33b0d8a4d92c82b75 Mon Sep 17 00:00:00 2001 From: "Tareq A. Siraj" Date: Tue, 16 Apr 2013 18:53:08 +0000 Subject: [PATCH] Implement CapturedStmt AST CapturedStmt can be used to implement generic function outlining as described in http://lists.cs.uiuc.edu/pipermail/cfe-dev/2013-January/027540.html. CapturedStmt is not exposed to the C api. Serialization and template support are pending. Author: Wei Pan Differential Revision: http://llvm-reviews.chandlerc.com/D370 git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@179615 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/clang/AST/RecursiveASTVisitor.h | 1 + include/clang/AST/Stmt.h | 166 ++++++++++++++++++++++ include/clang/Basic/StmtNodes.td | 1 + include/clang/Serialization/ASTBitCodes.h | 2 + lib/AST/Stmt.cpp | 102 +++++++++++++ lib/AST/StmtPrinter.cpp | 4 + lib/AST/StmtProfile.cpp | 4 + lib/CodeGen/CGStmt.cpp | 8 +- lib/CodeGen/CodeGenFunction.h | 1 + lib/Sema/TreeTransform.h | 6 + lib/Serialization/ASTReaderStmt.cpp | 8 ++ lib/Serialization/ASTWriterStmt.cpp | 7 + lib/StaticAnalyzer/Core/ExprEngine.cpp | 1 + tools/libclang/CXCursor.cpp | 4 + tools/libclang/RecursiveASTVisitor.h | 2 +- 15 files changed, 315 insertions(+), 2 deletions(-) diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h index 33534ecc7c..9b4e481bfd 100644 --- a/include/clang/AST/RecursiveASTVisitor.h +++ b/include/clang/AST/RecursiveASTVisitor.h @@ -2218,6 +2218,7 @@ DEF_TRAVERSE_STMT(UnresolvedMemberExpr, { DEF_TRAVERSE_STMT(SEHTryStmt, {}) DEF_TRAVERSE_STMT(SEHExceptStmt, {}) DEF_TRAVERSE_STMT(SEHFinallyStmt,{}) +DEF_TRAVERSE_STMT(CapturedStmt, {}) DEF_TRAVERSE_STMT(CXXOperatorCallExpr, { }) DEF_TRAVERSE_STMT(OpaqueValueExpr, { }) diff --git a/include/clang/AST/Stmt.h b/include/clang/AST/Stmt.h index cf8fc249c5..c2cfaa486c 100644 --- a/include/clang/AST/Stmt.h +++ b/include/clang/AST/Stmt.h @@ -33,12 +33,14 @@ namespace clang { class Attr; class Decl; class Expr; + class FunctionDecl; class IdentifierInfo; class LabelDecl; class ParmVarDecl; class PrinterHelper; struct PrintingPolicy; class QualType; + class RecordDecl; class SourceManager; class StringLiteral; class SwitchStmt; @@ -1882,6 +1884,170 @@ public: } }; +/// \brief This captures a statement into a function. For example, the following +/// pragma annotated compound statement can be represented as a CapturedStmt, +/// and this compound statement is the body of an anonymous outlined function. +/// @code +/// #pragma omp parallel +/// { +/// compute(); +/// } +/// @endcode +class CapturedStmt : public Stmt { +public: + /// \brief The different capture forms: by 'this' or by reference, etc. + enum VariableCaptureKind { + VCK_This, + VCK_ByRef + }; + + /// \brief Describes the capture of either a variable or 'this'. + class Capture { + VarDecl *Var; + SourceLocation Loc; + + public: + /// \brief Create a new capture. + /// + /// \param Loc The source location associated with this capture. + /// + /// \param Kind The kind of capture (this, ByRef, ...). + /// + /// \param Var The variable being captured, or null if capturing this. + /// + Capture(SourceLocation Loc, VariableCaptureKind Kind, VarDecl *Var = 0) + : Var(Var), Loc(Loc) { + switch (Kind) { + case VCK_This: + assert(Var == 0 && "'this' capture cannot have a variable!"); + break; + case VCK_ByRef: + assert(Var && "capturing by reference must have a variable!"); + break; + } + } + + /// \brief Determine the kind of capture. + VariableCaptureKind getCaptureKind() const { + if (capturesThis()) + return VCK_This; + + return VCK_ByRef; + } + + /// \brief Retrieve the source location at which the variable or 'this' was + /// first used. + SourceLocation getLocation() const { return Loc; } + + /// \brief Determine whether this capture handles the C++ 'this' pointer. + bool capturesThis() const { return Var == 0; } + + /// \brief Determine whether this capture handles a variable. + bool capturesVariable() const { return Var != 0; } + + /// \brief Retrieve the declaration of the variable being captured. + /// + /// This operation is only valid if this capture does not capture 'this'. + VarDecl *getCapturedVar() const { + assert(!capturesThis() && "No variable available for 'this' capture"); + return Var; + } + }; + +private: + /// \brief The number of variable captured, including 'this'. + unsigned NumCaptures; + + /// \brief The implicit outlined function. + FunctionDecl *TheFuncDecl; + + /// \brief The record for captured variables, a RecordDecl or CXXRecordDecl. + RecordDecl *TheRecordDecl; + + /// \brief Construct a captured statement. + CapturedStmt(Stmt *S, ArrayRef Captures, + ArrayRef CaptureInits, + FunctionDecl *FD, RecordDecl *RD); + + /// \brief Construct an empty captured statement. + CapturedStmt(EmptyShell Empty, unsigned NumCaptures); + + Stmt **getStoredStmts() const { + return reinterpret_cast(const_cast(this) + 1); + } + + Capture *getStoredCaptures() const; + +public: + static CapturedStmt *Create(ASTContext &Context, Stmt *S, + ArrayRef Captures, + ArrayRef CaptureInits, + FunctionDecl *FD, RecordDecl *RD); + + static CapturedStmt *CreateDeserialized(ASTContext &Context, + unsigned NumCaptures); + + /// \brief Retrieve the statement being captured. + Stmt *getCapturedStmt() { return getStoredStmts()[NumCaptures]; } + const Stmt *getCapturedStmt() const { + return const_cast(this)->getCapturedStmt(); + } + + /// \brief Retrieve the outlined function declaration. + const FunctionDecl *getCapturedFunctionDecl() const { return TheFuncDecl; } + + /// \brief Retrieve the record declaration for captured variables. + const RecordDecl *getCapturedRecordDecl() const { return TheRecordDecl; } + + /// \brief True if this variable has been captured. + bool capturesVariable(const VarDecl *Var) const; + + /// \brief An iterator that walks over the captures. + typedef const Capture *capture_iterator; + + /// \brief Retrieve an iterator pointing to the first capture. + capture_iterator capture_begin() const { return getStoredCaptures(); } + + /// \brief Retrieve an iterator pointing past the end of the sequence of + /// captures. + capture_iterator capture_end() const { + return getStoredCaptures() + NumCaptures; + } + + /// \brief Retrieve the number of captures, including 'this'. + unsigned capture_size() const { return NumCaptures; } + + /// \brief Iterator that walks over the capture initialization arguments. + typedef Expr **capture_init_iterator; + + /// \brief Retrieve the first initialization argument. + capture_init_iterator capture_init_begin() const { + return reinterpret_cast(getStoredStmts()); + } + + /// \brief Retrieve the iterator pointing one past the last initialization + /// argument. + capture_init_iterator capture_init_end() const { + return capture_init_begin() + NumCaptures; + } + + SourceLocation getLocStart() const LLVM_READONLY { + return getCapturedStmt()->getLocStart(); + } + SourceLocation getLocEnd() const LLVM_READONLY { + return getCapturedStmt()->getLocEnd(); + } + SourceRange getSourceRange() const LLVM_READONLY { + return getCapturedStmt()->getSourceRange(); + } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == CapturedStmtClass; + } + + child_range children(); +}; + } // end namespace clang #endif diff --git a/include/clang/Basic/StmtNodes.td b/include/clang/Basic/StmtNodes.td index 979555e1a6..ad25e57c89 100644 --- a/include/clang/Basic/StmtNodes.td +++ b/include/clang/Basic/StmtNodes.td @@ -27,6 +27,7 @@ def DeclStmt : Stmt; def SwitchCase : Stmt<1>; def CaseStmt : DStmt; def DefaultStmt : DStmt; +def CapturedStmt : Stmt; // Asm statements def AsmStmt : Stmt<1>; diff --git a/include/clang/Serialization/ASTBitCodes.h b/include/clang/Serialization/ASTBitCodes.h index e3f9e0643a..04d6a85860 100644 --- a/include/clang/Serialization/ASTBitCodes.h +++ b/include/clang/Serialization/ASTBitCodes.h @@ -1103,6 +1103,8 @@ namespace clang { STMT_RETURN, /// \brief A DeclStmt record. STMT_DECL, + /// \brief A CapturedStmt record. + STMT_CAPTURED, /// \brief A GCC-style AsmStmt record. STMT_GCCASM, /// \brief A MS-style AsmStmt record. diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 2ae5a1266c..e120c6a1f8 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -1023,3 +1023,105 @@ SEHFinallyStmt* SEHFinallyStmt::Create(ASTContext &C, Stmt *Block) { return new(C)SEHFinallyStmt(Loc,Block); } + +CapturedStmt::Capture *CapturedStmt::getStoredCaptures() const { + unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (NumCaptures + 1); + + // Offset of the first Capture object. + unsigned FirstCaptureOffset = + llvm::RoundUpToAlignment(Size, llvm::alignOf()); + + return reinterpret_cast( + reinterpret_cast(const_cast(this)) + + FirstCaptureOffset); +} + +CapturedStmt::CapturedStmt(Stmt *S, ArrayRef Captures, + ArrayRef CaptureInits, + FunctionDecl *FD, + RecordDecl *RD) + : Stmt(CapturedStmtClass), NumCaptures(Captures.size()), + TheFuncDecl(FD), TheRecordDecl(RD) { + assert( S && "null captured statement"); + assert(FD && "null function declaration for captured statement"); + assert(RD && "null record declaration for captured statement"); + + // Copy initialization expressions. + Stmt **Stored = getStoredStmts(); + for (unsigned I = 0, N = NumCaptures; I != N; ++I) + *Stored++ = CaptureInits[I]; + + // Copy the statement being captured. + *Stored = S; + + // Copy all Capture objects. + Capture *Buffer = getStoredCaptures(); + std::copy(Captures.begin(), Captures.end(), Buffer); +} + +CapturedStmt::CapturedStmt(EmptyShell Empty, unsigned NumCaptures) + : Stmt(CapturedStmtClass, Empty), NumCaptures(NumCaptures), + TheFuncDecl(0), TheRecordDecl(0) { + getStoredStmts()[NumCaptures] = 0; +} + +CapturedStmt *CapturedStmt::Create(ASTContext &Context, Stmt *S, + ArrayRef Captures, + ArrayRef CaptureInits, + FunctionDecl *FD, + RecordDecl *RD) { + // The layout is + // + // ----------------------------------------------------------- + // | CapturedStmt, Init, ..., Init, S, Capture, ..., Capture | + // ----------------^-------------------^---------------------- + // getStoredStmts() getStoredCaptures() + // + // where S is the statement being captured. + // + assert(CaptureInits.size() == Captures.size() && "wrong number of arguments"); + + unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (Captures.size() + 1); + if (!Captures.empty()) { + // Realign for the following Capture array. + Size = llvm::RoundUpToAlignment(Size, llvm::alignOf()); + Size += sizeof(Capture) * Captures.size(); + } + + void *Mem = Context.Allocate(Size); + return new (Mem) CapturedStmt(S, Captures, CaptureInits, FD, RD); +} + +CapturedStmt *CapturedStmt::CreateDeserialized(ASTContext &Context, + unsigned NumCaptures) { + unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (NumCaptures + 1); + if (NumCaptures > 0) { + // Realign for the following Capture array. + Size = llvm::RoundUpToAlignment(Size, llvm::alignOf()); + Size += sizeof(Capture) * NumCaptures; + } + + void *Mem = Context.Allocate(Size); + return new (Mem) CapturedStmt(EmptyShell(), NumCaptures); +} + +Stmt::child_range CapturedStmt::children() { + // Children are captured field initilizers and the statement being captured. + return child_range(getStoredStmts(), getStoredStmts() + NumCaptures + 1); +} + +bool CapturedStmt::capturesVariable(const VarDecl *Var) const { + for (capture_iterator I = capture_begin(), + E = capture_end(); I != E; ++I) { + if (I->capturesThis()) + continue; + + // This does not handle variable redeclarations. This should be + // extended to capture variables with redeclarations, for example + // a thread-private variable in OpenMP. + if (I->getCapturedVar() == Var) + return true; + } + + return false; +} diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp index a86159f49d..469c2846a6 100644 --- a/lib/AST/StmtPrinter.cpp +++ b/lib/AST/StmtPrinter.cpp @@ -450,6 +450,10 @@ void StmtPrinter::VisitMSAsmStmt(MSAsmStmt *Node) { Indent() << "}\n"; } +void StmtPrinter::VisitCapturedStmt(CapturedStmt *Node) { + PrintStmt(Node->getCapturedStmt()); +} + void StmtPrinter::VisitObjCAtTryStmt(ObjCAtTryStmt *Node) { Indent() << "@try"; if (CompoundStmt *TS = dyn_cast(Node->getTryBody())) { diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp index 5525018f79..d99400c603 100644 --- a/lib/AST/StmtProfile.cpp +++ b/lib/AST/StmtProfile.cpp @@ -215,6 +215,10 @@ void StmtProfiler::VisitSEHExceptStmt(const SEHExceptStmt *S) { VisitStmt(S); } +void StmtProfiler::VisitCapturedStmt(const CapturedStmt *S) { + VisitStmt(S); +} + void StmtProfiler::VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { VisitStmt(S); } diff --git a/lib/CodeGen/CGStmt.cpp b/lib/CodeGen/CGStmt.cpp index 3153ca8ca7..d10818cf25 100644 --- a/lib/CodeGen/CGStmt.cpp +++ b/lib/CodeGen/CGStmt.cpp @@ -134,7 +134,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S) { case Stmt::SwitchStmtClass: EmitSwitchStmt(cast(*S)); break; case Stmt::GCCAsmStmtClass: // Intentional fall-through. case Stmt::MSAsmStmtClass: EmitAsmStmt(cast(*S)); break; - + case Stmt::CapturedStmtClass: + EmitCapturedStmt(cast(*S)); + break; case Stmt::ObjCAtTryStmtClass: EmitObjCAtTryStmt(cast(*S)); break; @@ -1735,3 +1737,7 @@ void CodeGenFunction::EmitAsmStmt(const AsmStmt &S) { EmitStoreThroughLValue(RValue::get(Tmp), ResultRegDests[i]); } } + +void CodeGenFunction::EmitCapturedStmt(const CapturedStmt &S) { + llvm_unreachable("not implemented yet"); +} diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h index 645d5ff237..941eebe882 100644 --- a/lib/CodeGen/CodeGenFunction.h +++ b/lib/CodeGen/CodeGenFunction.h @@ -2133,6 +2133,7 @@ public: void EmitCaseStmt(const CaseStmt &S); void EmitCaseStmtRange(const CaseStmt &S); void EmitAsmStmt(const AsmStmt &S); + void EmitCapturedStmt(const CapturedStmt &S); void EmitObjCForCollectionStmt(const ObjCForCollectionStmt &S); void EmitObjCAtTryStmt(const ObjCAtTryStmt &S); diff --git a/lib/Sema/TreeTransform.h b/lib/Sema/TreeTransform.h index b4083e9bc7..55f1587a85 100644 --- a/lib/Sema/TreeTransform.h +++ b/lib/Sema/TreeTransform.h @@ -9377,6 +9377,12 @@ TreeTransform::RebuildCXXPseudoDestructorExpr(Expr *Base, /*TemplateArgs*/ 0); } +template +StmtResult +TreeTransform::TransformCapturedStmt(CapturedStmt *S) { + llvm_unreachable("not implement yet"); +} + } // end namespace clang #endif // LLVM_CLANG_SEMA_TREETRANSFORM_H diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp index 567d50e71d..b18114fdcb 100644 --- a/lib/Serialization/ASTReaderStmt.cpp +++ b/lib/Serialization/ASTReaderStmt.cpp @@ -324,6 +324,10 @@ void ASTStmtReader::VisitMSAsmStmt(MSAsmStmt *S) { VisitStmt(S); } +void ASTStmtReader::VisitCapturedStmt(CapturedStmt *S) { + llvm_unreachable("not implemented yet"); +} + void ASTStmtReader::VisitExpr(Expr *E) { VisitStmt(E); E->setType(Reader.readType(F, Record, Idx)); @@ -1724,6 +1728,10 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) { S = new (Context) MSAsmStmt(Empty); break; + case STMT_CAPTURED: + llvm_unreachable("not implemented yet"); + break; + case EXPR_PREDEFINED: S = new (Context) PredefinedExpr(Empty); break; diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp index 920730ffd6..61ddec0fb0 100644 --- a/lib/Serialization/ASTWriterStmt.cpp +++ b/lib/Serialization/ASTWriterStmt.cpp @@ -255,6 +255,13 @@ void ASTStmtWriter::VisitMSAsmStmt(MSAsmStmt *S) { Code = serialization::STMT_MSASM; } +void ASTStmtWriter::VisitCapturedStmt(CapturedStmt *S) { + VisitStmt(S); + Code = serialization::STMT_CAPTURED; + + llvm_unreachable("not implemented yet"); +} + void ASTStmtWriter::VisitExpr(Expr *E) { VisitStmt(E); Writer.AddTypeRef(E->getType(), Record); diff --git a/lib/StaticAnalyzer/Core/ExprEngine.cpp b/lib/StaticAnalyzer/Core/ExprEngine.cpp index cf75deb5c0..4759b51de7 100644 --- a/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -656,6 +656,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred, case Stmt::SwitchStmtClass: case Stmt::WhileStmtClass: case Expr::MSDependentExistsStmtClass: + case Stmt::CapturedStmtClass: llvm_unreachable("Stmt should not be in analyzer evaluation loop"); case Stmt::ObjCSubscriptRefExprClass: diff --git a/tools/libclang/CXCursor.cpp b/tools/libclang/CXCursor.cpp index a413903f9a..edcc85b45e 100644 --- a/tools/libclang/CXCursor.cpp +++ b/tools/libclang/CXCursor.cpp @@ -270,6 +270,10 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent, K = CXCursor_DeclStmt; break; + case Stmt::CapturedStmtClass: + K = CXCursor_UnexposedStmt; + break; + case Stmt::IntegerLiteralClass: K = CXCursor_IntegerLiteral; break; diff --git a/tools/libclang/RecursiveASTVisitor.h b/tools/libclang/RecursiveASTVisitor.h index 0312f1fbbd..592f168725 100644 --- a/tools/libclang/RecursiveASTVisitor.h +++ b/tools/libclang/RecursiveASTVisitor.h @@ -1839,7 +1839,7 @@ DEF_TRAVERSE_STMT(MSDependentExistsStmt, { DEF_TRAVERSE_STMT(ReturnStmt, { }) DEF_TRAVERSE_STMT(SwitchStmt, { }) DEF_TRAVERSE_STMT(WhileStmt, { }) - +DEF_TRAVERSE_STMT(CapturedStmt, { }) DEF_TRAVERSE_STMT(CXXDependentScopeMemberExpr, { TRY_TO(TraverseNestedNameSpecifierLoc(S->getQualifierLoc()));