From ddbee48f85e3cb977695835de364e6f05e82dd62 Mon Sep 17 00:00:00 2001 From: Spencer Fricke Date: Fri, 23 Sep 2022 21:45:11 +0900 Subject: [PATCH] spirv-opt: Fix stacked CompositeExtract constant folds (#4932) This was spotted in the Validation Layers where OpSpecConstantOp %x CompositeExtract %y 0 was being folded to a constant, but anything that was using it wasn't recognizing it as a constant, the simple fix was to add a const_mgr->MapInst(new_const_inst); so the next instruction knew it was a const --- source/opt/fold.cpp | 3 +- ...ld_spec_constant_op_and_composite_pass.cpp | 34 ++- .../opt/fold_spec_const_op_composite_test.cpp | 205 +++++++++++++++++- 3 files changed, 221 insertions(+), 21 deletions(-) diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp index b903da6a..315741ad 100644 --- a/source/opt/fold.cpp +++ b/source/opt/fold.cpp @@ -627,8 +627,7 @@ Instruction* InstructionFolder::FoldInstructionToConstant( Instruction* inst, std::function id_map) const { analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); - if (!inst->IsFoldableByFoldScalar() && - !GetConstantFoldingRules().HasFoldingRule(inst)) { + if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) { return nullptr; } // Collect the values of the constant parameters. diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index 8d68850a..7a518701 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -28,6 +28,7 @@ namespace opt { Pass::Status FoldSpecConstantOpAndCompositePass::Process() { bool modified = false; + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); // Traverse through all the constant defining instructions. For Normal // Constants whose values are determined and do not depend on OpUndef // instructions, records their values in two internal maps: id_to_const_val_ @@ -62,8 +63,8 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() { // used in OpSpecConstant{Composite|Op} instructions. // TODO(qining): If the constant or its type has decoration, we may need // to skip it. - if (context()->get_constant_mgr()->GetType(inst) && - !context()->get_constant_mgr()->GetType(inst)->decoration_empty()) + if (const_mgr->GetType(inst) && + !const_mgr->GetType(inst)->decoration_empty()) continue; switch (SpvOp opcode = inst->opcode()) { // Records the values of Normal Constants. @@ -80,15 +81,14 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() { // Constant will be turned in to a Normal Constant. In that case, a // Constant instance should also be created successfully and recorded // in the id_to_const_val_ and const_val_to_id_ mapps. - if (auto const_value = - context()->get_constant_mgr()->GetConstantFromInst(inst)) { + if (auto const_value = const_mgr->GetConstantFromInst(inst)) { // Need to replace the OpSpecConstantComposite instruction with a // corresponding OpConstantComposite instruction. if (opcode == SpvOp::SpvOpSpecConstantComposite) { inst->SetOpcode(SpvOp::SpvOpConstantComposite); modified = true; } - context()->get_constant_mgr()->MapConstantToInst(const_value, inst); + const_mgr->MapConstantToInst(const_value, inst); } break; } @@ -146,6 +146,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder( Module::inst_iterator* inst_iter_ptr) { + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); // If one of operands to the instruction is not a // constant, then we cannot fold this spec constant. for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) { @@ -155,7 +156,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder( continue; } uint32_t id = operand.words[0]; - if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) { + if (const_mgr->FindDeclaredConstant(id) == nullptr) { return nullptr; } } @@ -202,6 +203,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder( new_const_inst->InsertAfter(insert_pos); get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst); } + const_mgr->MapInst(new_const_inst); return new_const_inst; } @@ -285,8 +287,8 @@ utils::SmallVector EncodeIntegerAsWords(const analysis::Type& type, Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( Module::inst_iterator* pos) { const Instruction* inst = &**pos; - const analysis::Type* result_type = - context()->get_constant_mgr()->GetType(inst); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Type* result_type = const_mgr->GetType(inst); SpvOp spec_opcode = static_cast(inst->GetSingleWordInOperand(0)); // Check and collect operands. std::vector operands; @@ -311,10 +313,9 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( // Scalar operation const uint32_t result_val = context()->get_instruction_folder().FoldScalars(spec_opcode, operands); - auto result_const = context()->get_constant_mgr()->GetConstant( + auto result_const = const_mgr->GetConstant( result_type, EncodeIntegerAsWords(*result_type, result_val)); - return context()->get_constant_mgr()->BuildInstructionAndAddToModule( - result_const, pos); + return const_mgr->BuildInstructionAndAddToModule(result_const, pos); } else if (result_type->AsVector()) { // Vector operation const analysis::Type* element_type = @@ -325,11 +326,10 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( operands); std::vector result_vector_components; for (const uint32_t r : result_vec) { - if (auto rc = context()->get_constant_mgr()->GetConstant( + if (auto rc = const_mgr->GetConstant( element_type, EncodeIntegerAsWords(*element_type, r))) { result_vector_components.push_back(rc); - if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule( - rc, pos)) { + if (!const_mgr->BuildInstructionAndAddToModule(rc, pos)) { assert(false && "Failed to build and insert constant declaring instruction " "for the given vector component constant"); @@ -340,10 +340,8 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( } auto new_vec_const = MakeUnique( result_type->AsVector(), result_vector_components); - auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant( - std::move(new_vec_const)); - return context()->get_constant_mgr()->BuildInstructionAndAddToModule( - reg_vec_const, pos); + auto reg_vec_const = const_mgr->RegisterConstant(std::move(new_vec_const)); + return const_mgr->BuildInstructionAndAddToModule(reg_vec_const, pos); } else { // Cannot process invalid component wise operation. The result of component // wise operation must be of integer or bool scalar or vector of diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp index 7eddf7e9..c98a44c3 100644 --- a/test/opt/fold_spec_const_op_composite_test.cpp +++ b/test/opt/fold_spec_const_op_composite_test.cpp @@ -105,6 +105,209 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, builder.GetCode(), builder.GetCode(), /* skip_nop = */ true); } +// Test where OpSpecConstantOp depends on another OpSpecConstantOp with +// CompositeExtract +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedCompositeExtract) { + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals({ + // clang-format off + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", + // Folding target: + "%composite_0 = OpSpecConstantComposite %v3uint %uint_2 %uint_3 %uint_2", + "%op_0 = OpSpecConstantOp %uint CompositeExtract %composite_0 0", + "%op_1 = OpSpecConstantOp %uint CompositeExtract %composite_0 1", + "%op_2 = OpSpecConstantOp %uint IMul %op_0 %op_1", + "%composite_1 = OpSpecConstantComposite %v3uint %op_0 %op_1 %op_2", + "%op_3 = OpSpecConstantOp %uint CompositeExtract %composite_1 0", + "%op_4 = OpSpecConstantOp %uint IMul %op_2 %op_3", + // clang-format on + }); + + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %uint \"uint\"", + "OpName %v3uint \"v3uint\"", + "OpName %uint_2 \"uint_2\"", + "OpName %uint_3 \"uint_3\"", + "OpName %composite_0 \"composite_0\"", + "OpName %op_0 \"op_0\"", + "OpName %op_1 \"op_1\"", + "OpName %op_2 \"op_2\"", + "OpName %composite_1 \"composite_1\"", + "OpName %op_3 \"op_3\"", + "OpName %op_4 \"op_4\"", + "%void = OpTypeVoid", +"%main_func_type = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", +"%composite_0 = OpConstantComposite %v3uint %uint_2 %uint_3 %uint_2", + "%op_0 = OpConstant %uint 2", + "%op_1 = OpConstant %uint 3", + "%op_2 = OpConstant %uint 6", +"%composite_1 = OpConstantComposite %v3uint %op_0 %op_1 %op_2", +"%op_3 = OpConstant %uint 2", + "%op_4 = OpConstant %uint 12", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); +} + +// Test where OpSpecConstantOp depends on another OpSpecConstantOp with +// VectorShuffle +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedVectorShuffle) { + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals({ + // clang-format off + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", + "%uint_4 = OpConstant %uint 4", + "%uint_5 = OpConstant %uint 5", + "%uint_6 = OpConstant %uint 6", + // Folding target: + "%composite_0 = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_3", + "%composite_1 = OpSpecConstantComposite %v3uint %uint_4 %uint_5 %uint_6", + "%vecshuffle = OpSpecConstantOp %v3uint VectorShuffle %composite_0 %composite_1 0 5 3", + "%op = OpSpecConstantOp %uint CompositeExtract %vecshuffle 1", + // clang-format on + }); + + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %uint \"uint\"", + "OpName %v3uint \"v3uint\"", + "OpName %uint_1 \"uint_1\"", + "OpName %uint_2 \"uint_2\"", + "OpName %uint_3 \"uint_3\"", + "OpName %uint_4 \"uint_4\"", + "OpName %uint_5 \"uint_5\"", + "OpName %uint_6 \"uint_6\"", + "OpName %composite_0 \"composite_0\"", + "OpName %composite_1 \"composite_1\"", + "OpName %vecshuffle \"vecshuffle\"", + "OpName %op \"op\"", + "%void = OpTypeVoid", +"%main_func_type = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", + "%uint_4 = OpConstant %uint 4", + "%uint_5 = OpConstant %uint 5", + "%uint_6 = OpConstant %uint 6", +"%composite_0 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3", +"%composite_1 = OpConstantComposite %v3uint %uint_4 %uint_5 %uint_6", +"%vecshuffle = OpConstantComposite %v3uint %uint_1 %uint_6 %uint_4", + "%op = OpConstant %uint 6", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); +} + +// Test CompositeExtract with matrix +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) { + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals({ + // clang-format off + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%mat3x3 = OpTypeMatrix %v3uint 3", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", + // Folding target: + "%a = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_1", + "%b = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_3", + "%c = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_1", + "%op = OpSpecConstantComposite %mat3x3 %a %b %c", + "%x = OpSpecConstantOp %uint CompositeExtract %op 2 1", + "%y = OpSpecConstantOp %uint CompositeExtract %op 1 2", + // clang-format on + }); + + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %uint \"uint\"", + "OpName %v3uint \"v3uint\"", + "OpName %mat3x3 \"mat3x3\"", + "OpName %uint_1 \"uint_1\"", + "OpName %uint_2 \"uint_2\"", + "OpName %uint_3 \"uint_3\"", + "OpName %a \"a\"", + "OpName %b \"b\"", + "OpName %c \"c\"", + "OpName %op \"op\"", + "OpName %x \"x\"", + "OpName %y \"y\"", + "%void = OpTypeVoid", +"%main_func_type = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%v3uint = OpTypeVector %uint 3", + "%mat3x3 = OpTypeMatrix %v3uint 3", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", + "%uint_3 = OpConstant %uint 3", + "%a = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1", + "%b = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_3", + "%c = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_1", + "%op = OpConstantComposite %mat3x3 %a %b %c", + "%x = OpConstant %uint 2", + "%y = OpConstant %uint 3", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); +} + // All types and some common constants that are potentially required in // FoldSpecConstantOpAndCompositeTest. std::vector CommonTypesAndConstants() { @@ -199,7 +402,7 @@ std::string StripOpNameInstructions(const std::string& str) { struct FoldSpecConstantOpAndCompositePassTestCase { // Original constants with unfolded spec constants. std::vector original; - // Expected cosntants after folding. + // Expected constant after folding. std::vector expected; };