DirectXShaderCompiler/lib/DxilPIXPasses/DxilPIXAddTidToAmplificatio...

168 строки
6.6 KiB
C++
Исходник Обычный вид История

///////////////////////////////////////////////////////////////////////////////
// //
// DxilPIXAddTidToAmplificationShaderPayload.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/DXIL/DxilInstructions.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DxilPIXPasses/DxilPIXPasses.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Transforms/Utils/Local.h"
#include "PixPassHelpers.h"
using namespace llvm;
using namespace hlsl;
using namespace PIXPassHelpers;
class DxilPIXAddTidToAmplificationShaderPayload : public ModulePass {
uint32_t m_DispatchArgumentY = 1;
uint32_t m_DispatchArgumentZ = 1;
public:
static char ID; // Pass identification, replacement for typeid
DxilPIXAddTidToAmplificationShaderPayload() : ModulePass(ID) {}
StringRef getPassName() const override {
return "DXIL Add flat thread id to payload from AS to MS";
}
bool runOnModule(Module &M) override;
void applyOptions(PassOptions O) override;
};
void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) {
GetPassOptionUInt32(O, "dispatchArgY", &m_DispatchArgumentY, 1);
GetPassOptionUInt32(O, "dispatchArgZ", &m_DispatchArgumentZ, 1);
}
void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B,
ExpandedStruct &expanded,
AllocaInst *NewStructAlloca,
unsigned int expandedValueIndex, Value *value) {
Constant *Zero32Arg = HlslOP->GetU32Const(0);
SmallVector<Value *, 2> IndexToAppendedValue;
IndexToAppendedValue.push_back(Zero32Arg);
IndexToAppendedValue.push_back(HlslOP->GetU32Const(expandedValueIndex));
auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP(
expanded.ExpandedPayloadStructType, NewStructAlloca, IndexToAppendedValue,
"PointerToEmbeddedNewValue" + std::to_string(expandedValueIndex));
B.CreateStore(value, PointerToEmbeddedNewValue);
}
bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
LLVMContext &Ctx = M.getContext();
OP *HlslOP = DM.GetOP();
Type *OriginalPayloadStructPointerType = nullptr;
Type *OriginalPayloadStructType = nullptr;
ExpandedStruct expanded;
llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction(DM);
for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
I != E; ++I) {
if (auto *Instr = llvm::cast<Instruction>(&*I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(Instr,
hlsl::OP::OpCode::DispatchMesh)) {
DxilInst_DispatchMesh DispatchMesh(Instr);
OriginalPayloadStructPointerType =
DispatchMesh.get_payload()->getType();
OriginalPayloadStructType =
OriginalPayloadStructPointerType->getPointerElementType();
expanded = ExpandStructType(Ctx, OriginalPayloadStructType);
}
}
}
AllocaInst *OldStructAlloca = nullptr;
AllocaInst *NewStructAlloca = nullptr;
std::vector<AllocaInst *> allocasOfPayloadType;
for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
I != E; ++I) {
auto *Inst = &*I;
if (llvm::isa<AllocaInst>(Inst)) {
auto *Alloca = llvm::cast<AllocaInst>(Inst);
if (Alloca->getType() == OriginalPayloadStructPointerType) {
allocasOfPayloadType.push_back(Alloca);
}
}
}
for (auto &Alloca : allocasOfPayloadType) {
OldStructAlloca = Alloca;
llvm::IRBuilder<> B(Alloca->getContext());
NewStructAlloca = B.CreateAlloca(expanded.ExpandedPayloadStructType,
HlslOP->GetU32Const(1), "NewPayload");
NewStructAlloca->setAlignment(Alloca->getAlignment());
NewStructAlloca->insertAfter(Alloca);
ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction(
Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType);
}
auto F = HlslOP->GetOpFunc(DXIL::OpCode::DispatchMesh,
expanded.ExpandedPayloadStructPtrType);
for (auto FI = F->user_begin(); FI != F->user_end();) {
auto *FunctionUser = *FI++;
auto *UserInstruction = llvm::cast<Instruction>(FunctionUser);
DxilInst_DispatchMesh DispatchMesh(UserInstruction);
llvm::IRBuilder<> B(UserInstruction);
auto ThreadIdFunc =
HlslOP->GetOpFunc(DXIL::OpCode::ThreadId, Type::getInt32Ty(Ctx));
Constant *Opcode = HlslOP->GetU32Const((unsigned)DXIL::OpCode::ThreadId);
Constant *Zero32Arg = HlslOP->GetU32Const(0);
Constant *One32Arg = HlslOP->GetU32Const(1);
Constant *Two32Arg = HlslOP->GetU32Const(2);
auto ThreadIdX =
B.CreateCall(ThreadIdFunc, {Opcode, Zero32Arg}, "ThreadIdX");
auto ThreadIdY =
B.CreateCall(ThreadIdFunc, {Opcode, One32Arg}, "ThreadIdY");
auto ThreadIdZ =
B.CreateCall(ThreadIdFunc, {Opcode, Two32Arg}, "ThreadIdZ");
auto *XxY =
B.CreateMul(ThreadIdX, HlslOP->GetU32Const(m_DispatchArgumentY));
auto *XplusY = B.CreateAdd(ThreadIdY, XxY);
auto *XYxZ = B.CreateMul(XplusY, HlslOP->GetU32Const(m_DispatchArgumentZ));
auto *XYZ = B.CreateAdd(ThreadIdZ, XYxZ);
AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements(),
XYZ);
AddValueToExpandedPayload(
HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements() + 1,
DispatchMesh.get_threadGroupCountY());
AddValueToExpandedPayload(
HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements() + 2,
DispatchMesh.get_threadGroupCountZ());
}
DM.ReEmitDxilResources();
return true;
}
char DxilPIXAddTidToAmplificationShaderPayload::ID = 0;
ModulePass *llvm::createDxilPIXAddTidToAmplificationShaderPayloadPass() {
return new DxilPIXAddTidToAmplificationShaderPayload();
}
INITIALIZE_PASS(DxilPIXAddTidToAmplificationShaderPayload,
"hlsl-dxil-PIX-add-tid-to-as-payload",
"HLSL DXIL Add flat thread id to payload from AS to MS", false,
false)