From 6f9c107b78c1161cef9c44bec714e081ddbae275 Mon Sep 17 00:00:00 2001 From: Zhengxing li Date: Tue, 21 May 2024 11:55:40 -0700 Subject: [PATCH] Reassociate: add global reassociation algorithm (#6598) This PR pulls the upstream change, Reassociate: add global reassociation algorithm (https://github.com/llvm/llvm-project/commit/b8a330c42ab43879119cd3a305756d28aefe9fe6), into DXC with miminal changes. For the code below: foo = (a * b) * c bar = (a * d) * c As the upstream change states, it can identify the a*c is a common factor and redundant. This is part 1 of the fix for #6593. --- lib/Transforms/Scalar/Reassociate.cpp | 124 ++++++++++++++++++++++- test/Transforms/Reassociate/basictest.ll | 15 +++ 2 files changed, 137 insertions(+), 2 deletions(-) diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index d1acf785d..b5b0f7fa0 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -20,11 +20,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -37,6 +37,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; @@ -161,6 +162,13 @@ namespace { DenseMap RankMap; DenseMap, unsigned> ValueRankMap; SetVector > RedoInsts; + + // Arbitrary, but prevents quadratic behavior. + static const unsigned GlobalReassociateLimit = 10; + static const unsigned NumBinaryOps = + Instruction::BinaryOpsEnd - Instruction::BinaryOpsBegin; + DenseMap, unsigned> PairMap[NumBinaryOps]; + bool MadeChange; public: static char ID; // Pass identification, replacement for typeid @@ -196,6 +204,7 @@ namespace { void EraseInst(Instruction *I); void OptimizeInst(Instruction *I); Instruction *canonicalizeNegConstExpr(Instruction *I); + void BuildPairMap(ReversePostOrderTraversal &RPOT); }; } @@ -2234,11 +2243,103 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { return; } + if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) { + // Find the pair with the highest count in the pairmap and move it to the + // back of the list so that it can later be CSE'd. + // example: + // a*b*c*d*e + // if c*e is the most "popular" pair, we can express this as + // (((c*e)*d)*b)*a + unsigned Max = 1; + unsigned BestRank = 0; + std::pair BestPair; + unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; + for (unsigned i = 0; i < Ops.size() - 1; ++i) + for (unsigned j = i + 1; j < Ops.size(); ++j) { + unsigned Score = 0; + Value *Op0 = Ops[i].Op; + Value *Op1 = Ops[j].Op; + if (std::less()(Op1, Op0)) + std::swap(Op0, Op1); + auto it = PairMap[Idx].find({Op0, Op1}); + if (it != PairMap[Idx].end()) + Score += it->second; + + unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); + if (Score > Max || (Score == Max && MaxRank < BestRank)) { + BestPair = {i, j}; + Max = Score; + BestRank = MaxRank; + } + } + if (Max > 1) { + auto Op0 = Ops[BestPair.first]; + auto Op1 = Ops[BestPair.second]; + Ops.erase(&Ops[BestPair.second]); + Ops.erase(&Ops[BestPair.first]); + Ops.push_back(Op0); + Ops.push_back(Op1); + } + } // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. RewriteExprTree(I, Ops); } +void Reassociate::BuildPairMap(ReversePostOrderTraversal &RPOT) { + // Make a "pairmap" of how often each operand pair occurs. + for (BasicBlock *BI : RPOT) { + for (Instruction &I : *BI) { + if (!I.isAssociative()) + continue; + + // Ignore nodes that aren't at the root of trees. + if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode()) + continue; + + // Collect all operands in a single reassociable expression. + // Since Reassociate has already been run once, we can assume things + // are already canonical according to Reassociation's regime. + SmallVector Worklist = {I.getOperand(0), I.getOperand(1)}; + SmallVector Ops; + while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) { + Value *Op = Worklist.pop_back_val(); + Instruction *OpI = dyn_cast(Op); + if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) { + Ops.push_back(Op); + continue; + } + // Be paranoid about self-referencing expressions in unreachable code. + if (OpI->getOperand(0) != OpI) + Worklist.push_back(OpI->getOperand(0)); + if (OpI->getOperand(1) != OpI) + Worklist.push_back(OpI->getOperand(1)); + } + // Skip extremely long expressions. + if (Ops.size() > GlobalReassociateLimit) + continue; + + // Add all pairwise combinations of operands to the pair map. + unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin; + SmallSet, 32> Visited; + for (unsigned i = 0; i < Ops.size() - 1; ++i) { + for (unsigned j = i + 1; j < Ops.size(); ++j) { + // Canonicalize operand orderings. + Value *Op0 = Ops[i]; + Value *Op1 = Ops[j]; + if (std::less()(Op1, Op0)) + std::swap(Op0, Op1); + if (!Visited.insert({Op0, Op1}).second) + continue; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); + if (!res.second) + ++res.first->second; + } + } + } + } +} + bool Reassociate::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; @@ -2246,6 +2347,23 @@ bool Reassociate::runOnFunction(Function &F) { // Calculate the rank map for F BuildRankMap(F); + // Build the pair map before running reassociate. + // Technically this would be more accurate if we did it after one round + // of reassociation, but in practice it doesn't seem to help much on + // real-world code, so don't waste the compile time running reassociate + // twice. + // If a user wants, they could expicitly run reassociate twice in their + // pass pipeline for further potential gains. + // It might also be possible to update the pair map during runtime, but the + // overhead of that may be large if there's many reassociable chains. + // TODO: RPOT + // Get the functions basic blocks in Reverse Post Order. This order is used by + // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic + // blocks (it has been seen that the analysis in this pass could hang when + // analysing dead basic blocks). + ReversePostOrderTraversal RPOT(&F); + BuildPairMap(RPOT); + MadeChange = false; for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { // Optimize every instruction in the basic block. @@ -2268,9 +2386,11 @@ bool Reassociate::runOnFunction(Function &F) { } } - // We are done with the rank map. + // We are done with the rank map and pair map. RankMap.clear(); ValueRankMap.clear(); + for (auto &Entry : PairMap) + Entry.clear(); return MadeChange; } diff --git a/test/Transforms/Reassociate/basictest.ll b/test/Transforms/Reassociate/basictest.ll index c557017b4..277758426 100644 --- a/test/Transforms/Reassociate/basictest.ll +++ b/test/Transforms/Reassociate/basictest.ll @@ -221,3 +221,18 @@ define i32 @test15(i32 %X1, i32 %X2, i32 %X3) { ; CHECK-LABEL: @test15 ; CHECK: and i1 %A, %B } + +; CHECK-LABEL: @test17 +; CHECK: %[[A:.*]] = mul i32 %X4, %X3 +; CHECK-NEXT: %[[C:.*]] = mul i32 %[[A]], %X1 +; CHECK-NEXT: %[[D:.*]] = mul i32 %[[A]], %X2 +; CHECK-NEXT: %[[E:.*]] = xor i32 %[[C]], %[[D]] +; CHECK-NEXT: ret i32 %[[E]] +define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) { + %A = mul i32 %X3, %X1 + %B = mul i32 %X3, %X2 + %C = mul i32 %A, %X4 + %D = mul i32 %B, %X4 + %E = xor i32 %C, %D + ret i32 %E +} \ No newline at end of file