From 1e6d05ac59e2e6d734e6f43b0b19f594d8f26294 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 11 Aug 2017 12:16:27 -0400 Subject: [PATCH] [spirv] Add test for struct accessing and assignment (#551) --- tools/clang/lib/SPIRV/SPIRVEmitter.cpp | 42 +++++++---- tools/clang/lib/SPIRV/SPIRVEmitter.h | 6 ++ .../test/CodeGenSPIRV/binary-op.assign.hlsl | 20 ++++- .../test/CodeGenSPIRV/op.struct.access.hlsl | 75 +++++++++++++++++++ .../unittests/SPIRV/CodeGenSPIRVTest.cpp | 3 + 5 files changed, 132 insertions(+), 14 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp index 9b351218a..d2a44c075 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp @@ -1588,20 +1588,15 @@ uint32_t SPIRVEmitter::doInitListExpr(const InitListExpr *expr) { } uint32_t SPIRVEmitter::doMemberExpr(const MemberExpr *expr) { - const uint32_t base = doExpr(expr->getBase()); - const auto *memberDecl = expr->getMemberDecl(); - if (const auto *fieldDecl = dyn_cast(memberDecl)) { - const auto index = theBuilder.getConstantInt32(fieldDecl->getFieldIndex()); - const uint32_t fieldType = - typeTranslator.translateType(fieldDecl->getType()); - const uint32_t ptrType = theBuilder.getPointerType( - fieldType, declIdMapper.resolveStorageClass(expr->getBase())); - return theBuilder.createAccessChain(ptrType, base, {index}); - } + llvm::SmallVector indices; - emitError("Decl '%0' in MemberExpr is not supported yet.") - << memberDecl->getDeclKindName(); - return 0; + const Expr *baseExpr = collectStructIndices(expr, &indices); + const uint32_t base = doExpr(baseExpr); + + const uint32_t fieldType = typeTranslator.translateType(expr->getType()); + const uint32_t ptrType = theBuilder.getPointerType( + fieldType, declIdMapper.resolveStorageClass(baseExpr)); + return theBuilder.createAccessChain(ptrType, base, indices); } uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) { @@ -2391,6 +2386,27 @@ uint32_t SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs, return 0; } +const Expr * +SPIRVEmitter::collectStructIndices(const MemberExpr *expr, + llvm::SmallVectorImpl *indices) { + const Expr *base = expr->getBase(); + if (const auto *memExpr = dyn_cast(base)) { + base = collectStructIndices(memExpr, indices); + } else { + indices->clear(); + } + + const auto *memberDecl = expr->getMemberDecl(); + if (const auto *fieldDecl = dyn_cast(memberDecl)) { + indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex())); + } else { + emitError("Decl '%0' in MemberExpr is not supported yet.") + << memberDecl->getDeclKindName(); + } + + return base; +} + uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType, QualType toBoolType) { if (isSameScalarOrVecType(fromType, toBoolType)) diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.h b/tools/clang/lib/SPIRV/SPIRVEmitter.h index 50b33930e..cd3fb83ba 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.h +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.h @@ -210,6 +210,12 @@ private: uint32_t processMatrixBinaryOp(const Expr *lhs, const Expr *rhs, const BinaryOperatorKind opcode); + /// Collects all indices (SPIR-V constant values) from consecutive MemberExprs + /// and writes into indices. Returns the real base (the first Expr that is not + /// a MemberExpr). + const Expr *collectStructIndices(const MemberExpr *expr, + llvm::SmallVectorImpl *indices); + private: /// Processes the given expr, casts the result into the given bool (vector) /// type and returns the of the casted value. diff --git a/tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl b/tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl index fe21c5f42..795258d54 100644 --- a/tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl +++ b/tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl @@ -1,6 +1,13 @@ // Run: %dxc -T ps_6_0 -E main -// TODO: assignment for composite types +struct S { + float x; +}; + +struct T { + float y; + S z; +}; void main() { int a, b, c; @@ -20,4 +27,15 @@ void main() { // CHECK-NEXT: OpStore %a [[a1]] // CHECK-NEXT: OpStore %a [[a1]] a = a = a; + + T p, q; + +// CHECK-NEXT: [[q:%\d+]] = OpLoad %T %q +// CHECK-NEXT: OpStore %p [[q]] + p = q; // assign as a whole +// CHECK-NEXT: [[q1ptr:%\d+]] = OpAccessChain %_ptr_Function_S %q %int_1 +// CHECK-NEXT: [[q1val:%\d+]] = OpLoad %S [[q1ptr]] +// CHECK-NEXT: [[p1ptr:%\d+]] = OpAccessChain %_ptr_Function_S %p %int_1 +// CHECK-NEXT: OpStore [[p1ptr]] [[q1val]] + p.z = q.z; // assign nested struct } diff --git a/tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl b/tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl new file mode 100644 index 000000000..4dbd90fe5 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl @@ -0,0 +1,75 @@ +// Run: %dxc -T vs_6_0 -E main + +struct S { + bool a; + uint2 b; + float2x3 c; +}; + +struct T { + int h; // Nested struct + S i; +}; + +void main() { + T t; + +// CHECK: [[h:%\d+]] = OpAccessChain %_ptr_Function_int %t %int_0 +// CHECK-NEXT: {{%\d+}} = OpLoad %int [[h]] + int v1 = t.h; +// CHECK: [[a:%\d+]] = OpAccessChain %_ptr_Function_bool %t %int_1 %int_0 +// CHECK-NEXT: {{%\d+}} = OpLoad %bool [[a]] + bool v2 = t.i.a; + +// CHECK: [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1 +// CHECK-NEXT: [[b0:%\d+]] = OpAccessChain %_ptr_Function_uint [[b]] %uint_0 +// CHECK-NEXT: {{%\d+}} = OpLoad %uint [[b0]] + uint v3 = t.i.b[0]; +// CHECK: [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1 +// CHECK-NEXT: {{%\d+}} = OpLoad %v2uint [[b]] + uint2 v4 = t.i.b.rg; + +// CHECK: [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2 +// CHECK-NEXT: [[c00p:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_0 %int_0 +// CHECK-NEXT: [[c00v:%\d+]] = OpLoad %float [[c00p]] +// CHECK-NEXT: [[c11p:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_1 %int_1 +// CHECK-NEXT: [[c11v:%\d+]] = OpLoad %float [[c11p]] +// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v2float [[c00v]] [[c11v]] + float2 v5 = t.i.c._11_22; +// CHECK: [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2 +// CHECK-NEXT: [[c1:%\d+]] = OpAccessChain %_ptr_Function_v3float [[c]] %uint_1 +// CHECK-NEXT: {{%\d+}} = OpLoad %v3float [[c1]] + float3 v6 = t.i.c[1]; + +// CHECK: [[h:%\d+]] = OpAccessChain %_ptr_Function_int %t %int_0 +// CHECK-NEXT: OpStore [[h]] {{%\d+}} + t.h = v1; +// CHECK: [[a:%\d+]] = OpAccessChain %_ptr_Function_bool %t %int_1 %int_0 +// CHECK-NEXT: OpStore [[a]] {{%\d+}} + t.i.a = v2; + +// CHECK: [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1 +// CHECK-NEXT: [[b1:%\d+]] = OpAccessChain %_ptr_Function_uint [[b]] %uint_1 +// CHECK-NEXT: OpStore [[b1]] {{%\d+}} + t.i.b[1] = v3; +// CHECK: [[v4:%\d+]] = OpLoad %v2uint %v4 +// CHECK-NEXT: [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1 +// CHECK-NEXT: [[bv:%\d+]] = OpLoad %v2uint [[b]] +// CHECK-NEXT: [[gr:%\d+]] = OpVectorShuffle %v2uint [[bv]] [[v4]] 3 2 +// CHECK-NEXT: OpStore [[b]] [[gr]] + t.i.b.gr = v4; + +// CHECK: [[v5:%\d+]] = OpLoad %v2float %v5 +// CHECK-NEXT: [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2 +// CHECK-NEXT: [[v50:%\d+]] = OpCompositeExtract %float [[v5]] 0 +// CHECK-NEXT: [[c11:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_1 %int_1 +// CHECK-NEXT: OpStore [[c11]] [[v50]] +// CHECK-NEXT: [[v51:%\d+]] = OpCompositeExtract %float [[v5]] 1 +// CHECK-NEXT: [[c00:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_0 %int_0 +// CHECK-NEXT: OpStore [[c00]] [[v51]] + t.i.c._22_11 = v5; +// CHECK: [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2 +// CHECK-NEXT: [[c0:%\d+]] = OpAccessChain %_ptr_Function_v3float [[c]] %uint_0 +// CHECK-NEXT: OpStore [[c0]] {{%\d+}} + t.i.c[0] = v6; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp index 752fcddb7..9b8996b93 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp @@ -178,6 +178,9 @@ TEST_F(FileTest, OpMatrixAccess1x1) { runFileTest("op.matrix.access.1x1.hlsl"); } +// For struct accessing operator +TEST_F(FileTest, OpStructAccess) { runFileTest("op.struct.access.hlsl"); } + // For casting TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); } TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }