From 78e7d3af10c861678127dc9befe5dcdef5d1baea Mon Sep 17 00:00:00 2001 From: Mike Stump Date: Wed, 26 Aug 2009 01:54:35 +0000 Subject: [PATCH] Implement virtual dispatch. :-) This is self-consistent with clang, but not yet necessarily perfectly consistent with gcc. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@80064 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/CGCXX.cpp | 118 ++++++++++++++++++++++++++-------- lib/CodeGen/CodeGenFunction.h | 2 + test/CodeGenCXX/virt.cpp | 68 +++++++++++++++++++- 3 files changed, 160 insertions(+), 28 deletions(-) diff --git a/lib/CodeGen/CGCXX.cpp b/lib/CodeGen/CGCXX.cpp index 3efda13db8..4bf6a49774 100644 --- a/lib/CodeGen/CGCXX.cpp +++ b/lib/CodeGen/CGCXX.cpp @@ -200,15 +200,9 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) { const FunctionProtoType *FPT = MD->getType()->getAsFunctionProtoType(); - if (MD->isVirtual()) { - ErrorUnsupported(CE, "virtual dispatch"); - } - const llvm::Type *Ty = CGM.getTypes().GetFunctionType(CGM.getTypes().getFunctionInfo(MD), FPT->isVariadic()); - llvm::Constant *Callee = CGM.GetAddrOfFunction(GlobalDecl(MD), Ty); - llvm::Value *This; if (ME->isArrow()) @@ -217,6 +211,12 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) { LValue BaseLV = EmitLValue(ME->getBase()); This = BaseLV.getAddress(); } + + llvm::Value *Callee; + if (MD->isVirtual()) + Callee = BuildVirtualCall(MD, This, Ty); + else + Callee = CGM.GetAddrOfFunction(GlobalDecl(MD), Ty); return EmitCXXMemberCall(MD, Callee, This, CE->arg_begin(), CE->arg_end()); @@ -826,6 +826,10 @@ llvm::Constant *CodeGenModule::GenerateRtti(const CXXRecordDecl *RD) { } class VtableBuilder { +public: + /// Index_t - Vtable index type. + typedef uint64_t Index_t; +private: std::vector &methods; llvm::Type *Ptr8Ty; /// Class - The most derived class that this vtable is being built for. @@ -840,7 +844,7 @@ class VtableBuilder { CodeGenModule &CGM; // Per-module state. /// Index - Maps a method decl into a vtable index. Useful for virtual /// dispatch codegen. - llvm::DenseMap Index; + llvm::DenseMap Index; typedef CXXRecordDecl::method_iterator method_iter; public: VtableBuilder(std::vector &meth, @@ -852,6 +856,7 @@ public: Ptr8Ty = llvm::PointerType::get(llvm::Type::getInt8Ty(VMContext), 0); } + llvm::DenseMap &getIndex() { return Index; } llvm::Constant *GenerateVcall(const CXXMethodDecl *MD, const CXXRecordDecl *RD, bool VBoundary, @@ -932,17 +937,17 @@ public: SeenVBase.clear(); } - inline uint32_t nottoobig(uint64_t t) { - assert(t < (uint32_t)-1ULL || "vtable too big"); + inline Index_t nottoobig(uint64_t t) { + assert(t < (Index_t)-1ULL || "vtable too big"); return t; } #if 0 - inline uint32_t nottoobig(uint32_t t) { + inline Index_t nottoobig(Index_t t) { return t; } #endif - void AddMethod(const CXXMethodDecl *MD, int32_t FirstIndex) { + void AddMethod(const CXXMethodDecl *MD, Index_t AddressPoint) { typedef CXXMethodDecl::method_iterator meth_iter; llvm::Constant *m; @@ -963,34 +968,34 @@ public: om = CGM.GetAddrOfFunction(GlobalDecl(OMD), Ptr8Ty); om = llvm::ConstantExpr::getBitCast(om, Ptr8Ty); - for (int32_t i = FirstIndex, e = nottoobig(methods.size()); i != e; ++i) { + for (Index_t i = AddressPoint, e = methods.size(); + i != e; ++i) { // FIXME: begin_overridden_methods might be too lax, covariance */ if (methods[i] == om) { methods[i] = m; - Index[MD] = i; + Index[MD] = i - AddressPoint; return; } } } // else allocate a new slot. - Index[MD] = methods.size(); + Index[MD] = methods.size() - AddressPoint; methods.push_back(m); } - void GenerateMethods(const CXXRecordDecl *RD, int32_t FirstIndex) { + void GenerateMethods(const CXXRecordDecl *RD, Index_t AddressPoint) { for (method_iter mi = RD->method_begin(), me = RD->method_end(); mi != me; ++mi) if (mi->isVirtual()) - AddMethod(*mi, FirstIndex); + AddMethod(*mi, AddressPoint); } int64_t GenerateVtableForBase(const CXXRecordDecl *RD, bool forPrimary, bool VBoundary, int64_t Offset, - bool ForVirtualBase, - int32_t FirstIndex) { + bool ForVirtualBase) { llvm::Constant *m = llvm::Constant::getNullValue(Ptr8Ty); int64_t AddressPoint=0; @@ -1023,8 +1028,9 @@ public: if (PrimaryBaseWasVirtual) IndirectPrimary.insert(PrimaryBase); Top = false; - AddressPoint = GenerateVtableForBase(PrimaryBase, true, PrimaryBaseWasVirtual|VBoundary, - Offset, PrimaryBaseWasVirtual, FirstIndex); + AddressPoint = GenerateVtableForBase(PrimaryBase, true, + PrimaryBaseWasVirtual|VBoundary, + Offset, PrimaryBaseWasVirtual); } if (Top) { @@ -1041,7 +1047,7 @@ public: } // And add the virtuals for the class to the primary vtable. - GenerateMethods(RD, FirstIndex); + GenerateMethods(RD, AddressPoint); // and then the non-virtual bases. for (CXXRecordDecl::base_class_const_iterator i = RD->bases_begin(), @@ -1053,8 +1059,7 @@ public: if (Base != PrimaryBase || PrimaryBaseWasVirtual) { uint64_t o = Offset + Layout.getBaseClassOffset(Base); StartNewTable(); - FirstIndex = methods.size(); - GenerateVtableForBase(Base, true, false, o, false, FirstIndex); + GenerateVtableForBase(Base, true, false, o, false); } } return AddressPoint; @@ -1071,8 +1076,7 @@ public: IndirectPrimary.insert(Base); StartNewTable(); int64_t BaseOffset = BLayout.getVBaseClassOffset(Base); - int32_t FirstIndex = methods.size(); - GenerateVtableForBase(Base, false, true, BaseOffset, true, FirstIndex); + GenerateVtableForBase(Base, false, true, BaseOffset, true); } if (Base->getNumVBases()) GenerateVtableForVBases(Base, Class); @@ -1080,6 +1084,43 @@ public: } }; +class VtableInfo { +public: + typedef VtableBuilder::Index_t Index_t; +private: + CodeGenModule &CGM; // Per-module state. + /// Index_t - Vtable index type. + typedef llvm::DenseMap ElTy; + typedef llvm::DenseMap MapTy; + // FIXME: Move to Context. + static MapTy IndexFor; +public: + VtableInfo(CodeGenModule &cgm) : CGM(cgm) { } + void register_index(const CXXRecordDecl *RD, const ElTy &e) { + assert(IndexFor.find(RD) == IndexFor.end() || "Don't compute vtbl twice"); + // We own a copy of this, it will go away shortly. + new ElTy (e); + IndexFor[RD] = new ElTy (e); + } + Index_t lookup(const CXXMethodDecl *MD) { + const CXXRecordDecl *RD = MD->getParent(); + MapTy::iterator I = IndexFor.find(RD); + if (I == IndexFor.end()) { + std::vector methods; + VtableBuilder b(methods, RD, CGM); + b.GenerateVtableForBase(RD, true, false, 0, false); + b.GenerateVtableForVBases(RD, RD); + register_index(RD, b.getIndex()); + I = IndexFor.find(RD); + } + assert(I->second->find(MD)!=I->second->end() || "Can't find vtable index"); + return (*I->second)[MD]; + } +}; + +// FIXME: Move to Context. +VtableInfo::MapTy VtableInfo::IndexFor; + llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) { llvm::SmallString<256> OutName; llvm::raw_svector_ostream Out(OutName); @@ -1095,7 +1136,7 @@ llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) { VtableBuilder b(methods, RD, CGM); // First comes the vtables for all the non-virtual bases... - Offset = b.GenerateVtableForBase(RD, true, false, 0, false, 0); + Offset = b.GenerateVtableForBase(RD, true, false, 0, false); // then the vtables for all the virtual bases. b.GenerateVtableForVBases(RD, RD); @@ -1112,6 +1153,31 @@ llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) { return vtable; } +// FIXME: move to Context +static VtableInfo *vtableinfo; + +llvm::Value * +CodeGenFunction::BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This, + const llvm::Type *Ty) { + // FIXME: If we know the dynamic type, we don't have to do a virtual dispatch. + + // FIXME: move to Context + if (vtableinfo == 0) + vtableinfo = new VtableInfo(CGM); + + VtableInfo::Index_t Idx = vtableinfo->lookup(MD); + + Ty = llvm::PointerType::get(Ty, 0); + Ty = llvm::PointerType::get(Ty, 0); + Ty = llvm::PointerType::get(Ty, 0); + llvm::Value *vtbl = Builder.CreateBitCast(This, Ty); + vtbl = Builder.CreateLoad(vtbl); + llvm::Value *vfn = Builder.CreateConstInBoundsGEP1_64(vtbl, + Idx, "vfn"); + vfn = Builder.CreateLoad(vfn); + return vfn; +} + /// EmitClassAggrMemberwiseCopy - This routine generates code to copy a class /// array of objects from SrcValue to DestValue. Copying can be either a bitwise /// copy or via a copy constructor call. diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h index 4fc6c7c773..e8f0cc5de5 100644 --- a/lib/CodeGen/CodeGenFunction.h +++ b/lib/CodeGen/CodeGenFunction.h @@ -826,6 +826,8 @@ public: const Decl *TargetDecl = 0); RValue EmitCallExpr(const CallExpr *E); + llvm::Value *BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This, + const llvm::Type *Ty); RValue EmitCXXMemberCall(const CXXMethodDecl *MD, llvm::Value *Callee, llvm::Value *This, diff --git a/test/CodeGenCXX/virt.cpp b/test/CodeGenCXX/virt.cpp index 89411dea01..4583e0aa7c 100644 --- a/test/CodeGenCXX/virt.cpp +++ b/test/CodeGenCXX/virt.cpp @@ -91,6 +91,71 @@ int main() { // CHECK-LP64: movl $1, 12(%rax) // CHECK-LP64: movl $2, 8(%rax) +struct test12_A { + virtual void foo0() { } + virtual void foo() { } +} *test12_pa; + +struct test12_B : public test12_A { + virtual void foo() { } +} *test12_pb; + +struct test12_D : public test12_B { +} *test12_pd; +void test12_foo() { + test12_pa->foo0(); + test12_pb->foo0(); + test12_pd->foo0(); + test12_pa->foo(); + test12_pb->foo(); + test12_pd->foo(); +} + +// CHECK-LPOPT32:__Z10test12_foov: +// CHECK-LPOPT32: movl _test12_pa, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *(%ecx) +// CHECK-LPOPT32-NEXT: movl _test12_pb, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *(%ecx) +// CHECK-LPOPT32-NEXT: movl _test12_pd, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *(%ecx) +// CHECK-LPOPT32-NEXT: movl _test12_pa, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *4(%ecx) +// CHECK-LPOPT32-NEXT: movl _test12_pb, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *4(%ecx) +// CHECK-LPOPT32-NEXT: movl _test12_pd, %eax +// CHECK-LPOPT32-NEXT: movl (%eax), %ecx +// CHECK-LPOPT32-NEXT: movl %eax, (%esp) +// CHECK-LPOPT32-NEXT: call *4(%ecx) + +// CHECK-LPOPT64:__Z10test12_foov: +// CHECK-LPOPT64: movq _test12_pa(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *(%rax) +// CHECK-LPOPT64-NEXT: movq _test12_pb(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *(%rax) +// CHECK-LPOPT64-NEXT: movq _test12_pd(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *(%rax) +// CHECK-LPOPT64-NEXT: movq _test12_pa(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *8(%rax) +// CHECK-LPOPT64-NEXT: movq _test12_pb(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *8(%rax) +// CHECK-LPOPT64-NEXT: movq _test12_pd(%rip), %rdi +// CHECK-LPOPT64-NEXT: movq (%rdi), %rax +// CHECK-LPOPT64-NEXT: call *8(%rax) struct test6_B2 { virtual void funcB2(); char b[1000]; }; struct test6_B1 : virtual test6_B2 { virtual void funcB1(); }; @@ -115,7 +180,7 @@ struct test3_B3 { virtual void funcB3(); }; struct test3_B2 : virtual test3_B3 { virtual void funcB2(); }; struct test3_B1 : virtual test3_B2 { virtual void funcB1(); }; -struct test3_D : virtual test3_B1 { +struct test3_D : virtual test3_B1 { virtual void funcD() { } }; @@ -652,7 +717,6 @@ struct test11_D : test11_B { // CHECK-LP64-NEXT: .quad __ZN8test11_D2D2Ev - // CHECK-LP64: __ZTV1B: // CHECK-LP64-NEXT: .space 8 // CHECK-LP64-NEXT: .quad __ZTI1B