diff --git a/tools/clang/lib/SPIRV/AstTypeProbe.cpp b/tools/clang/lib/SPIRV/AstTypeProbe.cpp index aaea7a396..0c66dfd88 100644 --- a/tools/clang/lib/SPIRV/AstTypeProbe.cpp +++ b/tools/clang/lib/SPIRV/AstTypeProbe.cpp @@ -648,6 +648,30 @@ bool isSameType(const ASTContext &astContext, QualType type1, QualType type2) { arrType2->getElementType()); } + { // Two structures with identical fields + if (const auto *structType1 = type1->getAs()) { + if (const auto *structType2 = type2->getAs()) { + llvm::SmallVector fieldTypes1; + llvm::SmallVector fieldTypes2; + for (const auto *field : structType1->getDecl()->fields()) + fieldTypes1.push_back(field->getType()); + for (const auto *field : structType2->getDecl()->fields()) + fieldTypes2.push_back(field->getType()); + // Note: We currently do NOT consider such cases as equal types: + // struct s1 { int x; int y; } + // struct s2 { int2 x; } + // Therefore if two structs have different number of members, we + // consider them different. + if (fieldTypes1.size() != fieldTypes2.size()) + return false; + for (auto i = 0; i < fieldTypes1.size(); ++i) + if (!isSameType(astContext, fieldTypes1[i], fieldTypes2[i])) + return false; + return true; + } + } + } + // TODO: support other types if needed return false; diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 502f4861b..5c5831182 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -2329,15 +2329,24 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) { } // For assigning one array instance to another one with the same array type // (regardless of constness and literalness), the rhs will be wrapped in a - // FlatConversion: + // FlatConversion. Similarly for assigning a struct to another struct with + // identical members. // |- // `- ImplicitCastExpr // `- ImplicitCastExpr // `- - // This FlatConversion does not affect CodeGen, so that we can ignore it. - else if (subExprType->isArrayType() && - isSameType(astContext, expr->getType(), subExprType)) { - return doExpr(subExpr); + else if (isSameType(astContext, toType, evalType) || + // We can have casts changing the shape but without affecting + // memory order, e.g., `float4 a[2]; float b[8] = (float[8])a;`. + // This is also represented as FlatConversion. For such cases, we + // can rely on the InitListHandler, which can decompse + // vectors/matrices. + subExprType->isArrayType()) { + auto *valInstr = + InitListHandler(astContext, *this).processCast(toType, subExpr); + if (valInstr) + valInstr->setRValue(); + return valInstr; } // We can have casts changing the shape but without affecting memory order, // e.g., `float4 a[2]; float b[8] = (float[8])a;`. This is also represented diff --git a/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct-to-struct.hlsl b/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct-to-struct.hlsl new file mode 100644 index 000000000..b28f0a28e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct-to-struct.hlsl @@ -0,0 +1,60 @@ +// Run: %dxc -T cs_6_0 -E main + +// Processing FlatConversion when source and destination +// are both structures with identical members. + +struct FirstStruct { + float3 anArray[4]; + float2x3 mats[1]; + int2 ints[3]; +}; + +struct SecondStruct { + float3 anArray[4]; + float2x3 mats[1]; + int2 ints[3]; +}; + +RWStructuredBuffer rwBuf : register(u0); +[ numthreads ( 16 , 16 , 1 ) ] +void main() { + SecondStruct values; + FirstStruct v; + +// Yes, this is a FlatConversion! +// CHECK: [[v0ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v3float_uint_4_0 %values %int_0 +// CHECK-NEXT: [[v0:%\d+]] = OpLoad %_arr_v3float_uint_4_0 [[v0ptr]] +// CHECK-NEXT: [[v1ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_0 %values %int_1 +// CHECK-NEXT: [[v1:%\d+]] = OpLoad %_arr_mat2v3float_uint_1_0 [[v1ptr]] +// CHECK-NEXT: [[v2ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v2int_uint_3_0 %values %int_2 +// CHECK-NEXT: [[v2:%\d+]] = OpLoad %_arr_v2int_uint_3_0 [[v2ptr]] +// CHECK-NEXT: [[v:%\d+]] = OpCompositeConstruct %FirstStruct_0 [[v0]] [[v1]] [[v2]] +// CHECK-NEXT: OpStore %v [[v]] + v = values; + +// CHECK: [[v0ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v3float_uint_4_0 %values %int_0 +// CHECK-NEXT: [[v0:%\d+]] = OpLoad %_arr_v3float_uint_4_0 [[v0ptr]] +// CHECK-NEXT: [[v1ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_0 %values %int_1 +// CHECK-NEXT: [[v1:%\d+]] = OpLoad %_arr_mat2v3float_uint_1_0 [[v1ptr]] +// CHECK-NEXT: [[v2ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v2int_uint_3_0 %values %int_2 +// CHECK-NEXT: [[v2:%\d+]] = OpLoad %_arr_v2int_uint_3_0 [[v2ptr]] +// CHECK-NEXT: [[values:%\d+]] = OpCompositeConstruct %FirstStruct_0 [[v0]] [[v1]] [[v2]] +// CHECK-NEXT: [[rwBuf_ptr:%\d+]] = OpAccessChain %_ptr_Uniform_FirstStruct %rwBuf %int_0 %uint_0 +// CHECK-NEXT: [[anArray:%\d+]] = OpCompositeExtract %_arr_v3float_uint_4_0 [[values]] 0 +// CHECK-NEXT: [[anArray1:%\d+]] = OpCompositeExtract %v3float [[anArray]] 0 +// CHECK-NEXT: [[anArray2:%\d+]] = OpCompositeExtract %v3float [[anArray]] 1 +// CHECK-NEXT: [[anArray3:%\d+]] = OpCompositeExtract %v3float [[anArray]] 2 +// CHECK-NEXT: [[anArray4:%\d+]] = OpCompositeExtract %v3float [[anArray]] 3 +// CHECK-NEXT: [[res1:%\d+]] = OpCompositeConstruct %_arr_v3float_uint_4 [[anArray1]] [[anArray2]] [[anArray3]] [[anArray4]] +// CHECK-NEXT: [[mats:%\d+]] = OpCompositeExtract %_arr_mat2v3float_uint_1_0 [[values]] 1 +// CHECK-NEXT: [[mat:%\d+]] = OpCompositeExtract %mat2v3float [[mats]] 0 +// CHECK-NEXT: [[res2:%\d+]] = OpCompositeConstruct %_arr_mat2v3float_uint_1 [[mat]] +// CHECK-NEXT: [[ints:%\d+]] = OpCompositeExtract %_arr_v2int_uint_3_0 [[values]] 2 +// CHECK-NEXT: [[ints1:%\d+]] = OpCompositeExtract %v2int [[ints]] 0 +// CHECK-NEXT: [[ints2:%\d+]] = OpCompositeExtract %v2int [[ints]] 1 +// CHECK-NEXT: [[ints3:%\d+]] = OpCompositeExtract %v2int [[ints]] 2 +// CHECK-NEXT: [[res3:%\d+]] = OpCompositeConstruct %_arr_v2int_uint_3 [[ints1]] [[ints2]] [[ints3]] +// CHECK-NEXT: [[result:%\d+]] = OpCompositeConstruct %FirstStruct [[res1]] [[res2]] [[res3]] +// CHECK-NEXT: OpStore [[rwBuf_ptr]] [[result]] + rwBuf[0] = values; +} diff --git a/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct.hlsl b/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct.hlsl index 895e969cf..55ff5af1a 100644 --- a/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct.hlsl +++ b/tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct.hlsl @@ -16,9 +16,11 @@ float4 main(float4 a: A) : SV_Target { // CHECK-NEXT: OpStore %s [[s]] S s = (S)a; -// CHECK: [[s:%\d+]] = OpLoad %S %s -// CHECK-NEXT: [[t:%\d+]] = OpCompositeConstruct %T [[s]] -// CHECK-NEXT: OpStore %t [[t]] +// CHECK: [[valptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %s %int_0 +// CHECK-NEXT: [[val:%\d+]] = OpLoad %v4float [[valptr]] +// CHECK-NEXT: [[s:%\d+]] = OpCompositeConstruct %S [[val]] +// CHECK-NEXT: [[t:%\d+]] = OpCompositeConstruct %T [[s]] +// CHECK-NEXT: OpStore %t [[t]] T t = (T)s; return s.val + t.val.val; diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index ad9dfab6f..20195b12a 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -395,6 +395,9 @@ TEST_F(FileTest, CastFlatConversionStruct) { TEST_F(FileTest, CastFlatConversionNoOp) { runFileTest("cast.flat-conversion.no-op.hlsl"); } +TEST_F(FileTest, CastFlatConversionStructToStruct) { + runFileTest("cast.flat-conversion.struct-to-struct.hlsl"); +} TEST_F(FileTest, CastFlatConversionLiteralInitializer) { runFileTest("cast.flat-conversion.literal-initializer.hlsl"); }