spirv-opt: Fix stacked CompositeExtract constant folds (#4932)

This was spotted in the Validation Layers where OpSpecConstantOp %x CompositeExtract %y 0 was being folded to a constant, but anything that was using it wasn't recognizing it as a constant, the simple fix was to add a const_mgr->MapInst(new_const_inst); so the next instruction knew it was a const
This commit is contained in:
Spencer Fricke 2022-09-23 21:45:11 +09:00 коммит произвёл GitHub
Родитель f98473ceeb
Коммит ddbee48f85
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 221 добавлений и 21 удалений

Просмотреть файл

@ -627,8 +627,7 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const { Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
if (!inst->IsFoldableByFoldScalar() && if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) {
!GetConstantFoldingRules().HasFoldingRule(inst)) {
return nullptr; return nullptr;
} }
// Collect the values of the constant parameters. // Collect the values of the constant parameters.

Просмотреть файл

@ -28,6 +28,7 @@ namespace opt {
Pass::Status FoldSpecConstantOpAndCompositePass::Process() { Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
bool modified = false; bool modified = false;
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
// Traverse through all the constant defining instructions. For Normal // Traverse through all the constant defining instructions. For Normal
// Constants whose values are determined and do not depend on OpUndef // Constants whose values are determined and do not depend on OpUndef
// instructions, records their values in two internal maps: id_to_const_val_ // instructions, records their values in two internal maps: id_to_const_val_
@ -62,8 +63,8 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
// used in OpSpecConstant{Composite|Op} instructions. // used in OpSpecConstant{Composite|Op} instructions.
// TODO(qining): If the constant or its type has decoration, we may need // TODO(qining): If the constant or its type has decoration, we may need
// to skip it. // to skip it.
if (context()->get_constant_mgr()->GetType(inst) && if (const_mgr->GetType(inst) &&
!context()->get_constant_mgr()->GetType(inst)->decoration_empty()) !const_mgr->GetType(inst)->decoration_empty())
continue; continue;
switch (SpvOp opcode = inst->opcode()) { switch (SpvOp opcode = inst->opcode()) {
// Records the values of Normal Constants. // Records the values of Normal Constants.
@ -80,15 +81,14 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
// Constant will be turned in to a Normal Constant. In that case, a // Constant will be turned in to a Normal Constant. In that case, a
// Constant instance should also be created successfully and recorded // Constant instance should also be created successfully and recorded
// in the id_to_const_val_ and const_val_to_id_ mapps. // in the id_to_const_val_ and const_val_to_id_ mapps.
if (auto const_value = if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
context()->get_constant_mgr()->GetConstantFromInst(inst)) {
// Need to replace the OpSpecConstantComposite instruction with a // Need to replace the OpSpecConstantComposite instruction with a
// corresponding OpConstantComposite instruction. // corresponding OpConstantComposite instruction.
if (opcode == SpvOp::SpvOpSpecConstantComposite) { if (opcode == SpvOp::SpvOpSpecConstantComposite) {
inst->SetOpcode(SpvOp::SpvOpConstantComposite); inst->SetOpcode(SpvOp::SpvOpConstantComposite);
modified = true; modified = true;
} }
context()->get_constant_mgr()->MapConstantToInst(const_value, inst); const_mgr->MapConstantToInst(const_value, inst);
} }
break; break;
} }
@ -146,6 +146,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder( Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
Module::inst_iterator* inst_iter_ptr) { Module::inst_iterator* inst_iter_ptr) {
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
// If one of operands to the instruction is not a // If one of operands to the instruction is not a
// constant, then we cannot fold this spec constant. // constant, then we cannot fold this spec constant.
for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) { for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
@ -155,7 +156,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
continue; continue;
} }
uint32_t id = operand.words[0]; uint32_t id = operand.words[0];
if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) { if (const_mgr->FindDeclaredConstant(id) == nullptr) {
return nullptr; return nullptr;
} }
} }
@ -202,6 +203,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
new_const_inst->InsertAfter(insert_pos); new_const_inst->InsertAfter(insert_pos);
get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst); get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
} }
const_mgr->MapInst(new_const_inst);
return new_const_inst; return new_const_inst;
} }
@ -285,8 +287,8 @@ utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
Module::inst_iterator* pos) { Module::inst_iterator* pos) {
const Instruction* inst = &**pos; const Instruction* inst = &**pos;
const analysis::Type* result_type = analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
context()->get_constant_mgr()->GetType(inst); const analysis::Type* result_type = const_mgr->GetType(inst);
SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0)); SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
// Check and collect operands. // Check and collect operands.
std::vector<const analysis::Constant*> operands; std::vector<const analysis::Constant*> operands;
@ -311,10 +313,9 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
// Scalar operation // Scalar operation
const uint32_t result_val = const uint32_t result_val =
context()->get_instruction_folder().FoldScalars(spec_opcode, operands); context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
auto result_const = context()->get_constant_mgr()->GetConstant( auto result_const = const_mgr->GetConstant(
result_type, EncodeIntegerAsWords(*result_type, result_val)); result_type, EncodeIntegerAsWords(*result_type, result_val));
return context()->get_constant_mgr()->BuildInstructionAndAddToModule( return const_mgr->BuildInstructionAndAddToModule(result_const, pos);
result_const, pos);
} else if (result_type->AsVector()) { } else if (result_type->AsVector()) {
// Vector operation // Vector operation
const analysis::Type* element_type = const analysis::Type* element_type =
@ -325,11 +326,10 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
operands); operands);
std::vector<const analysis::Constant*> result_vector_components; std::vector<const analysis::Constant*> result_vector_components;
for (const uint32_t r : result_vec) { for (const uint32_t r : result_vec) {
if (auto rc = context()->get_constant_mgr()->GetConstant( if (auto rc = const_mgr->GetConstant(
element_type, EncodeIntegerAsWords(*element_type, r))) { element_type, EncodeIntegerAsWords(*element_type, r))) {
result_vector_components.push_back(rc); result_vector_components.push_back(rc);
if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule( if (!const_mgr->BuildInstructionAndAddToModule(rc, pos)) {
rc, pos)) {
assert(false && assert(false &&
"Failed to build and insert constant declaring instruction " "Failed to build and insert constant declaring instruction "
"for the given vector component constant"); "for the given vector component constant");
@ -340,10 +340,8 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
} }
auto new_vec_const = MakeUnique<analysis::VectorConstant>( auto new_vec_const = MakeUnique<analysis::VectorConstant>(
result_type->AsVector(), result_vector_components); result_type->AsVector(), result_vector_components);
auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant( auto reg_vec_const = const_mgr->RegisterConstant(std::move(new_vec_const));
std::move(new_vec_const)); return const_mgr->BuildInstructionAndAddToModule(reg_vec_const, pos);
return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
reg_vec_const, pos);
} else { } else {
// Cannot process invalid component wise operation. The result of component // Cannot process invalid component wise operation. The result of component
// wise operation must be of integer or bool scalar or vector of // wise operation must be of integer or bool scalar or vector of

Просмотреть файл

@ -105,6 +105,209 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
builder.GetCode(), builder.GetCode(), /* skip_nop = */ true); builder.GetCode(), builder.GetCode(), /* skip_nop = */ true);
} }
// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
// CompositeExtract
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedCompositeExtract) {
AssemblyBuilder builder;
builder.AppendTypesConstantsGlobals({
// clang-format off
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
// Folding target:
"%composite_0 = OpSpecConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
"%op_0 = OpSpecConstantOp %uint CompositeExtract %composite_0 0",
"%op_1 = OpSpecConstantOp %uint CompositeExtract %composite_0 1",
"%op_2 = OpSpecConstantOp %uint IMul %op_0 %op_1",
"%composite_1 = OpSpecConstantComposite %v3uint %op_0 %op_1 %op_2",
"%op_3 = OpSpecConstantOp %uint CompositeExtract %composite_1 0",
"%op_4 = OpSpecConstantOp %uint IMul %op_2 %op_3",
// clang-format on
});
std::vector<const char*> expected = {
// clang-format off
"OpCapability Shader",
"OpCapability Float64",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %void \"void\"",
"OpName %main_func_type \"main_func_type\"",
"OpName %main \"main\"",
"OpName %main_func_entry_block \"main_func_entry_block\"",
"OpName %uint \"uint\"",
"OpName %v3uint \"v3uint\"",
"OpName %uint_2 \"uint_2\"",
"OpName %uint_3 \"uint_3\"",
"OpName %composite_0 \"composite_0\"",
"OpName %op_0 \"op_0\"",
"OpName %op_1 \"op_1\"",
"OpName %op_2 \"op_2\"",
"OpName %composite_1 \"composite_1\"",
"OpName %op_3 \"op_3\"",
"OpName %op_4 \"op_4\"",
"%void = OpTypeVoid",
"%main_func_type = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
"%composite_0 = OpConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
"%op_0 = OpConstant %uint 2",
"%op_1 = OpConstant %uint 3",
"%op_2 = OpConstant %uint 6",
"%composite_1 = OpConstantComposite %v3uint %op_0 %op_1 %op_2",
"%op_3 = OpConstant %uint 2",
"%op_4 = OpConstant %uint 12",
"%main = OpFunction %void None %main_func_type",
"%main_func_entry_block = OpLabel",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
}
// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
// VectorShuffle
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedVectorShuffle) {
AssemblyBuilder builder;
builder.AppendTypesConstantsGlobals({
// clang-format off
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%uint_1 = OpConstant %uint 1",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
"%uint_4 = OpConstant %uint 4",
"%uint_5 = OpConstant %uint 5",
"%uint_6 = OpConstant %uint 6",
// Folding target:
"%composite_0 = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
"%composite_1 = OpSpecConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
"%vecshuffle = OpSpecConstantOp %v3uint VectorShuffle %composite_0 %composite_1 0 5 3",
"%op = OpSpecConstantOp %uint CompositeExtract %vecshuffle 1",
// clang-format on
});
std::vector<const char*> expected = {
// clang-format off
"OpCapability Shader",
"OpCapability Float64",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %void \"void\"",
"OpName %main_func_type \"main_func_type\"",
"OpName %main \"main\"",
"OpName %main_func_entry_block \"main_func_entry_block\"",
"OpName %uint \"uint\"",
"OpName %v3uint \"v3uint\"",
"OpName %uint_1 \"uint_1\"",
"OpName %uint_2 \"uint_2\"",
"OpName %uint_3 \"uint_3\"",
"OpName %uint_4 \"uint_4\"",
"OpName %uint_5 \"uint_5\"",
"OpName %uint_6 \"uint_6\"",
"OpName %composite_0 \"composite_0\"",
"OpName %composite_1 \"composite_1\"",
"OpName %vecshuffle \"vecshuffle\"",
"OpName %op \"op\"",
"%void = OpTypeVoid",
"%main_func_type = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%uint_1 = OpConstant %uint 1",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
"%uint_4 = OpConstant %uint 4",
"%uint_5 = OpConstant %uint 5",
"%uint_6 = OpConstant %uint 6",
"%composite_0 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
"%composite_1 = OpConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
"%vecshuffle = OpConstantComposite %v3uint %uint_1 %uint_6 %uint_4",
"%op = OpConstant %uint 6",
"%main = OpFunction %void None %main_func_type",
"%main_func_entry_block = OpLabel",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
}
// Test CompositeExtract with matrix
TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) {
AssemblyBuilder builder;
builder.AppendTypesConstantsGlobals({
// clang-format off
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%mat3x3 = OpTypeMatrix %v3uint 3",
"%uint_1 = OpConstant %uint 1",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
// Folding target:
"%a = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
"%b = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
"%c = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
"%op = OpSpecConstantComposite %mat3x3 %a %b %c",
"%x = OpSpecConstantOp %uint CompositeExtract %op 2 1",
"%y = OpSpecConstantOp %uint CompositeExtract %op 1 2",
// clang-format on
});
std::vector<const char*> expected = {
// clang-format off
"OpCapability Shader",
"OpCapability Float64",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %void \"void\"",
"OpName %main_func_type \"main_func_type\"",
"OpName %main \"main\"",
"OpName %main_func_entry_block \"main_func_entry_block\"",
"OpName %uint \"uint\"",
"OpName %v3uint \"v3uint\"",
"OpName %mat3x3 \"mat3x3\"",
"OpName %uint_1 \"uint_1\"",
"OpName %uint_2 \"uint_2\"",
"OpName %uint_3 \"uint_3\"",
"OpName %a \"a\"",
"OpName %b \"b\"",
"OpName %c \"c\"",
"OpName %op \"op\"",
"OpName %x \"x\"",
"OpName %y \"y\"",
"%void = OpTypeVoid",
"%main_func_type = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%v3uint = OpTypeVector %uint 3",
"%mat3x3 = OpTypeMatrix %v3uint 3",
"%uint_1 = OpConstant %uint 1",
"%uint_2 = OpConstant %uint 2",
"%uint_3 = OpConstant %uint 3",
"%a = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
"%b = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
"%c = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
"%op = OpConstantComposite %mat3x3 %a %b %c",
"%x = OpConstant %uint 2",
"%y = OpConstant %uint 3",
"%main = OpFunction %void None %main_func_type",
"%main_func_entry_block = OpLabel",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
}
// All types and some common constants that are potentially required in // All types and some common constants that are potentially required in
// FoldSpecConstantOpAndCompositeTest. // FoldSpecConstantOpAndCompositeTest.
std::vector<std::string> CommonTypesAndConstants() { std::vector<std::string> CommonTypesAndConstants() {
@ -199,7 +402,7 @@ std::string StripOpNameInstructions(const std::string& str) {
struct FoldSpecConstantOpAndCompositePassTestCase { struct FoldSpecConstantOpAndCompositePassTestCase {
// Original constants with unfolded spec constants. // Original constants with unfolded spec constants.
std::vector<std::string> original; std::vector<std::string> original;
// Expected cosntants after folding. // Expected constant after folding.
std::vector<std::string> expected; std::vector<std::string> expected;
}; };