/////////////////////////////////////////////////////////////////////////////// // // // DxilPIXAddTidToAmplificationShaderPayload.cpp // // Copyright (C) Microsoft Corporation. All rights reserved. // // This file is distributed under the University of Illinois Open Source // // License. See LICENSE.TXT for details. // // // /////////////////////////////////////////////////////////////////////////////// #include "dxc/DXIL/DxilOperations.h" #include "dxc/DXIL/DxilUtil.h" #include "dxc/DXIL/DxilInstructions.h" #include "dxc/DXIL/DxilModule.h" #include "dxc/DxilPIXPasses/DxilPIXPasses.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Transforms/Utils/Local.h" #include "PixPassHelpers.h" using namespace llvm; using namespace hlsl; using namespace PIXPassHelpers; class DxilPIXAddTidToAmplificationShaderPayload : public ModulePass { uint32_t m_DispatchArgumentY = 1; uint32_t m_DispatchArgumentZ = 1; public: static char ID; // Pass identification, replacement for typeid DxilPIXAddTidToAmplificationShaderPayload() : ModulePass(ID) {} StringRef getPassName() const override { return "DXIL Add flat thread id to payload from AS to MS"; } bool runOnModule(Module &M) override; void applyOptions(PassOptions O) override; }; void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) { GetPassOptionUInt32(O, "dispatchArgY", &m_DispatchArgumentY, 1); GetPassOptionUInt32(O, "dispatchArgZ", &m_DispatchArgumentZ, 1); } void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B, ExpandedStruct &expanded, AllocaInst *NewStructAlloca, unsigned int expandedValueIndex, Value *value) { Constant *Zero32Arg = HlslOP->GetU32Const(0); SmallVector IndexToAppendedValue; IndexToAppendedValue.push_back(Zero32Arg); IndexToAppendedValue.push_back(HlslOP->GetU32Const(expandedValueIndex)); auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP( expanded.ExpandedPayloadStructType, NewStructAlloca, IndexToAppendedValue, "PointerToEmbeddedNewValue" + std::to_string(expandedValueIndex)); B.CreateStore(value, PointerToEmbeddedNewValue); } bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) { DxilModule &DM = M.GetOrCreateDxilModule(); LLVMContext &Ctx = M.getContext(); OP *HlslOP = DM.GetOP(); Type *OriginalPayloadStructPointerType = nullptr; Type *OriginalPayloadStructType = nullptr; ExpandedStruct expanded; llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction(DM); for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction); I != E; ++I) { if (auto *Instr = llvm::cast(&*I)) { if (hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::DispatchMesh)) { DxilInst_DispatchMesh DispatchMesh(Instr); OriginalPayloadStructPointerType = DispatchMesh.get_payload()->getType(); OriginalPayloadStructType = OriginalPayloadStructPointerType->getPointerElementType(); expanded = ExpandStructType(Ctx, OriginalPayloadStructType); } } } AllocaInst *OldStructAlloca = nullptr; AllocaInst *NewStructAlloca = nullptr; std::vector allocasOfPayloadType; for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction); I != E; ++I) { auto *Inst = &*I; if (llvm::isa(Inst)) { auto *Alloca = llvm::cast(Inst); if (Alloca->getType() == OriginalPayloadStructPointerType) { allocasOfPayloadType.push_back(Alloca); } } } for (auto &Alloca : allocasOfPayloadType) { OldStructAlloca = Alloca; llvm::IRBuilder<> B(Alloca->getContext()); NewStructAlloca = B.CreateAlloca(expanded.ExpandedPayloadStructType, HlslOP->GetU32Const(1), "NewPayload"); NewStructAlloca->setAlignment(Alloca->getAlignment()); NewStructAlloca->insertAfter(Alloca); ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction( Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType); } auto F = HlslOP->GetOpFunc(DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType); for (auto FI = F->user_begin(); FI != F->user_end();) { auto *FunctionUser = *FI++; auto *UserInstruction = llvm::cast(FunctionUser); DxilInst_DispatchMesh DispatchMesh(UserInstruction); llvm::IRBuilder<> B(UserInstruction); auto ThreadIdFunc = HlslOP->GetOpFunc(DXIL::OpCode::ThreadId, Type::getInt32Ty(Ctx)); Constant *Opcode = HlslOP->GetU32Const((unsigned)DXIL::OpCode::ThreadId); Constant *Zero32Arg = HlslOP->GetU32Const(0); Constant *One32Arg = HlslOP->GetU32Const(1); Constant *Two32Arg = HlslOP->GetU32Const(2); auto ThreadIdX = B.CreateCall(ThreadIdFunc, {Opcode, Zero32Arg}, "ThreadIdX"); auto ThreadIdY = B.CreateCall(ThreadIdFunc, {Opcode, One32Arg}, "ThreadIdY"); auto ThreadIdZ = B.CreateCall(ThreadIdFunc, {Opcode, Two32Arg}, "ThreadIdZ"); auto *XxY = B.CreateMul(ThreadIdX, HlslOP->GetU32Const(m_DispatchArgumentY)); auto *XplusY = B.CreateAdd(ThreadIdY, XxY); auto *XYxZ = B.CreateMul(XplusY, HlslOP->GetU32Const(m_DispatchArgumentZ)); auto *XYZ = B.CreateAdd(ThreadIdZ, XYxZ); AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca, OriginalPayloadStructType->getStructNumElements(), XYZ); AddValueToExpandedPayload( HlslOP, B, expanded, NewStructAlloca, OriginalPayloadStructType->getStructNumElements() + 1, DispatchMesh.get_threadGroupCountY()); AddValueToExpandedPayload( HlslOP, B, expanded, NewStructAlloca, OriginalPayloadStructType->getStructNumElements() + 2, DispatchMesh.get_threadGroupCountZ()); } DM.ReEmitDxilResources(); return true; } char DxilPIXAddTidToAmplificationShaderPayload::ID = 0; ModulePass *llvm::createDxilPIXAddTidToAmplificationShaderPayloadPass() { return new DxilPIXAddTidToAmplificationShaderPayload(); } INITIALIZE_PASS(DxilPIXAddTidToAmplificationShaderPayload, "hlsl-dxil-PIX-add-tid-to-as-payload", "HLSL DXIL Add flat thread id to payload from AS to MS", false, false)