diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index fb06060ca..b8b9d6a79 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -2193,6 +2193,36 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) { return processCall(callExpr); } +SpirvInstruction *SpirvEmitter::getBaseOfMemberFunction(QualType objectType, + SpirvInstruction * objInstr, + const CXXMethodDecl* memberFn, + SourceLocation loc) { + // If objectType is different from the parent of memberFn, memberFn should be + // defined in a base struct/class of objectType. We create OpAccessChain with + // index 0 while iterating bases of objectType until we find the base with + // the definition of memberFn. + if (const auto *ptrType = objectType->getAs()) { + if (const auto *recordType = ptrType->getPointeeType()->getAs()) { + const auto *parentDeclOfMemberFn = memberFn->getParent(); + if (recordType->getDecl() != parentDeclOfMemberFn) { + const auto *cxxRecordDecl = dyn_cast(recordType->getDecl()); + auto *zero = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); + for (auto baseItr = cxxRecordDecl->bases_begin(), itrEnd = cxxRecordDecl->bases_end(); + baseItr != itrEnd; baseItr++) { + const auto *baseType = baseItr->getType()->getAs(); + objectType = astContext.getPointerType(baseType->desugar()); + objInstr = spvBuilder.createAccessChain(objectType, + objInstr, {zero}, + loc); + if (baseType->getDecl() == parentDeclOfMemberFn) return objInstr; + } + } + } + } + return nullptr; +} + SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) { const FunctionDecl *callee = getCalleeDefinition(callExpr); @@ -2243,6 +2273,10 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) { objectType = object->getType(); objInstr = doExpr(object); + if (auto *accessToBaseInstr = getBaseOfMemberFunction(objectType, objInstr, memberFn, memberCall->getExprLoc())) { + objInstr = accessToBaseInstr; + objectType = accessToBaseInstr->getAstResultType(); + } // If not already a variable, we need to create a temporary variable and // pass the object pointer to the function. Example: diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 120f6bc87..8bf010571 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -975,6 +975,15 @@ private: SpirvInstruction *sampleIndex, SourceLocation loc); + /// \brief Returns OpAccessChain to the struct/class object that defines + /// memberFn when the struct/class is a base struct/class of objectType. + /// If the struct/class that defines memberFn is not a base of objectType, + /// returns nullptr. + SpirvInstruction *getBaseOfMemberFunction(QualType objectType, + SpirvInstruction * objInstr, + const CXXMethodDecl* memberFn, + SourceLocation loc); + private: /// \brief Takes a vector of size 4, and returns a vector of size 1 or 2 or 3 /// or 4. Creates a CompositeExtract or VectorShuffle instruction to extract diff --git a/tools/clang/test/CodeGenSPIRV/oo.call.method.with.same.base.method.name.hlsl b/tools/clang/test/CodeGenSPIRV/oo.call.method.with.same.base.method.name.hlsl new file mode 100644 index 000000000..74155b655 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/oo.call.method.with.same.base.method.name.hlsl @@ -0,0 +1,38 @@ +// Run: %dxc -T vs_5_0 -E main -fspv-target-env=vulkan1.1 + +// CHECK: %bar = OpTypeStruct %empty %mat4v4float +// CHECK: %foo = OpTypeStruct %bar + +// CHECK: OpFunctionCall %v4float %foo_get %x + +struct empty { +}; + +struct bar : empty { + float4x4 trans; + float4 value() { + return 0; + } +}; + +struct foo : bar { + float4 value() { + return 1; + } + +// When foo calls foo::value(), it must use +// this->value() instead of (this's 0th object)->value() + +// CHECK: %foo_get = OpFunction +// CHECK: %param_this = OpFunctionParameter %_ptr_Function_foo +// CHECK: OpFunctionCall %v4float %foo_value %param_this + + float4 get() { + return value(); + } +}; + +void main(out float4 Position : SV_Position) { + foo x; + Position = x.get(); +} diff --git a/tools/clang/test/CodeGenSPIRV/oo.inheritance.call.base.method.hlsl b/tools/clang/test/CodeGenSPIRV/oo.inheritance.call.base.method.hlsl new file mode 100644 index 000000000..902ea9ec4 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/oo.inheritance.call.base.method.hlsl @@ -0,0 +1,38 @@ +// Run: %dxc -T vs_5_0 -E main -fspv-target-env=vulkan1.1 + +// CHECK: %bar = OpTypeStruct %empty %mat4v4float +// CHECK: %foo = OpTypeStruct %bar + +// CHECK: OpFunctionCall %v4float %foo_get %x + +struct empty { +}; + +struct bar : empty { + float4x4 trans; + float4 value() { + return mul(float4(1,1,1,0), trans); + } +}; + +struct foo : bar { + +// When foo calls bar::value(), it must use +// (this's 0th object)->value() instead of this->value() +// because this's 0th object is the object of the base struct in SPIR-V. + +// CHECK: %foo_get = OpFunction +// CHECK: %param_this = OpFunctionParameter %_ptr_Function_foo +// CHECK: [[bar:%\w+]] = OpAccessChain %_ptr_Function_bar %param_this %uint_0 +// CHECK: OpFunctionCall %v4float %bar_value [[bar]] + + float4 get() { + return value(); + } +}; + +void main(out float4 Position : SV_Position) +{ + foo x; + Position = x.get(); +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index 0dbf87734..b52d1fd72 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -652,6 +652,13 @@ TEST_F(FileTest, InheritanceLayoutDifferences) { TEST_F(FileTest, InheritanceLayoutEmptyStruct) { runFileTest("oo.inheritance.layout.empty-struct.hlsl"); } +TEST_F(FileTest, InheritanceCallMethodOfBase) { + runFileTest("oo.inheritance.call.base.method.hlsl", Expect::Success, + /* runValidation */ false); +} +TEST_F(FileTest, InheritanceCallMethodWithSameBaseMethodName) { + runFileTest("oo.call.method.with.same.base.method.name.hlsl"); +} // For semantics // SV_Position, SV_ClipDistance, and SV_CullDistance are covered in