DirectXShaderCompiler/lib/DXIL/DxilUtil.cpp

503 строки
15 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// //
// DxilUtil.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// Dxil helper functions. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "llvm/IR/GlobalVariable.h"
#include "dxc/DXIL/DxilTypeSystem.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/DXIL/DxilModule.h"
#include "llvm/Bitcode/ReaderWriter.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "dxc/Support/Global.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
using namespace llvm;
using namespace hlsl;
namespace hlsl {
namespace dxilutil {
const char ManglingPrefix[] = "\01?";
const char EntryPrefix[] = "dx.entry.";
Type *GetArrayEltTy(Type *Ty) {
if (isa<PointerType>(Ty))
Ty = Ty->getPointerElementType();
while (isa<ArrayType>(Ty)) {
Ty = Ty->getArrayElementType();
}
return Ty;
}
bool HasDynamicIndexing(Value *V) {
for (auto User : V->users()) {
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
if (!isa<ConstantInt>(Idx))
return true;
}
}
}
return false;
}
unsigned
GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
llvm::Type *Ty,
DxilTypeSystem &typeSys) {
while (isa<ArrayType>(Ty)) {
Ty = Ty->getArrayElementType();
}
// Bytes.
CompType compType = fieldAnnotation.GetCompType();
unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
unsigned fieldSize = compSize;
if (Ty->isVectorTy()) {
fieldSize *= Ty->getVectorNumElements();
} else if (StructType *ST = dyn_cast<StructType>(Ty)) {
DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
if (EltAnnotation) {
fieldSize = EltAnnotation->GetCBufferSize();
} else {
// Calculate size when don't have annotation.
if (fieldAnnotation.HasMatrixAnnotation()) {
const DxilMatrixAnnotation &matAnnotation =
fieldAnnotation.GetMatrixAnnotation();
unsigned rows = matAnnotation.Rows;
unsigned cols = matAnnotation.Cols;
if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
rows = cols;
cols = matAnnotation.Rows;
} else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
// Invalid matrix orientation.
fieldSize = 0;
}
fieldSize = (rows - 1) * 16 + cols * 4;
} else {
// Cannot find struct annotation.
fieldSize = 0;
}
}
}
return fieldSize;
}
bool IsStaticGlobal(GlobalVariable *GV) {
return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
}
bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
}
bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
Function *PatchConstantFunc, bool IsLib) {
std::vector<Function *> deadList;
for (auto &F : M.functions()) {
if (&F == EntryFunc || &F == PatchConstantFunc)
continue;
if (F.isDeclaration() || !IsLib) {
if (F.user_empty())
deadList.emplace_back(&F);
}
}
bool bUpdated = deadList.size();
for (Function *F : deadList)
F->eraseFromParent();
return bUpdated;
}
void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
DI.print(*printer);
}
StringRef DemangleFunctionName(StringRef name) {
if (!name.startswith(ManglingPrefix)) {
// Name isn't mangled.
return name;
}
size_t nameEnd = name.find_first_of("@");
DXASSERT(nameEnd != StringRef::npos, "else Name isn't mangled but has \01?");
return name.substr(2, nameEnd - 2);
}
std::string ReplaceFunctionName(StringRef originalName, StringRef newName) {
if (originalName.startswith(ManglingPrefix)) {
return (Twine(ManglingPrefix) + newName +
originalName.substr(originalName.find_first_of('@'))).str();
} else if (originalName.startswith(EntryPrefix)) {
return (Twine(EntryPrefix) + newName).str();
}
return newName.str();
}
// From AsmWriter.cpp
// PrintEscapedString - Print each character of the specified string, escaping
// it if it is not printable or if it is an escape char.
void PrintEscapedString(StringRef Name, raw_ostream &Out) {
for (unsigned i = 0, e = Name.size(); i != e; ++i) {
unsigned char C = Name[i];
if (isprint(C) && C != '\\' && C != '"')
Out << C;
else
Out << '\\' << hexdigit(C >> 4) << hexdigit(C & 0x0F);
}
}
void PrintUnescapedString(StringRef Name, raw_ostream &Out) {
for (unsigned i = 0, e = Name.size(); i != e; ++i) {
unsigned char C = Name[i];
if (C == '\\') {
C = Name[++i];
unsigned value = hexDigitValue(C);
if (value != -1U) {
C = (unsigned char)value;
unsigned value2 = hexDigitValue(Name[i+1]);
assert(value2 != -1U && "otherwise, not a two digit hex escape");
if (value2 != -1U) {
C = (C << 4) + (unsigned char)value2;
++i;
}
} // else, the next character (in C) should be the escaped character
}
Out << C;
}
}
std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
llvm::LLVMContext &Ctx,
std::string &DiagStr) {
// Note: the DiagStr is not used.
ErrorOr<std::unique_ptr<llvm::Module>> pModule(
llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx));
if (std::error_code ec = pModule.getError()) {
return nullptr;
}
return std::unique_ptr<llvm::Module>(pModule.get().release());
}
std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
llvm::LLVMContext &Ctx,
std::string &DiagStr) {
std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
llvm::MemoryBuffer::getMemBuffer(BC, "", false));
return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
}
// If we don't have debug location and this is select/phi,
// try recursing users to find instruction with debug info.
// Only recurse phi/select and limit depth to prevent doing
// too much work if no debug location found.
static bool EmitErrorOnInstructionFollowPhiSelect(
Instruction *I, StringRef Msg, unsigned depth=0) {
if (depth > 4)
return false;
if (I->getDebugLoc().get()) {
EmitErrorOnInstruction(I, Msg);
return true;
}
if (isa<PHINode>(I) || isa<SelectInst>(I)) {
for (auto U : I->users())
if (Instruction *UI = dyn_cast<Instruction>(U))
if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1))
return true;
}
return false;
}
void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
const DebugLoc &DL = I->getDebugLoc();
if (DL.get()) {
std::string locString;
raw_string_ostream os(locString);
DL.print(os);
I->getContext().emitError(os.str() + ": " + Twine(Msg));
return;
} else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
if (EmitErrorOnInstructionFollowPhiSelect(I, Msg))
return;
}
I->getContext().emitError(Twine(Msg) + " Use /Zi for source location.");
}
const StringRef kResourceMapErrorMsg =
"local resource not guaranteed to map to unique global resource.";
void EmitResMappingError(Instruction *Res) {
EmitErrorOnInstruction(Res, kResourceMapErrorMsg);
}
void CollectSelect(llvm::Instruction *Inst,
std::unordered_set<llvm::Instruction *> &selectSet) {
unsigned startOpIdx = 0;
// Skip Cond for Select.
if (isa<SelectInst>(Inst)) {
startOpIdx = 1;
} else if (!isa<PHINode>(Inst)) {
// Only check phi and select here.
return;
}
// Already add.
if (selectSet.count(Inst))
return;
selectSet.insert(Inst);
// Scan operand to add node which is phi/select.
unsigned numOperands = Inst->getNumOperands();
for (unsigned i = startOpIdx; i < numOperands; i++) {
Value *V = Inst->getOperand(i);
if (Instruction *I = dyn_cast<Instruction>(V)) {
CollectSelect(I, selectSet);
}
}
}
Value *MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
unsigned numOperands) {
Value *op0 = nullptr;
for (unsigned i = startOpIdx; i < numOperands; i++) {
Value *op = SelInst->getOperand(i);
if (i == startOpIdx) {
op0 = op;
} else {
if (op0 != op)
return nullptr;
}
}
if (op0) {
SelInst->replaceAllUsesWith(op0);
SelInst->eraseFromParent();
}
return op0;
}
Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) {
Instruction *prototype = Inst;
for (unsigned i = 0; i < prototype->getNumOperands(); i++) {
if (i == operandIdx)
continue;
if (!isa<Constant>(prototype->getOperand(i)))
return nullptr;
}
Value *V = prototype->getOperand(operandIdx);
if (SelectInst *SI = dyn_cast<SelectInst>(V)) {
IRBuilder<> Builder(SI);
Instruction *trueClone = Inst->clone();
trueClone->setOperand(operandIdx, SI->getTrueValue());
Builder.Insert(trueClone);
Instruction *falseClone = Inst->clone();
falseClone->setOperand(operandIdx, SI->getFalseValue());
Builder.Insert(falseClone);
Value *newSel =
Builder.CreateSelect(SI->getCondition(), trueClone, falseClone);
return newSel;
}
if (PHINode *Phi = dyn_cast<PHINode>(V)) {
Type *Ty = Inst->getType();
unsigned numOperands = Phi->getNumOperands();
IRBuilder<> Builder(Phi);
PHINode *newPhi = Builder.CreatePHI(Ty, numOperands);
for (unsigned i = 0; i < numOperands; i++) {
BasicBlock *b = Phi->getIncomingBlock(i);
Value *V = Phi->getIncomingValue(i);
Instruction *iClone = Inst->clone();
IRBuilder<> iBuilder(b->getTerminator()->getPrevNode());
iClone->setOperand(operandIdx, V);
iBuilder.Insert(iClone);
newPhi->addIncoming(iClone, b);
}
return newPhi;
}
return nullptr;
}
llvm::Instruction *SkipAllocas(llvm::Instruction *I) {
// Step past any allocas:
while (I && isa<AllocaInst>(I))
I = I->getNextNode();
return I;
}
llvm::Instruction *FindAllocaInsertionPt(llvm::Instruction* I) {
Function *F = I->getParent()->getParent();
if (F)
return &*F->getEntryBlock().getFirstInsertionPt();
else // BB with no parent function
return &*I->getParent()->getFirstInsertionPt();
}
llvm::Instruction *FindAllocaInsertionPt(llvm::Function* F) {
return &*F->getEntryBlock().getFirstInsertionPt();
}
llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Instruction* I) {
return SkipAllocas(FindAllocaInsertionPt(I));
}
llvm::Instruction *FirstNonAllocaInsertionPt(llvm::BasicBlock* BB) {
return SkipAllocas(
&*BB->getFirstInsertionPt());
}
llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
return SkipAllocas(
&*F->getEntryBlock().getFirstInsertionPt());
}
bool IsHLSLObjectType(llvm::Type *Ty) {
if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
StringRef name = ST->getName();
// TODO: don't check names.
if (name.startswith("dx.types.wave_t"))
return true;
if (name.endswith("_slice_type"))
return false;
name = name.ltrim("class.");
name = name.ltrim("struct.");
if (name == "SamplerState")
return true;
if (name == "SamplerComparisonState")
return true;
if (name.startswith("TriangleStream<"))
return true;
if (name.startswith("PointStream<"))
return true;
if (name.startswith("LineStream<"))
return true;
if (name.startswith("AppendStructuredBuffer<"))
return true;
if (name.startswith("ConsumeStructuredBuffer<"))
return true;
if (name.startswith("ConstantBuffer<"))
return true;
if (name == "RaytracingAccelerationStructure")
return true;
name = name.ltrim("RasterizerOrdered");
name = name.ltrim("RW");
if (name == "ByteAddressBuffer")
return true;
if (name.startswith("Buffer<"))
return true;
if (name.startswith("StructuredBuffer<"))
return true;
if (name.startswith("Texture1D<"))
return true;
if (name.startswith("Texture1DArray<"))
return true;
if (name.startswith("Texture2D<"))
return true;
if (name.startswith("Texture2DArray<"))
return true;
if (name.startswith("Texture3D<"))
return true;
if (name.startswith("TextureCube<"))
return true;
if (name.startswith("TextureCubeArray<"))
return true;
if (name.startswith("Texture2DMS<"))
return true;
if (name.startswith("Texture2DMSArray<"))
return true;
}
return false;
}
bool ContainsHLSLObjectType(llvm::Type *Ty) {
// Unwrap pointer/array
while (llvm::isa<llvm::PointerType>(Ty))
Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
while (llvm::isa<llvm::ArrayType>(Ty))
Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
if (ST->getName().startswith("dx.types."))
return true;
// TODO: How is this suppoed to check for Input/OutputPatch types if
// these have already been eliminated in function arguments during CG?
if (IsHLSLObjectType(Ty))
return true;
// Otherwise, recurse elements of UDT
for (auto ETy : ST->elements()) {
if (ContainsHLSLObjectType(ETy))
return true;
}
}
return false;
}
// Based on the implementation available in LLVM's trunk:
// http://llvm.org/doxygen/Constants_8cpp_source.html#l02734
bool IsSplat(llvm::ConstantDataVector *cdv) {
const char *Base = cdv->getRawDataValues().data();
// Compare elements 1+ to the 0'th element.
unsigned EltSize = cdv->getElementByteSize();
for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i)
if (memcmp(Base, Base + i * EltSize, EltSize))
return false;
return true;
}
}
}
///////////////////////////////////////////////////////////////////////////////
namespace {
class DxilLoadMetadata : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilLoadMetadata () : ModulePass(ID) {}
const char *getPassName() const override { return "HLSL load DxilModule from metadata"; }
bool runOnModule(Module &M) override {
if (!M.HasDxilModule()) {
(void)M.GetOrCreateDxilModule();
return true;
}
return false;
}
};
}
char DxilLoadMetadata::ID = 0;
ModulePass *llvm::createDxilLoadMetadataPass() {
return new DxilLoadMetadata();
}
INITIALIZE_PASS(DxilLoadMetadata, "hlsl-dxilload", "HLSL load DxilModule from metadata", false, false)