Only propagate WaveSensitive when target BB not post dom current BB. (#1648)
This commit is contained in:
Родитель
3f671b38fd
Коммит
0df5e31b43
|
@ -19,13 +19,14 @@ class FunctionPass;
|
|||
class Instruction;
|
||||
class PassRegistry;
|
||||
class StringRef;
|
||||
struct PostDominatorTree;
|
||||
}
|
||||
|
||||
namespace hlsl {
|
||||
class DxilResourceBase;
|
||||
class WaveSensitivityAnalysis {
|
||||
public:
|
||||
static WaveSensitivityAnalysis* create();
|
||||
static WaveSensitivityAnalysis* create(llvm::PostDominatorTree &PDT);
|
||||
virtual ~WaveSensitivityAnalysis() { }
|
||||
virtual void Analyze(llvm::Function *F) = 0;
|
||||
virtual bool IsWaveSensitive(llvm::Instruction *op) = 0;
|
||||
|
|
|
@ -2679,7 +2679,9 @@ static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<
|
|||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create());
|
||||
PostDominatorTree PDT;
|
||||
PDT.runOnFunction(*F);
|
||||
std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create(PDT));
|
||||
WaveVal->Analyze(F);
|
||||
for (CallInst *op : ops) {
|
||||
if (WaveVal->IsWaveSensitive(op)) {
|
||||
|
|
|
@ -31,6 +31,8 @@
|
|||
#include "llvm/IR/DiagnosticInfo.h"
|
||||
#include "llvm/IR/DiagnosticPrinter.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/Analysis/PostDominators.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winerror.h>
|
||||
#endif
|
||||
|
@ -42,6 +44,14 @@ using namespace std;
|
|||
|
||||
namespace hlsl {
|
||||
|
||||
// WaveSensitivityAnalysis is created to validate Gradient operations.
|
||||
// Gradient operations require all neighbor lanes to be active when calculated,
|
||||
// compiler will enable lanes to meet this requirement. If a wave operation
|
||||
// contributed to gradient operation, it will get unexpected result because the
|
||||
// active lanes are modified.
|
||||
// To avoid unexpected result, validation will fail if gradient operations
|
||||
// are dependent on wave-sensitive data or control flow.
|
||||
|
||||
class WaveSensitivityAnalyzer : public WaveSensitivityAnalysis {
|
||||
private:
|
||||
enum WaveSensitivity {
|
||||
|
@ -49,6 +59,7 @@ private:
|
|||
KnownNotSensitive,
|
||||
Unknown
|
||||
};
|
||||
PostDominatorTree *pPDT;
|
||||
map<Instruction *, WaveSensitivity> InstState;
|
||||
map<BasicBlock *, WaveSensitivity> BBState;
|
||||
std::vector<Instruction *> InstWorkList;
|
||||
|
@ -59,12 +70,13 @@ private:
|
|||
void UpdateInst(Instruction *I, WaveSensitivity WS);
|
||||
void VisitInst(Instruction *I);
|
||||
public:
|
||||
WaveSensitivityAnalyzer(PostDominatorTree &PDT) : pPDT(&PDT) {}
|
||||
void Analyze(Function *F);
|
||||
bool IsWaveSensitive(Instruction *op);
|
||||
};
|
||||
|
||||
WaveSensitivityAnalysis* WaveSensitivityAnalysis::create() {
|
||||
return new WaveSensitivityAnalyzer();
|
||||
WaveSensitivityAnalysis* WaveSensitivityAnalysis::create(PostDominatorTree &PDT) {
|
||||
return new WaveSensitivityAnalyzer(PDT);
|
||||
}
|
||||
|
||||
void WaveSensitivityAnalyzer::Analyze(Function *F) {
|
||||
|
@ -132,9 +144,14 @@ void WaveSensitivityAnalyzer::UpdateInst(Instruction *I, WaveSensitivity WS) {
|
|||
InstState[I] = WS;
|
||||
InstWorkList.push_back(I);
|
||||
if (TerminatorInst * TI = dyn_cast<TerminatorInst>(I)) {
|
||||
BasicBlock *CurBB = TI->getParent();
|
||||
for (unsigned i = 0; i < TI->getNumSuccessors(); ++i) {
|
||||
BasicBlock *BB = TI->getSuccessor(i);
|
||||
UpdateBlock(BB, WS);
|
||||
// Only propagate WS when BB not post dom CurBB.
|
||||
WaveSensitivity TmpWS = pPDT->properlyDominates(BB, CurBB)
|
||||
? WaveSensitivity::KnownNotSensitive
|
||||
: WS;
|
||||
UpdateBlock(BB, TmpWS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -153,11 +170,24 @@ void WaveSensitivityAnalyzer::VisitInst(Instruction *I) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
if (CheckBBState(I->getParent(), KnownSensitive)) {
|
||||
UpdateInst(I, KnownSensitive);
|
||||
return;
|
||||
}
|
||||
|
||||
// Catch control flow wave sensitive for phi.
|
||||
if (PHINode *Phi = dyn_cast<PHINode>(I)) {
|
||||
for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) {
|
||||
BasicBlock *BB = Phi->getIncomingBlock(i);
|
||||
WaveSensitivity WS = GetInstState(BB->getTerminator());
|
||||
if (WS == KnownSensitive) {
|
||||
UpdateInst(I, KnownSensitive);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool allKnownNotSensitive = true;
|
||||
for (unsigned i = firstArg; i < I->getNumOperands(); ++i) {
|
||||
Value *V = I->getOperand(i);
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
|
||||
|
||||
// CHECK: Sin
|
||||
|
||||
float main ( uint mask:M, float a:A) : SV_Target
|
||||
{
|
||||
float r = a;
|
||||
mask = WaveActiveBitOr ( mask ) ;
|
||||
if (mask & 0xf) {
|
||||
r += sin(r);
|
||||
}
|
||||
|
||||
float dd = ddx(a);
|
||||
return r + dd;
|
||||
}
|
Загрузка…
Ссылка в новой задаче