diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 3c60592a4..d09964d6e 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -517,12 +517,10 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) { const uint32_t id = theBuilder.getSPIRVContext()->takeNextId(); info.setResultId(id); - if (isAlias) - // No need to dereference to get the pointer. Alias function returns - // themselves are already pointers to values. - info.setValTypeId(0); - else - // All other cases should be normal rvalues. + // No need to dereference to get the pointer. Alias function returns + // themselves are already pointers to values. All other cases should be + // normal rvalues. + if (!isAlias) info.setRValue(); // Create alias counter variable if suitable @@ -1882,25 +1880,18 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar( if (const auto *varDecl = dyn_cast(decl)) { // This method is only intended to be used to create SPIR-V variables in the // Function or Private storage class. - assert(!varDecl->isExceptionVariable() || varDecl->isStaticDataMember()); + assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember()); } const QualType type = getTypeOrFnRetType(decl); // Whether we should generate this decl as an alias variable. bool genAlias = false; - // All texture/structured/byte buffers use GLSL std430 rules. - LayoutRule rule = LayoutRule::GLSLStd430; if (const auto *buffer = dyn_cast(decl->getDeclContext())) { // For ConstantBuffer and TextureBuffer if (buffer->isConstantBufferView()) genAlias = true; - // ConstantBuffer uses GLSL std140 rules. - // TODO: do we actually want to include constant/texture buffers - // in this method? - if (buffer->isCBuffer()) - rule = LayoutRule::GLSLStd140; - } else if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) { + } else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) { genAlias = true; } @@ -1910,18 +1901,8 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar( if (genAlias) { needsLegalization = true; - const uint32_t valType = typeTranslator.translateType(type, rule); - // All constant/texture/structured/byte buffers are in the Uniform - // storage class. - const auto ptrType = - theBuilder.getPointerType(valType, spv::StorageClass::Uniform); - if (info) - info->setStorageClass(spv::StorageClass::Uniform) - .setLayoutRule(rule) - .setValTypeId(ptrType); - - return ptrType; + info->setContainsAliasComponent(true); } return typeTranslator.translateType(type); diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp index 651f59b74..582cdbd33 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp @@ -660,17 +660,22 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr, return info.setRValue(); } - uint32_t valType = 0; - if (valType = info.getValTypeId()) { + if (loadIfAliasVarRef(expr, info)) { // We are loading an alias variable as a whole here. This is likely for // wholesale assignments or function returns. Need to load the pointer. // // Note: legalization specific code + // TODO: It seems we should not set rvalue here since info is still + // holding a pointer. But it fails structured buffer assignment because + // of double loadIfGLValue() calls if we do not. Fix it. + return info.setRValue(); } + + uint32_t valType = 0; // TODO: Ouch. Very hacky. We need special path to get the value type if // we are loading a whole ConstantBuffer/TextureBuffer since the normal // type translation path won't work. - else if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) { + if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) { valType = declIdMapper.getCTBufferPushConstantTypeId(declContext); } else { valType = @@ -684,18 +689,34 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr, SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) { auto info = doExpr(expr); + loadIfAliasVarRef(expr, info); + return info; +} - if (const auto valTypeId = info.getValTypeId()) { - return info - // Load the pointer of the aliased-to-variable - .setResultId(theBuilder.createLoad(valTypeId, info)) - // Set the value's to zero to indicate that we've performed - // dereference over the pointer-to-pointer and now should fallback to - // the normal path - .setValTypeId(0); +bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) { + if (info.containsAliasComponent() && + TypeTranslator::isAKindOfStructuredOrByteBuffer(varExpr->getType())) { + // Aliased-to variables are all in the Uniform storage class with GLSL + // std430 layout rules. + const auto ptrType = typeTranslator.translateType(varExpr->getType()); + + // Load the pointer of the aliased-to-variable if the expression has a + // pointer to pointer type. That is, the expression itself is a lvalue. + // (Note that we translate alias function return values as pointer types, + // not pointer to pointer types.) + if (varExpr->isGLValue()) + info.setResultId(theBuilder.createLoad(ptrType, info)); + + info.setStorageClass(spv::StorageClass::Uniform) + .setLayoutRule(LayoutRule::GLSLStd430) + // Set to false to indicate that we've performed dereference over the + // pointer-to-pointer and now should fallback to the normal path + .setContainsAliasComponent(false); + + return true; } - return info; + return false; } uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType, @@ -977,7 +998,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) { else storeValue(varId, loadIfGLValue(init), decl->getType()); - // Update counter variable associatd with local variables + // Update counter variable associated with local variables tryToAssignCounterVar(decl, init); } @@ -1427,7 +1448,7 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) { void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) { if (const auto *retVal = stmt->getRetValue()) { - // Update counter variable associatd with function returns + // Update counter variable associated with function returns tryToAssignCounterVar(curFunction, retVal); const auto retInfo = doExpr(retVal); @@ -1559,7 +1580,7 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) { // For other binary operations, we need to evaluate lhs before rhs. if (opcode == BO_Assign) { if (const auto *dstDecl = getReferencedDef(expr->getLHS())) - // Update counter variable associatd with lhs of assignments + // Update counter variable associated with lhs of assignments tryToAssignCounterVar(dstDecl, expr->getRHS()); return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()), @@ -4645,6 +4666,18 @@ const Expr *SPIRVEmitter::collectArrayStructIndices( const auto thisBaseType = thisBase->getType(); const Expr *base = collectArrayStructIndices(thisBase, indices); + if (thisBaseType != base->getType() && + TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) { + // The immediate base is a kind of structured or byte buffer. It should + // be an alias variable. Break the normal index collecting chain. + // Return the immediate base as the base so that we can apply other + // hacks for legalization over it. + // + // Note: legalization specific code + indices->clear(); + base = thisBase; + } + // If the base is a StructureType, we need to push an addtional index 0 // here. This is because we created an additional OpTypeRuntimeArray // in the structure. @@ -7174,7 +7207,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl, if (const auto *init = varDecl->getInit()) { storeValue(varInfo, doExpr(init), varDecl->getType()); - // Update counter variable associatd with global variables + // Update counter variable associated with global variables tryToAssignCounterVar(varDecl, init); } else { const auto typeId = typeTranslator.translateType(varDecl->getType()); diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.h b/tools/clang/lib/SPIRV/SPIRVEmitter.h index 1506b3355..befa50d50 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.h +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.h @@ -118,6 +118,13 @@ private: /// Note: legalization specific code SpirvEvalInfo loadIfAliasVarRef(const Expr *expr); + /// Loads the pointer of the aliased-to-variable and ajusts aliasVarInfo + /// accordingly if aliasVarExpr is referencing an alias variable. Returns true + /// if aliasVarInfo is changed, false otherwise. + /// + /// Note: legalization specific code + bool loadIfAliasVarRef(const Expr *aliasVarExpr, SpirvEvalInfo &aliasVarInfo); + private: /// Translates the given frontend binary operator into its SPIR-V equivalent /// taking consideration of the operand type. diff --git a/tools/clang/lib/SPIRV/SpirvEvalInfo.h b/tools/clang/lib/SPIRV/SpirvEvalInfo.h index 82d0a1140..91ba067e7 100644 --- a/tools/clang/lib/SPIRV/SpirvEvalInfo.h +++ b/tools/clang/lib/SPIRV/SpirvEvalInfo.h @@ -79,8 +79,8 @@ public: /// Handly implicit conversion to test whether the is valid. operator bool() const { return resultId != 0; } - inline SpirvEvalInfo &setValTypeId(uint32_t id); - uint32_t getValTypeId() const { return valTypeId; } + inline SpirvEvalInfo &setContainsAliasComponent(bool); + bool containsAliasComponent() const { return containsAlias; } inline SpirvEvalInfo &setStorageClass(spv::StorageClass sc); spv::StorageClass getStorageClass() const { return storageClass; } @@ -99,14 +99,15 @@ public: private: uint32_t resultId; - /// The value's for this variable. + /// Indicates whether this evaluation result contains alias variables /// - /// This field should only be non-zero for original alias variables, which is - /// of pointer-to-pointer type. After dereferencing the alias variable, this - /// should be set to zero to let CodeGen fall back to normal handling path. + /// This field should only be true for stand-alone alias variables, which is + /// of pointer-to-pointer type, or struct variables containing alias fields. + /// After dereferencing the alias variable, this should be set to false to let + /// CodeGen fall back to normal handling path. /// /// Note: legalization specific code - uint32_t valTypeId; + bool containsAlias; spv::StorageClass storageClass; LayoutRule layoutRule; @@ -117,9 +118,9 @@ private: }; SpirvEvalInfo::SpirvEvalInfo(uint32_t id) - : resultId(id), valTypeId(0), storageClass(spv::StorageClass::Function), - layoutRule(LayoutRule::Void), isRValue_(false), isConstant_(false), - isRelaxedPrecision_(false) {} + : resultId(id), containsAlias(false), + storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void), + isRValue_(false), isConstant_(false), isRelaxedPrecision_(false) {} SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) { resultId = id; @@ -132,8 +133,8 @@ SpirvEvalInfo SpirvEvalInfo::substResultId(uint32_t newId) const { return info; } -SpirvEvalInfo &SpirvEvalInfo::setValTypeId(uint32_t id) { - valTypeId = id; +SpirvEvalInfo &SpirvEvalInfo::setContainsAliasComponent(bool contains) { + containsAlias = contains; return *this; } diff --git a/tools/clang/lib/SPIRV/TypeTranslator.cpp b/tools/clang/lib/SPIRV/TypeTranslator.cpp index 32e8b0ae2..4203f9693 100644 --- a/tools/clang/lib/SPIRV/TypeTranslator.cpp +++ b/tools/clang/lib/SPIRV/TypeTranslator.cpp @@ -488,6 +488,22 @@ bool TypeTranslator::isAKindOfStructuredOrByteBuffer(QualType type) { return false; } +bool TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(QualType type) { + if (const RecordType *recordType = type->getAs()) { + StringRef name = recordType->getDecl()->getName(); + if (name == "StructuredBuffer" || name == "RWStructuredBuffer" || + name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" || + name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") + return true; + + for (const auto *field : recordType->getDecl()->fields()) { + if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType())) + return true; + } + } + return false; +} + bool TypeTranslator::isStructuredBuffer(QualType type) { const auto *recordType = type->getAs(); if (!recordType) @@ -836,10 +852,20 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) { if (name == "StructuredBuffer" || name == "RWStructuredBuffer" || name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") { - auto &context = *theBuilder.getSPIRVContext(); // StructureBuffer will be translated into an OpTypeStruct with one // field, which is an OpTypeRuntimeArray of OpTypeStruct (S). + // If layout rule is void, it means these resource types are used for + // declaring local resources, which should be created as alias variables. + // The aliased-to variable should surely be in the Uniform storage class, + // which has layout decorations. + bool asAlias = false; + if (rule == LayoutRule::Void) { + asAlias = true; + rule = LayoutRule::GLSLStd430; + } + + auto &context = *theBuilder.getSPIRVContext(); const auto s = hlsl::GetHLSLResourceResultType(type); const uint32_t structType = translateType(s, rule); std::string structName; @@ -864,16 +890,36 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) { decorations.push_back(Decoration::getNonWritable(context, 0)); decorations.push_back(Decoration::getBufferBlock(context)); const std::string typeName = "type." + name.str() + "." + structName; - return theBuilder.getStructType(raType, typeName, {}, decorations); + const auto valType = + theBuilder.getStructType(raType, typeName, {}, decorations); + + if (asAlias) { + // All structured buffers are in the Uniform storage class. + return theBuilder.getPointerType(valType, spv::StorageClass::Uniform); + } else { + return valType; + } } // ByteAddressBuffer types. if (name == "ByteAddressBuffer") { - return theBuilder.getByteAddressBufferType(/*isRW*/ false); + const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ false); + if (rule == LayoutRule::Void) { + // All byte address buffers are in the Uniform storage class. + return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform); + } else { + return bufferType; + } } // RWByteAddressBuffer types. if (name == "RWByteAddressBuffer") { - return theBuilder.getByteAddressBufferType(/*isRW*/ true); + const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ true); + if (rule == LayoutRule::Void) { + // All byte address buffers are in the Uniform storage class. + return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform); + } else { + return bufferType; + } } // Buffer and RWBuffer types diff --git a/tools/clang/lib/SPIRV/TypeTranslator.h b/tools/clang/lib/SPIRV/TypeTranslator.h index c07654ad3..10ba678ee 100644 --- a/tools/clang/lib/SPIRV/TypeTranslator.h +++ b/tools/clang/lib/SPIRV/TypeTranslator.h @@ -95,6 +95,11 @@ public: /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer. static bool isAKindOfStructuredOrByteBuffer(QualType type); + /// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer, + /// (RW)ByteAddressBuffer, {Append|Consume}StructuredBuffer, or a struct + /// containing one of the above. + static bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type); + /// \brief Returns true if the given type is the HLSL Buffer type. static bool isBuffer(QualType type); diff --git a/tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl b/tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl new file mode 100644 index 000000000..9add8679e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.struct.hlsl @@ -0,0 +1,86 @@ +// Run: %dxc -T ps_6_0 -E main + +struct Basic { + float3 a; + float4 b; +}; + +// CHECK: %S = OpTypeStruct %_ptr_Uniform_type_AppendStructuredBuffer_v4float %_ptr_Uniform_type_AppendStructuredBuffer_v4float +struct S { + AppendStructuredBuffer append; + ConsumeStructuredBuffer consume; +}; + +// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic +struct T { + StructuredBuffer ro; + RWStructuredBuffer rw; +}; + +// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer +struct Combine { + S s; + T t; + ByteAddressBuffer ro; + RWByteAddressBuffer rw; +}; + + StructuredBuffer gSBuffer; + RWStructuredBuffer gRWSBuffer; + AppendStructuredBuffer gASBuffer; +ConsumeStructuredBuffer gCSBuffer; + ByteAddressBuffer gBABuffer; + RWByteAddressBuffer gRWBABuffer; + +float4 foo(Combine comb); + +float4 main() : SV_Target { + Combine c; + +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_0 +// CHECK-NEXT: OpStore [[ptr]] %gASBuffer + c.s.append = gASBuffer; +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_1 +// CHECK-NEXT: OpStore [[ptr]] %gCSBuffer + c.s.consume = gCSBuffer; + + T t; +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %t %int_0 +// CHECK-NEXT: OpStore [[ptr]] %gSBuffer + t.ro = gSBuffer; +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_Basic %t %int_1 +// CHECK-NEXT: OpStore [[ptr]] %gRWSBuffer + t.rw = gRWSBuffer; +// CHECK: [[val:%\d+]] = OpLoad %T %t +// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_T %c %int_1 +// CHECK-NEXT: OpStore [[ptr]] [[val]] + c.t = t; + +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %c %int_2 +// CHECK-NEXT: OpStore [[ptr]] %gBABuffer + c.ro = gBABuffer; +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWByteAddressBuffer %c %int_3 +// CHECK-NEXT: OpStore [[ptr]] %gRWBABuffer + c.rw = gRWBABuffer; + +// CHECK: [[val:%\d+]] = OpLoad %Combine %c +// CHECK-NEXT: OpStore %param_var_comb [[val]] + return foo(c); +} +float4 foo(Combine comb) { + // TODO: add support for associated counters of struct fields + // comb.s.append.Append(float4(1, 2, 3, 4)); + // float4 val = comb.s.consume.Consume(); + // comb.t.rw[5].a = 4.2; + +// CHECK: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %comb %int_2 +// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer [[ptr1]] +// CHECK-NEXT: [[idx:%\d+]] = OpShiftRightLogical %uint %uint_5 %uint_2 +// CHECK-NEXT: {{%\d+}} = OpAccessChain %_ptr_Uniform_uint [[ptr2]] %uint_0 [[idx]] + uint val = comb.ro.Load(5); + +// CHECK: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %comb %int_1 %int_0 +// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_Basic [[ptr1]] +// CHECK-NEXT: {{%\d+}} = OpAccessChain %_ptr_Uniform_v4float [[ptr2]] %int_0 %uint_0 %int_1 + return comb.t.ro[0].b; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp index a83b4575f..531030696 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp @@ -1003,6 +1003,11 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) { // The generated SPIR-V needs legalization. /*runValidation=*/false); } +TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) { + runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success, + // The generated SPIR-V needs legalization. + /*runValidation=*/false); +} TEST_F(FileTest, SpirvLegalizationConstantBuffer) { runFileTest("spirv.legal.cbuffer.hlsl"); }