diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index f348bbe3..01ed5b88 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -629,12 +629,39 @@ bool InlinePass::GenInlineCode( return true; } -bool InlinePass::IsInlinableFunctionCall(const Instruction* inst) { +bool InlinePass::IsInlinableFunctionCall(Instruction* inst) { if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false; const uint32_t calleeFnId = inst->GetSingleWordOperand(kSpvFunctionCallFunctionId); const auto ci = inlinable_.find(calleeFnId); - return ci != inlinable_.cend(); + if (ci == inlinable_.cend()) { + return false; + } + + if (funcs_with_opkill_.count(calleeFnId) == 0) { + return true; + } + + // We cannot inline into a continue construct if the function has an OpKill. + auto* cfg_analysis = context()->GetStructuredCFGAnalysis(); + BasicBlock* bb = context()->get_instr_block(inst); + uint32_t loop_header_id = cfg_analysis->ContainingLoop(bb->id()); + if (loop_header_id == 0) { + // Not in a loop, so we can inline. + return true; + } + BasicBlock* loop_header_bb = context()->get_instr_block(loop_header_id); + uint32_t loop_continue = + loop_header_bb->GetLoopMergeInst()->GetSingleWordOperand(1); + + Function* caller_func = bb->GetParent(); + DominatorAnalysis* dom = context()->GetDominatorAnalysis(caller_func); + if (dom->Dominates(loop_continue, bb->id())) { + // The function call is the continue construct and the callee contains an + // OpKill. + return false; + } + return true; } void InlinePass::UpdateSucceedingPhis( @@ -711,6 +738,9 @@ bool InlinePass::IsInlinableFunction(Function* func) { // the returns as a branch to the loop's merge block. However, this can only // done validly if the return was not in a loop in the original function. // Also remember functions with multiple (early) returns. + + // Do not inline functions with an OpKill because they may be inlined into a + // continue construct. AnalyzeReturns(func); if (no_return_in_loop_.find(func->result_id()) == no_return_in_loop_.cend()) { return false; @@ -741,6 +771,13 @@ void InlinePass::InitializeInline() { } // Compute inlinability if (IsInlinableFunction(&fn)) inlinable_.insert(fn.result_id()); + + bool has_opkill = !fn.WhileEachInst( + [](Instruction* inst) { return inst->opcode() != SpvOpKill; }); + + if (has_opkill) { + funcs_with_opkill_.insert(fn.result_id()); + } } } diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h index ecfe964f..e17dddb6 100644 --- a/source/opt/inline_pass.h +++ b/source/opt/inline_pass.h @@ -122,7 +122,7 @@ class InlinePass : public Pass { UptrVectorIterator call_block_itr); // Return true if |inst| is a function call that can be inlined. - bool IsInlinableFunctionCall(const Instruction* inst); + bool IsInlinableFunctionCall(Instruction* inst); // Return true if |func| does not have a return that is // nested in a structured if, switch or loop. @@ -159,6 +159,9 @@ class InlinePass : public Pass { // Set of ids of functions with no returns in loop std::set no_return_in_loop_; + // Set of ids of functions with no returns in loop + std::unordered_set funcs_with_opkill_; + // Set of ids of inlinable functions std::set inlinable_; diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp index 71708121..fd38e477 100644 --- a/test/opt/inline_test.cpp +++ b/test/opt/inline_test.cpp @@ -3112,6 +3112,121 @@ OpFunctionEnd SinglePassRunAndCheck(test, test, false, true); } +TEST_F(InlineTest, DontInlineFuncWithOpKill) { + const std::string test = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %11 %12 None +OpBranch %13 +%13 = OpLabel +OpBranchConditional %true %10 %11 +%10 = OpLabel +OpBranch %12 +%12 = OpLabel +%16 = OpFunctionCall %void %kill_ +OpBranch %9 +%11 = OpLabel +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpKill +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, false, true); +} + +TEST_F(InlineTest, InlineFuncWithOpKill) { + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %11 %12 None +OpBranch %13 +%13 = OpLabel +OpBranchConditional %true %10 %11 +%10 = OpLabel +%16 = OpFunctionCall %void %kill_ +OpBranch %12 +%12 = OpLabel +OpBranch %9 +%11 = OpLabel +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpKill +OpFunctionEnd +)"; + const std::string after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %11 %12 None +OpBranch %13 +%13 = OpLabel +OpBranchConditional %true %10 %11 +%10 = OpLabel +OpKill +%17 = OpLabel +OpBranch %12 +%12 = OpLabel +OpBranch %9 +%11 = OpLabel +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpKill +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, false, true); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Empty modules