/////////////////////////////////////////////////////////////////////////////// // // // DxilUtil.cpp // // Copyright (C) Microsoft Corporation. All rights reserved. // // This file is distributed under the University of Illinois Open Source // // License. See LICENSE.TXT for details. // // // // Dxil helper functions. // // // /////////////////////////////////////////////////////////////////////////////// #include "llvm/IR/GlobalVariable.h" #include "dxc/DXIL/DxilTypeSystem.h" #include "dxc/DXIL/DxilUtil.h" #include "dxc/DXIL/DxilModule.h" #include "llvm/Bitcode/ReaderWriter.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "dxc/Support/Global.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" using namespace llvm; using namespace hlsl; namespace hlsl { namespace dxilutil { const char ManglingPrefix[] = "\01?"; const char EntryPrefix[] = "dx.entry."; Type *GetArrayEltTy(Type *Ty) { if (isa(Ty)) Ty = Ty->getPointerElementType(); while (isa(Ty)) { Ty = Ty->getArrayElementType(); } return Ty; } bool HasDynamicIndexing(Value *V) { for (auto User : V->users()) { if (GetElementPtrInst *GEP = dyn_cast(User)) { for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) { if (!isa(Idx)) return true; } } } return false; } unsigned GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation, llvm::Type *Ty, DxilTypeSystem &typeSys) { while (isa(Ty)) { Ty = Ty->getArrayElementType(); } // Bytes. CompType compType = fieldAnnotation.GetCompType(); unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4; unsigned fieldSize = compSize; if (Ty->isVectorTy()) { fieldSize *= Ty->getVectorNumElements(); } else if (StructType *ST = dyn_cast(Ty)) { DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST); if (EltAnnotation) { fieldSize = EltAnnotation->GetCBufferSize(); } else { // Calculate size when don't have annotation. if (fieldAnnotation.HasMatrixAnnotation()) { const DxilMatrixAnnotation &matAnnotation = fieldAnnotation.GetMatrixAnnotation(); unsigned rows = matAnnotation.Rows; unsigned cols = matAnnotation.Cols; if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) { rows = cols; cols = matAnnotation.Rows; } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) { // Invalid matrix orientation. fieldSize = 0; } fieldSize = (rows - 1) * 16 + cols * 4; } else { // Cannot find struct annotation. fieldSize = 0; } } } return fieldSize; } bool IsStaticGlobal(GlobalVariable *GV) { return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage && GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace; } bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) { return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace; } bool RemoveUnusedFunctions(Module &M, Function *EntryFunc, Function *PatchConstantFunc, bool IsLib) { std::vector deadList; for (auto &F : M.functions()) { if (&F == EntryFunc || &F == PatchConstantFunc) continue; if (F.isDeclaration() || !IsLib) { if (F.user_empty()) deadList.emplace_back(&F); } } bool bUpdated = deadList.size(); for (Function *F : deadList) F->eraseFromParent(); return bUpdated; } void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) { DiagnosticPrinter *printer = reinterpret_cast(Context); DI.print(*printer); } StringRef DemangleFunctionName(StringRef name) { if (!name.startswith(ManglingPrefix)) { // Name isn't mangled. return name; } size_t nameEnd = name.find_first_of("@"); DXASSERT(nameEnd != StringRef::npos, "else Name isn't mangled but has \01?"); return name.substr(2, nameEnd - 2); } std::string ReplaceFunctionName(StringRef originalName, StringRef newName) { if (originalName.startswith(ManglingPrefix)) { return (Twine(ManglingPrefix) + newName + originalName.substr(originalName.find_first_of('@'))).str(); } else if (originalName.startswith(EntryPrefix)) { return (Twine(EntryPrefix) + newName).str(); } return newName.str(); } // From AsmWriter.cpp // PrintEscapedString - Print each character of the specified string, escaping // it if it is not printable or if it is an escape char. void PrintEscapedString(StringRef Name, raw_ostream &Out) { for (unsigned i = 0, e = Name.size(); i != e; ++i) { unsigned char C = Name[i]; if (isprint(C) && C != '\\' && C != '"') Out << C; else Out << '\\' << hexdigit(C >> 4) << hexdigit(C & 0x0F); } } void PrintUnescapedString(StringRef Name, raw_ostream &Out) { for (unsigned i = 0, e = Name.size(); i != e; ++i) { unsigned char C = Name[i]; if (C == '\\') { C = Name[++i]; unsigned value = hexDigitValue(C); if (value != -1U) { C = (unsigned char)value; unsigned value2 = hexDigitValue(Name[i+1]); assert(value2 != -1U && "otherwise, not a two digit hex escape"); if (value2 != -1U) { C = (C << 4) + (unsigned char)value2; ++i; } } // else, the next character (in C) should be the escaped character } Out << C; } } std::unique_ptr LoadModuleFromBitcode(llvm::MemoryBuffer *MB, llvm::LLVMContext &Ctx, std::string &DiagStr) { // Note: the DiagStr is not used. ErrorOr> pModule( llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx)); if (std::error_code ec = pModule.getError()) { return nullptr; } return std::unique_ptr(pModule.get().release()); } std::unique_ptr LoadModuleFromBitcode(llvm::StringRef BC, llvm::LLVMContext &Ctx, std::string &DiagStr) { std::unique_ptr pBitcodeBuf( llvm::MemoryBuffer::getMemBuffer(BC, "", false)); return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr); } // If we don't have debug location and this is select/phi, // try recursing users to find instruction with debug info. // Only recurse phi/select and limit depth to prevent doing // too much work if no debug location found. static bool EmitErrorOnInstructionFollowPhiSelect( Instruction *I, StringRef Msg, unsigned depth=0) { if (depth > 4) return false; if (I->getDebugLoc().get()) { EmitErrorOnInstruction(I, Msg); return true; } if (isa(I) || isa(I)) { for (auto U : I->users()) if (Instruction *UI = dyn_cast(U)) if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1)) return true; } return false; } void EmitErrorOnInstruction(Instruction *I, StringRef Msg) { const DebugLoc &DL = I->getDebugLoc(); if (DL.get()) { std::string locString; raw_string_ostream os(locString); DL.print(os); I->getContext().emitError(os.str() + ": " + Twine(Msg)); return; } else if (isa(I) || isa(I)) { if (EmitErrorOnInstructionFollowPhiSelect(I, Msg)) return; } I->getContext().emitError(Twine(Msg) + " Use /Zi for source location."); } const StringRef kResourceMapErrorMsg = "local resource not guaranteed to map to unique global resource."; void EmitResMappingError(Instruction *Res) { EmitErrorOnInstruction(Res, kResourceMapErrorMsg); } void CollectSelect(llvm::Instruction *Inst, std::unordered_set &selectSet) { unsigned startOpIdx = 0; // Skip Cond for Select. if (isa(Inst)) { startOpIdx = 1; } else if (!isa(Inst)) { // Only check phi and select here. return; } // Already add. if (selectSet.count(Inst)) return; selectSet.insert(Inst); // Scan operand to add node which is phi/select. unsigned numOperands = Inst->getNumOperands(); for (unsigned i = startOpIdx; i < numOperands; i++) { Value *V = Inst->getOperand(i); if (Instruction *I = dyn_cast(V)) { CollectSelect(I, selectSet); } } } Value *MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx, unsigned numOperands) { Value *op0 = nullptr; for (unsigned i = startOpIdx; i < numOperands; i++) { Value *op = SelInst->getOperand(i); if (i == startOpIdx) { op0 = op; } else { if (op0 != op) return nullptr; } } if (op0) { SelInst->replaceAllUsesWith(op0); SelInst->eraseFromParent(); } return op0; } Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) { Instruction *prototype = Inst; for (unsigned i = 0; i < prototype->getNumOperands(); i++) { if (i == operandIdx) continue; if (!isa(prototype->getOperand(i))) return nullptr; } Value *V = prototype->getOperand(operandIdx); if (SelectInst *SI = dyn_cast(V)) { IRBuilder<> Builder(SI); Instruction *trueClone = Inst->clone(); trueClone->setOperand(operandIdx, SI->getTrueValue()); Builder.Insert(trueClone); Instruction *falseClone = Inst->clone(); falseClone->setOperand(operandIdx, SI->getFalseValue()); Builder.Insert(falseClone); Value *newSel = Builder.CreateSelect(SI->getCondition(), trueClone, falseClone); return newSel; } if (PHINode *Phi = dyn_cast(V)) { Type *Ty = Inst->getType(); unsigned numOperands = Phi->getNumOperands(); IRBuilder<> Builder(Phi); PHINode *newPhi = Builder.CreatePHI(Ty, numOperands); for (unsigned i = 0; i < numOperands; i++) { BasicBlock *b = Phi->getIncomingBlock(i); Value *V = Phi->getIncomingValue(i); Instruction *iClone = Inst->clone(); IRBuilder<> iBuilder(b->getTerminator()->getPrevNode()); iClone->setOperand(operandIdx, V); iBuilder.Insert(iClone); newPhi->addIncoming(iClone, b); } return newPhi; } return nullptr; } llvm::Instruction *SkipAllocas(llvm::Instruction *I) { // Step past any allocas: while (I && isa(I)) I = I->getNextNode(); return I; } llvm::Instruction *FindAllocaInsertionPt(llvm::Instruction* I) { Function *F = I->getParent()->getParent(); if (F) return &*F->getEntryBlock().getFirstInsertionPt(); else // BB with no parent function return &*I->getParent()->getFirstInsertionPt(); } llvm::Instruction *FindAllocaInsertionPt(llvm::Function* F) { return &*F->getEntryBlock().getFirstInsertionPt(); } llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Instruction* I) { return SkipAllocas(FindAllocaInsertionPt(I)); } llvm::Instruction *FirstNonAllocaInsertionPt(llvm::BasicBlock* BB) { return SkipAllocas( &*BB->getFirstInsertionPt()); } llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) { return SkipAllocas( &*F->getEntryBlock().getFirstInsertionPt()); } bool IsHLSLObjectType(llvm::Type *Ty) { if (llvm::StructType *ST = dyn_cast(Ty)) { StringRef name = ST->getName(); // TODO: don't check names. if (name.startswith("dx.types.wave_t")) return true; if (name.endswith("_slice_type")) return false; name = name.ltrim("class."); name = name.ltrim("struct."); if (name == "SamplerState") return true; if (name == "SamplerComparisonState") return true; if (name.startswith("TriangleStream<")) return true; if (name.startswith("PointStream<")) return true; if (name.startswith("LineStream<")) return true; if (name.startswith("AppendStructuredBuffer<")) return true; if (name.startswith("ConsumeStructuredBuffer<")) return true; if (name.startswith("ConstantBuffer<")) return true; if (name == "RaytracingAccelerationStructure") return true; name = name.ltrim("RasterizerOrdered"); name = name.ltrim("RW"); if (name == "ByteAddressBuffer") return true; if (name.startswith("Buffer<")) return true; if (name.startswith("StructuredBuffer<")) return true; if (name.startswith("Texture1D<")) return true; if (name.startswith("Texture1DArray<")) return true; if (name.startswith("Texture2D<")) return true; if (name.startswith("Texture2DArray<")) return true; if (name.startswith("Texture3D<")) return true; if (name.startswith("TextureCube<")) return true; if (name.startswith("TextureCubeArray<")) return true; if (name.startswith("Texture2DMS<")) return true; if (name.startswith("Texture2DMSArray<")) return true; } return false; } bool ContainsHLSLObjectType(llvm::Type *Ty) { // Unwrap pointer/array while (llvm::isa(Ty)) Ty = llvm::cast(Ty)->getPointerElementType(); while (llvm::isa(Ty)) Ty = llvm::cast(Ty)->getArrayElementType(); if (llvm::StructType *ST = llvm::dyn_cast(Ty)) { if (ST->getName().startswith("dx.types.")) return true; // TODO: How is this suppoed to check for Input/OutputPatch types if // these have already been eliminated in function arguments during CG? if (IsHLSLObjectType(Ty)) return true; // Otherwise, recurse elements of UDT for (auto ETy : ST->elements()) { if (ContainsHLSLObjectType(ETy)) return true; } } return false; } // Based on the implementation available in LLVM's trunk: // http://llvm.org/doxygen/Constants_8cpp_source.html#l02734 bool IsSplat(llvm::ConstantDataVector *cdv) { const char *Base = cdv->getRawDataValues().data(); // Compare elements 1+ to the 0'th element. unsigned EltSize = cdv->getElementByteSize(); for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i) if (memcmp(Base, Base + i * EltSize, EltSize)) return false; return true; } } } /////////////////////////////////////////////////////////////////////////////// namespace { class DxilLoadMetadata : public ModulePass { public: static char ID; // Pass identification, replacement for typeid explicit DxilLoadMetadata () : ModulePass(ID) {} const char *getPassName() const override { return "HLSL load DxilModule from metadata"; } bool runOnModule(Module &M) override { if (!M.HasDxilModule()) { (void)M.GetOrCreateDxilModule(); return true; } return false; } }; } char DxilLoadMetadata::ID = 0; ModulePass *llvm::createDxilLoadMetadataPass() { return new DxilLoadMetadata(); } INITIALIZE_PASS(DxilLoadMetadata, "hlsl-dxilload", "HLSL load DxilModule from metadata", false, false)