Only propagate WaveSensitive when target BB not post dom current BB. (#1648)

This commit is contained in:
Xiang Li 2018-10-31 16:30:27 -07:00 коммит произвёл GitHub
Родитель 3f671b38fd
Коммит 0df5e31b43
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 53 добавлений и 5 удалений

Просмотреть файл

@ -19,13 +19,14 @@ class FunctionPass;
class Instruction;
class PassRegistry;
class StringRef;
struct PostDominatorTree;
}
namespace hlsl {
class DxilResourceBase;
class WaveSensitivityAnalysis {
public:
static WaveSensitivityAnalysis* create();
static WaveSensitivityAnalysis* create(llvm::PostDominatorTree &PDT);
virtual ~WaveSensitivityAnalysis() { }
virtual void Analyze(llvm::Function *F) = 0;
virtual bool IsWaveSensitive(llvm::Instruction *op) = 0;

Просмотреть файл

@ -2679,7 +2679,9 @@ static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<
return;
}
std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create());
PostDominatorTree PDT;
PDT.runOnFunction(*F);
std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create(PDT));
WaveVal->Analyze(F);
for (CallInst *op : ops) {
if (WaveVal->IsWaveSensitive(op)) {

Просмотреть файл

@ -31,6 +31,8 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/Analysis/PostDominators.h"
#ifdef _WIN32
#include <winerror.h>
#endif
@ -42,6 +44,14 @@ using namespace std;
namespace hlsl {
// WaveSensitivityAnalysis is created to validate Gradient operations.
// Gradient operations require all neighbor lanes to be active when calculated,
// compiler will enable lanes to meet this requirement. If a wave operation
// contributed to gradient operation, it will get unexpected result because the
// active lanes are modified.
// To avoid unexpected result, validation will fail if gradient operations
// are dependent on wave-sensitive data or control flow.
class WaveSensitivityAnalyzer : public WaveSensitivityAnalysis {
private:
enum WaveSensitivity {
@ -49,6 +59,7 @@ private:
KnownNotSensitive,
Unknown
};
PostDominatorTree *pPDT;
map<Instruction *, WaveSensitivity> InstState;
map<BasicBlock *, WaveSensitivity> BBState;
std::vector<Instruction *> InstWorkList;
@ -59,12 +70,13 @@ private:
void UpdateInst(Instruction *I, WaveSensitivity WS);
void VisitInst(Instruction *I);
public:
WaveSensitivityAnalyzer(PostDominatorTree &PDT) : pPDT(&PDT) {}
void Analyze(Function *F);
bool IsWaveSensitive(Instruction *op);
};
WaveSensitivityAnalysis* WaveSensitivityAnalysis::create() {
return new WaveSensitivityAnalyzer();
WaveSensitivityAnalysis* WaveSensitivityAnalysis::create(PostDominatorTree &PDT) {
return new WaveSensitivityAnalyzer(PDT);
}
void WaveSensitivityAnalyzer::Analyze(Function *F) {
@ -132,9 +144,14 @@ void WaveSensitivityAnalyzer::UpdateInst(Instruction *I, WaveSensitivity WS) {
InstState[I] = WS;
InstWorkList.push_back(I);
if (TerminatorInst * TI = dyn_cast<TerminatorInst>(I)) {
BasicBlock *CurBB = TI->getParent();
for (unsigned i = 0; i < TI->getNumSuccessors(); ++i) {
BasicBlock *BB = TI->getSuccessor(i);
UpdateBlock(BB, WS);
// Only propagate WS when BB not post dom CurBB.
WaveSensitivity TmpWS = pPDT->properlyDominates(BB, CurBB)
? WaveSensitivity::KnownNotSensitive
: WS;
UpdateBlock(BB, TmpWS);
}
}
}
@ -153,11 +170,24 @@ void WaveSensitivityAnalyzer::VisitInst(Instruction *I) {
}
}
if (CheckBBState(I->getParent(), KnownSensitive)) {
UpdateInst(I, KnownSensitive);
return;
}
// Catch control flow wave sensitive for phi.
if (PHINode *Phi = dyn_cast<PHINode>(I)) {
for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) {
BasicBlock *BB = Phi->getIncomingBlock(i);
WaveSensitivity WS = GetInstState(BB->getTerminator());
if (WS == KnownSensitive) {
UpdateInst(I, KnownSensitive);
return;
}
}
}
bool allKnownNotSensitive = true;
for (unsigned i = firstArg; i < I->getNumOperands(); ++i) {
Value *V = I->getOperand(i);

Просмотреть файл

@ -0,0 +1,15 @@
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
// CHECK: Sin
float main ( uint mask:M, float a:A) : SV_Target
{
float r = a;
mask = WaveActiveBitOr ( mask ) ;
if (mask & 0xf) {
r += sin(r);
}
float dd = ddx(a);
return r + dd;
}