DirectXShaderCompiler/lib/HLSL/DxilConvergent.cpp

253 строки
8.0 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// //
// DxilConvergent.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// Mark convergent for hlsl. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/GenericDomTree.h"
#include "llvm/Support/raw_os_ostream.h"
#include "dxc/HLSL/DxilConstants.h"
#include "dxc/HLSL/DxilGenerationPass.h"
#include "dxc/HLSL/HLOperations.h"
#include "dxc/HLSL/HLModule.h"
#include "dxc/HlslIntrinsicOp.h"
using namespace llvm;
using namespace hlsl;
namespace {
const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
}
///////////////////////////////////////////////////////////////////////////////
// DxilConvergent.
// Mark convergent to avoid sample coordnate calculation sink into control flow.
//
namespace {
class DxilConvergentMark : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilConvergentMark() : ModulePass(ID) {}
const char *getPassName() const override {
return "DxilConvergentMark";
}
bool runOnModule(Module &M) override {
if (M.HasHLModule()) {
if (!M.GetHLModule().GetShaderModel()->IsPS())
return false;
}
bool bUpdated = false;
for (Function &F : M.functions()) {
if (F.isDeclaration())
continue;
// Compute postdominator relation.
DominatorTreeBase<BasicBlock> PDR(true);
PDR.recalculate(F);
for (BasicBlock &bb : F.getBasicBlockList()) {
for (auto it = bb.begin(); it != bb.end();) {
Instruction *I = (it++);
if (Value *V = FindConvergentOperand(I)) {
if (PropagateConvergent(V, &F, PDR)) {
// TODO: emit warning here.
}
bUpdated = true;
}
}
}
}
return bUpdated;
}
private:
void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
Value *FindConvergentOperand(Instruction *I);
bool PropagateConvergent(Value *V, Function *F,
DominatorTreeBase<BasicBlock> &PostDom);
};
char DxilConvergentMark::ID = 0;
void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
Module &M) {
Type *Ty = V->getType()->getScalarType();
// Only work on vector/scalar types.
if (Ty->isAggregateType() ||
Ty->isPointerTy())
return;
FunctionType *FT = FunctionType::get(Ty, Ty, false);
std::string str = kConvergentFunctionPrefix;
raw_string_ostream os(str);
Ty->print(os);
os.flush();
Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
ConvF->addFnAttr(Attribute::AttrKind::Convergent);
if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
Value *ConvV = UndefValue::get(V->getType());
std::vector<ExtractElementInst *> extractList(VT->getNumElements());
for (unsigned i = 0; i < VT->getNumElements(); i++) {
ExtractElementInst *EltV =
cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
extractList[i] = EltV;
Value *EltC = Builder.CreateCall(ConvF, {EltV});
ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
}
V->replaceAllUsesWith(ConvV);
for (ExtractElementInst *E : extractList) {
E->setOperand(0, V);
}
} else {
CallInst *ConvV = Builder.CreateCall(ConvF, {V});
V->replaceAllUsesWith(ConvV);
ConvV->setOperand(0, V);
}
}
bool DxilConvergentMark::PropagateConvergent(
Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
// Skip constant.
if (isa<Constant>(V))
return false;
// Skip phi which cannot sink.
if (isa<PHINode>(V))
return false;
if (Instruction *I = dyn_cast<Instruction>(V)) {
BasicBlock *BB = I->getParent();
if (PostDom.dominates(BB, &F->getEntryBlock())) {
IRBuilder<> Builder(I->getNextNode());
MarkConvergent(I, Builder, *F->getParent());
return false;
} else {
// Propagete to each operand of I.
for (Use &U : I->operands()) {
PropagateConvergent(U.get(), F, PostDom);
}
// return true for report warning.
// TODO: static indexing cbuffer is fine.
return true;
}
} else {
IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
MarkConvergent(V, EntryBuilder, *F->getParent());
return false;
}
}
Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
HLOpcodeGroup::HLIntrinsic) {
IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
switch (IOP) {
case IntrinsicOp::IOP_ddx:
case IntrinsicOp::IOP_ddx_fine:
case IntrinsicOp::IOP_ddx_coarse:
case IntrinsicOp::IOP_ddy:
case IntrinsicOp::IOP_ddy_fine:
case IntrinsicOp::IOP_ddy_coarse:
return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
case IntrinsicOp::MOP_Sample:
case IntrinsicOp::MOP_SampleBias:
case IntrinsicOp::MOP_SampleCmp:
case IntrinsicOp::MOP_SampleCmpLevelZero:
case IntrinsicOp::MOP_CalculateLevelOfDetail:
case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
case IntrinsicOp::MOP_Gather:
case IntrinsicOp::MOP_GatherAlpha:
case IntrinsicOp::MOP_GatherBlue:
case IntrinsicOp::MOP_GatherCmp:
case IntrinsicOp::MOP_GatherCmpAlpha:
case IntrinsicOp::MOP_GatherCmpBlue:
case IntrinsicOp::MOP_GatherCmpGreen:
case IntrinsicOp::MOP_GatherCmpRed:
case IntrinsicOp::MOP_GatherGreen:
case IntrinsicOp::MOP_GatherRed:
return CI->getArgOperand(HLOperandIndex::kGatherCoordArgIndex);
default:
// No other ops have convergent operands.
break;
}
}
}
return nullptr;
}
} // namespace
INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
"Mark convergent", false, false)
ModulePass *llvm::createDxilConvergentMarkPass() {
return new DxilConvergentMark();
}
namespace {
class DxilConvergentClear : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilConvergentClear() : ModulePass(ID) {}
const char *getPassName() const override {
return "DxilConvergentClear";
}
bool runOnModule(Module &M) override {
std::vector<Function *> convergentList;
for (Function &F : M.functions()) {
if (F.getName().startswith(kConvergentFunctionPrefix)) {
convergentList.emplace_back(&F);
}
}
for (Function *F : convergentList) {
ClearConvergent(F);
}
return convergentList.size();
}
private:
void ClearConvergent(Function *F);
};
char DxilConvergentClear::ID = 0;
void DxilConvergentClear::ClearConvergent(Function *F) {
// Replace all users with arg.
for (auto it = F->user_begin(); it != F->user_end();) {
CallInst *CI = cast<CallInst>(*(it++));
Value *arg = CI->getArgOperand(0);
CI->replaceAllUsesWith(arg);
CI->eraseFromParent();
}
F->eraseFromParent();
}
} // namespace
INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
"Clear convergent before dxil emit", false, false)
ModulePass *llvm::createDxilConvergentClearPass() {
return new DxilConvergentClear();
}