DirectXShaderCompiler/lib/HLSL/DxilPatchShaderRecordBindin...

1210 строки
46 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// //
// DxilPatchShaderRecordBindings.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// Provides a pass used by the RayTracing Fallback Lyaer to add modify //
// bindings to pull local root signature parameters from a global //
// "shader table" buffer instead //
// //
///////////////////////////////////////////////////////////////////////////////
#include "dxc/DXIL/DxilFunctionProps.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilSignatureElement.h"
#include "dxc/HLSL/DxilFallbackLayerPass.h"
#include "dxc/HLSL/DxilGenerationPass.h"
#include "dxc/Support/Global.h"
#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilInstructions.h"
#include "dxc/DXIL/DxilTypeSystem.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/DxilRootSignature/DxilRootSignature.h"
#include "dxc/HLSL/DxilSpanAllocator.h"
#include "dxc/Support/Unicode.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <array>
#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>
struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
#include "DxilPatchShaderRecordBindingsShared.h"
using namespace llvm;
using namespace hlsl;
bool operator==(const ViewKey &a, const ViewKey &b) {
return memcmp(&a, &b, sizeof(a)) == 0;
}
const size_t SizeofD3D12GpuVA = sizeof(uint64_t);
const size_t SizeofD3D12GpuDescriptorHandle = sizeof(uint64_t);
Function *CloneFunction(Function *Orig, const llvm::Twine &Name,
llvm::Module *llvmModule) {
Function *F = Function::Create(Orig->getFunctionType(),
GlobalValue::LinkageTypes::ExternalLinkage,
Name, llvmModule);
SmallVector<ReturnInst *, 2> Returns;
ValueToValueMapTy vmap;
// Map params.
auto entryParamIt = F->arg_begin();
for (Argument &param : Orig->args()) {
vmap[&param] = (entryParamIt++);
}
DxilModule &DM = llvmModule->GetOrCreateDxilModule();
llvm::CloneFunctionInto(F, Orig, vmap, /*ModuleLevelChagnes*/ false, Returns);
DM.GetTypeSystem().CopyFunctionAnnotation(F, Orig, DM.GetTypeSystem());
if (DM.HasDxilFunctionProps(F)) {
DM.CloneDxilEntryProps(Orig, F);
}
return F;
}
struct ShaderRecordEntry {
DxilRootParameterType ParameterType;
unsigned int RecordOffsetInBytes;
unsigned int OffsetInDescriptors; // Only valid for descriptor tables
static ShaderRecordEntry InvalidEntry() {
return {(DxilRootParameterType)-1, (unsigned int)-1, 0};
}
bool IsInvalid() { return (unsigned int)ParameterType == (unsigned int)-1; }
};
struct D3D12_VERSIONED_ROOT_SIGNATURE_DESC;
class DxilPatchShaderRecordBindings : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilPatchShaderRecordBindings() : ModulePass(ID) {}
StringRef getPassName() const override {
return "DXIL Patch Shader Record Binding";
}
void applyOptions(PassOptions O) override;
bool runOnModule(Module &M) override;
private:
void ValidateParameters();
void AddInputBinding(Module &M);
void PatchShaderBindings(Module &M);
void InitializeViewTable();
unsigned int AddSRVRawBuffer(Module &M, unsigned int registerIndex,
unsigned int registerSpace,
const std::string &bufferName);
unsigned int
AddHandle(Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize,
unsigned int registerSpace, DXIL::ResourceClass resClass,
DXIL::ResourceKind resKind, const std::string &bufferName,
llvm::Type *type = nullptr, unsigned int constantBufferSize = 0);
unsigned int AddAliasedHandle(Module &M, unsigned int baseRegisterIndex,
unsigned int registerSpace,
DXIL::ResourceClass resClass,
DXIL::ResourceKind resKind,
const std::string &bufferName,
llvm::Type *type);
unsigned int AddCBufferAliasedHandle(Module &M,
unsigned int baseRegisterIndex,
unsigned int registerSpace,
const std::string &bufferName);
llvm::Value *CreateOffsetToShaderRecord(Module &M, IRBuilder<> &Builder,
unsigned int RecordOffsetInBytes,
llvm::Value *CbufferOffsetInBytes);
llvm::Value *
CreateShaderRecordBufferLoad(Module &M, IRBuilder<> &Builder,
llvm::Value *ShaderRecordOffsetInBytes,
llvm::Type *type);
llvm::Value *CreateCBufferLoadOffsetInBytes(Module &M, IRBuilder<> &Builder,
llvm::Instruction *instruction);
llvm::Value *CreateCBufferLoadLegacy(Module &M, IRBuilder<> &Builder,
llvm::Value *ResourceHandle,
unsigned int RowToLoad = 0);
llvm::Value *LoadShaderRecordData(Module &M, IRBuilder<> &Builder,
llvm::Value *offsetToShaderRecord,
unsigned int dataOffsetInShaderRecord);
void PatchCreateHandleToUseDescriptorIndex(
Module &M, IRBuilder<> &Builder, DXIL::ResourceKind &resourceKind,
DXIL::ResourceClass &resourceClass, llvm::Type *resourceType,
llvm::Value *descriptorIndex,
DxilInst_CreateHandleForLib &createHandleInstr);
bool GetHandleInfo(Module &M,
DxilInst_CreateHandleForLib &createHandleStructForLib,
unsigned int &shaderRegister, unsigned int &registerSpace,
DXIL::ResourceKind &kind, DXIL::ResourceClass &resClass,
llvm::Type *&resType);
llvm::Value *GetAliasedDescriptorHeapHandle(Module &M, llvm::Type *,
DXIL::ResourceClass resClass,
DXIL::ResourceKind resKind);
unsigned int GetConstantBufferOffsetToShaderRecord();
bool IsCBufferLoad(llvm::Instruction *instruction);
// Unlike the LLVM version of this function, this does not requires the
// InstructionToReplace and the ValueToReplaceWith to be the same instruction
// type
static void ReplaceUsesOfWith(llvm::Instruction *InstructionToReplace,
llvm::Value *ValueToReplaceWith);
static ShaderRecordEntry FindRootSignatureDescriptor(
const DxilVersionedRootSignatureDesc &rootSignatureDescriptor,
unsigned int ShaderRecordIdentifierSizeInBytes,
DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
unsigned int registerSpace);
// TODO: I would like to see these prefixed with m_
llvm::Value *ShaderTableHandle = nullptr;
llvm::Value *DispatchRaysConstantsHandle = nullptr;
llvm::Value *BaseShaderRecordOffset = nullptr;
static const unsigned int NumViewTypes = 4;
struct ViewKeyHasher {
public:
std::size_t operator()(const ViewKey &x) const {
return std::hash<unsigned int>()((unsigned int)x.ViewType) ^
std::hash<unsigned int>()((unsigned int)x.StructuredStride);
}
};
std::unordered_map<ViewKey, llvm::Value *, ViewKeyHasher>
TypeToAliasedDescriptorHeap[NumViewTypes];
llvm::Function *EntryPointFunction;
ShaderInfo *pInputShaderInfo;
const DxilVersionedRootSignatureDesc *pRootSignatureDesc;
DXIL::ShaderKind ShaderKind;
};
char DxilPatchShaderRecordBindings::ID = 0;
// TODO: Find the right thing to do on failure
void ThrowFailure() { throw std::exception(); }
// TODO: Stolen from Brandon's code, merge
// Remove ELF mangling
static inline std::string GetUnmangledName(StringRef name) {
if (!name.startswith("\x1?"))
return name;
size_t pos = name.find("@@");
if (pos == name.npos)
return name;
return name.substr(2, pos - 2);
}
static Function *getFunctionFromName(Module &M,
const std::wstring &exportName) {
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
std::wstring functionName = Unicode::UTF8ToWideStringOrThrow(
GetUnmangledName(F->getName()).c_str());
if (exportName == functionName) {
return F;
}
}
return nullptr;
}
ModulePass *llvm::createDxilPatchShaderRecordBindingsPass() {
return new DxilPatchShaderRecordBindings();
}
INITIALIZE_PASS(DxilPatchShaderRecordBindings,
"hlsl-dxil-patch-shader-record-bindings",
"Patch shader record bindings to instead pull from the "
"fallback provided bindings",
false, false)
void DxilPatchShaderRecordBindings::applyOptions(PassOptions O) {
for (const auto &option : O) {
if (0 == option.first.compare("root-signature")) {
unsigned int cHexRadix = 16;
pInputShaderInfo =
(ShaderInfo *)strtoull(option.second.data(), nullptr, cHexRadix);
pRootSignatureDesc = (const DxilVersionedRootSignatureDesc *)
pInputShaderInfo->pRootSignatureDesc;
}
}
}
void AddAnnoationsIfNeeded(DxilModule &DM, llvm::StructType *StructTy,
const std::string &FieldName,
unsigned int numFields = 1) {
auto pAnnotation = DM.GetTypeSystem().GetStructAnnotation(StructTy);
if (pAnnotation == nullptr) {
pAnnotation = DM.GetTypeSystem().AddStructAnnotation(StructTy);
pAnnotation->SetCBufferSize(sizeof(uint32_t) * numFields);
for (unsigned int i = 0; i < numFields; i++) {
pAnnotation->GetFieldAnnotation(i).SetCBufferOffset(sizeof(uint32_t) * i);
pAnnotation->GetFieldAnnotation(i).SetCompType(
hlsl::DXIL::ComponentType::I32);
pAnnotation->GetFieldAnnotation(i).SetFieldName(FieldName +
std::to_string(i));
}
}
}
unsigned int DxilPatchShaderRecordBindings::AddHandle(
Module &M, unsigned int baseRegisterIndex, unsigned int rangeSize,
unsigned int registerSpace, DXIL::ResourceClass resClass,
DXIL::ResourceKind resKind, const std::string &bufferName, llvm::Type *type,
unsigned int constantBufferSize) {
LLVMContext &Ctx = M.getContext();
DxilModule &DM = M.GetOrCreateDxilModule();
// Set up a SRV with byte address buffer
unsigned int resourceHandle;
std::unique_ptr<DxilResource> pHandle;
std::unique_ptr<DxilCBuffer> pCBuf;
std::unique_ptr<DxilSampler> pSampler;
DxilResourceBase *pBaseHandle;
switch (resClass) {
case DXIL::ResourceClass::SRV:
resourceHandle = static_cast<unsigned int>(DM.GetSRVs().size());
pHandle = llvm::make_unique<DxilResource>();
pHandle->SetRW(false);
pBaseHandle = pHandle.get();
break;
case DXIL::ResourceClass::UAV:
resourceHandle = static_cast<unsigned int>(DM.GetUAVs().size());
pHandle = llvm::make_unique<DxilResource>();
pHandle->SetRW(true);
pBaseHandle = pHandle.get();
break;
case DXIL::ResourceClass::CBuffer:
resourceHandle = static_cast<unsigned int>(DM.GetCBuffers().size());
pCBuf = llvm::make_unique<DxilCBuffer>();
pCBuf->SetSize(constantBufferSize);
pBaseHandle = pCBuf.get();
break;
case DXIL::ResourceClass::Sampler:
resourceHandle = static_cast<unsigned int>(DM.GetSamplers().size());
pSampler = llvm::make_unique<DxilSampler>();
// TODO: Is this okay? What if one of the samplers in the table is a
// comparison sampler?
pSampler->SetSamplerKind(DxilSampler::SamplerKind::Default);
pBaseHandle = pSampler.get();
break;
}
if (!type) {
SmallVector<llvm::Type *, 1> Elements{Type::getInt32Ty(Ctx)};
std::string ByteAddressBufferName = "struct.ByteAddressBuffer";
type = M.getTypeByName(ByteAddressBufferName);
if (!type) {
StructType *StructTy;
type = StructTy = StructType::create(Elements, ByteAddressBufferName);
AddAnnoationsIfNeeded(DM, StructTy, ByteAddressBufferName);
}
}
GlobalVariable *GV = M.getGlobalVariable(bufferName);
if (!GV) {
GV = cast<GlobalVariable>(M.getOrInsertGlobal(bufferName, type));
}
pBaseHandle->SetGlobalName(bufferName.c_str());
pBaseHandle->SetGlobalSymbol(GV);
pBaseHandle->SetID(resourceHandle);
pBaseHandle->SetSpaceID(registerSpace);
pBaseHandle->SetLowerBound(baseRegisterIndex);
pBaseHandle->SetRangeSize(rangeSize);
pBaseHandle->SetKind(resKind);
if (pHandle) {
pHandle->SetGloballyCoherent(false);
pHandle->SetHasCounter(false);
pHandle->SetCompType(CompType::getF32()); // TODO: Need to handle all types
}
unsigned int ID;
switch (resClass) {
case DXIL::ResourceClass::SRV:
ID = DM.AddSRV(std::move(pHandle));
break;
case DXIL::ResourceClass::UAV:
ID = DM.AddUAV(std::move(pHandle));
break;
case DXIL::ResourceClass::CBuffer:
ID = DM.AddCBuffer(std::move(pCBuf));
break;
case DXIL::ResourceClass::Sampler:
ID = DM.AddSampler(std::move(pSampler));
break;
}
assert(ID == resourceHandle);
return ID;
}
unsigned int
DxilPatchShaderRecordBindings::GetConstantBufferOffsetToShaderRecord() {
switch (ShaderKind) {
case DXIL::ShaderKind::ClosestHit:
case DXIL::ShaderKind::AnyHit:
case DXIL::ShaderKind::Intersection:
return offsetof(DispatchRaysConstants, HitGroupShaderRecordStride);
case DXIL::ShaderKind::Miss:
return offsetof(DispatchRaysConstants, MissShaderRecordStride);
default:
ThrowFailure();
return -1;
}
}
unsigned int DxilPatchShaderRecordBindings::AddSRVRawBuffer(
Module &M, unsigned int registerIndex, unsigned int registerSpace,
const std::string &bufferName) {
return AddHandle(M, registerIndex, 1, registerSpace, DXIL::ResourceClass::SRV,
DXIL::ResourceKind::RawBuffer, bufferName);
}
llvm::Constant *GetArraySymbol(Module &M, const std::string &bufferName) {
LLVMContext &Ctx = M.getContext();
SmallVector<llvm::Type *, 1> Elements{Type::getInt32Ty(Ctx)};
llvm::StructType *StructTy = llvm::StructType::create(Elements, bufferName);
llvm::ArrayType *ArrayTy = ArrayType::get(StructTy, -1);
return UndefValue::get(ArrayTy->getPointerTo());
}
unsigned int DxilPatchShaderRecordBindings::AddCBufferAliasedHandle(
Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace,
const std::string &bufferName) {
const unsigned int maxConstantBufferSize = 4096 * 16;
return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace,
DXIL::ResourceClass::CBuffer, DXIL::ResourceKind::CBuffer,
bufferName, GetArraySymbol(M, bufferName)->getType(),
maxConstantBufferSize);
}
unsigned int DxilPatchShaderRecordBindings::AddAliasedHandle(
Module &M, unsigned int baseRegisterIndex, unsigned int registerSpace,
DXIL::ResourceClass resClass, DXIL::ResourceKind resKind,
const std::string &bufferName, llvm::Type *type) {
return AddHandle(M, baseRegisterIndex, UINT_MAX, registerSpace, resClass,
resKind, bufferName, type);
}
// TODO: Stolen from Brandon's code
DXIL::ShaderKind GetRayShaderKindCopy(Function *F) {
if (F->hasFnAttribute("exp-shader"))
return DXIL::ShaderKind::RayGeneration;
DxilModule &DM = F->getParent()->GetDxilModule();
if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
return DM.GetDxilFunctionProps(F).shaderKind;
return DXIL::ShaderKind::Invalid;
}
bool DxilPatchShaderRecordBindings::runOnModule(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
EntryPointFunction =
pInputShaderInfo->ExportName
? getFunctionFromName(M, pInputShaderInfo->ExportName)
: DM.GetEntryFunction();
ShaderKind = GetRayShaderKindCopy(EntryPointFunction);
ValidateParameters();
InitializeViewTable();
PatchShaderBindings(M);
DM.ReEmitDxilResources();
return true;
}
void DxilPatchShaderRecordBindings::ValidateParameters() {
if (!pInputShaderInfo || !pInputShaderInfo->pRootSignatureDesc) {
throw std::exception();
}
}
DxilResourceBase &GetResourceFromID(DxilModule &DM,
DXIL::ResourceClass resClass,
unsigned int id) {
switch (resClass) {
case DXIL::ResourceClass::CBuffer:
return DM.GetCBuffer(id);
break;
case DXIL::ResourceClass::SRV:
return DM.GetSRV(id);
break;
case DXIL::ResourceClass::UAV:
return DM.GetUAV(id);
break;
case DXIL::ResourceClass::Sampler:
return DM.GetSampler(id);
break;
default:
ThrowFailure();
llvm_unreachable("invalid resource class");
}
}
unsigned int FindOrInsertViewIntoList(const ViewKey &key, ViewKey *pViewList,
unsigned int &numViews,
unsigned int maxViews) {
unsigned int viewIndex = 0;
for (; viewIndex < numViews; viewIndex++) {
if (pViewList[viewIndex] == key) {
break;
}
}
if (viewIndex == numViews) {
if (viewIndex >= maxViews) {
ThrowFailure();
}
pViewList[viewIndex] = key;
numViews++;
}
return viewIndex;
}
llvm::Value *DxilPatchShaderRecordBindings::GetAliasedDescriptorHeapHandle(
Module &M, llvm::Type *type, DXIL::ResourceClass resClass,
DXIL::ResourceKind resKind) {
DxilModule &DM = M.GetOrCreateDxilModule();
unsigned int resClassIndex = (unsigned int)resClass;
ViewKey key = {};
key.ViewType = (unsigned int)resKind;
if (DXIL::IsStructuredBuffer(resKind)) {
key.StructuredStride = type->getPrimitiveSizeInBits();
} else if (resKind != DXIL::ResourceKind::RawBuffer) {
auto containedType = type->getContainedType(0);
// If it's a vector, get the type of just a single element
if (containedType->getNumContainedTypes() > 0) {
assert(containedType->getNumContainedTypes() <= 4);
containedType = containedType->getContainedType(0);
}
key.SRVComponentType =
(unsigned int)CompType::GetCompType(containedType).GetKind();
}
auto aliasedDescriptorHeapHandle =
TypeToAliasedDescriptorHeap[resClassIndex].find(key);
if (aliasedDescriptorHeapHandle ==
TypeToAliasedDescriptorHeap[resClassIndex].end()) {
unsigned int registerSpaceOffset = 0;
std::string HandleName;
if (resClass == DXIL::ResourceClass::SRV) {
registerSpaceOffset = FindOrInsertViewIntoList(
key, pInputShaderInfo->pSRVRegisterSpaceArray,
*pInputShaderInfo->pNumSRVSpaces,
FallbackLayerNumDescriptorHeapSpacesPerView);
HandleName = std::string("SRVDescriptorHeapTable") +
std::to_string(registerSpaceOffset);
} else if (resClass == DXIL::ResourceClass::UAV) {
registerSpaceOffset = FindOrInsertViewIntoList(
key, pInputShaderInfo->pUAVRegisterSpaceArray,
*pInputShaderInfo->pNumUAVSpaces,
FallbackLayerNumDescriptorHeapSpacesPerView);
if (registerSpaceOffset == 0) {
// Using the descriptor heap declared by the fallback for handling
// emulated pointers, make sure the name is an exact match
assert(key.ViewType ==
(unsigned int)hlsl::DXIL::ResourceKind::RawBuffer);
HandleName =
"\01?DescriptorHeapBufferTable@@3PAURWByteAddressBuffer@@A";
} else {
HandleName = std::string("UAVDescriptorHeapTable") +
std::to_string(registerSpaceOffset);
}
} else if (resClass == DXIL::ResourceClass::CBuffer) {
HandleName = std::string("CBVDescriptorHeapTable");
} else {
HandleName = std::string("SamplerDescriptorHeapTable");
}
llvm::ArrayType *descriptorHeapType = ArrayType::get(type, 0);
unsigned int id = AddAliasedHandle(
M, FallbackLayerDescriptorHeapTable,
FallbackLayerRegisterSpace + FallbackLayerDescriptorHeapSpaceOffset +
registerSpaceOffset,
resClass, resKind, HandleName, descriptorHeapType);
TypeToAliasedDescriptorHeap[resClassIndex][key] =
GetResourceFromID(DM, resClass, id).GetGlobalSymbol();
}
return TypeToAliasedDescriptorHeap[resClassIndex][key];
}
void DxilPatchShaderRecordBindings::AddInputBinding(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
auto &EntryBlock = EntryPointFunction->getEntryBlock();
auto &Instructions = EntryBlock.getInstList();
std::string bufferName;
unsigned int bufferRegister;
switch (ShaderKind) {
case DXIL::ShaderKind::AnyHit:
case DXIL::ShaderKind::ClosestHit:
case DXIL::ShaderKind::Intersection:
bufferRegister = FallbackLayerHitGroupRecordByteAddressBufferRegister;
bufferName = "\01?HitGroupShaderTable@@3UByteAddressBuffer@@A";
break;
case DXIL::ShaderKind::Miss:
bufferRegister = FallbackLayerMissShaderRecordByteAddressBufferRegister;
bufferName = "\01?MissShaderTable@@3UByteAddressBuffer@@A";
break;
case DXIL::ShaderKind::RayGeneration:
bufferRegister = FallbackLayerRayGenShaderRecordByteAddressBufferRegister;
bufferName = "\01?RayGenShaderTable@@3UByteAddressBuffer@@A";
break;
case DXIL::ShaderKind::Callable:
bufferRegister = FallbackLayerCallableShaderRecordByteAddressBufferRegister;
bufferName = "\01?CallableShaderTable@@3UByteAddressBuffer@@A";
break;
}
unsigned int ShaderRecordID = AddSRVRawBuffer(
M, bufferRegister, FallbackLayerRegisterSpace, bufferName);
auto It = Instructions.begin();
OP *HlslOP = DM.GetOP();
LLVMContext &Ctx = M.getContext();
IRBuilder<> Builder(It);
{
auto ShaderTableName = "ShaderTableHandle";
llvm::Value *Symbol = DM.GetSRV(ShaderRecordID).GetGlobalSymbol();
llvm::Value *Load = Builder.CreateLoad(Symbol, "LoadShaderTableHandle");
Function *CreateHandleForLib =
HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
Constant *CreateHandleOpcodeArg =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
ShaderTableHandle = Builder.CreateCall(
CreateHandleForLib, {CreateHandleOpcodeArg, Load}, ShaderTableName);
}
{
auto CbufferName = "Constants";
const unsigned int sizeOfConstantsInBytes = sizeof(DispatchRaysConstants);
llvm::StructType *StructTy = M.getTypeByName(CbufferName);
if (!StructTy) {
const unsigned int numUintsInConstants =
sizeOfConstantsInBytes / sizeof(unsigned int);
SmallVector<llvm::Type *, numUintsInConstants> Elements(
numUintsInConstants);
for (unsigned int i = 0; i < numUintsInConstants; i++) {
Elements[i] = Type::getInt32Ty(Ctx);
}
StructTy = llvm::StructType::create(Elements, CbufferName);
AddAnnoationsIfNeeded(DM, StructTy, std::string(CbufferName),
numUintsInConstants);
}
unsigned int handle =
AddHandle(M, FallbackLayerDispatchConstantsRegister, 1,
FallbackLayerRegisterSpace, DXIL::ResourceClass::CBuffer,
DXIL::ResourceKind::CBuffer, CbufferName, StructTy,
sizeOfConstantsInBytes);
llvm::Value *Symbol = DM.GetCBuffer(handle).GetGlobalSymbol();
llvm::Value *Load = Builder.CreateLoad(Symbol, "DispatchRaysConstants");
Function *CreateHandleForLib =
HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, Load->getType());
Constant *CreateHandleOpcodeArg =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::CreateHandleForLib);
DispatchRaysConstantsHandle = Builder.CreateCall(
CreateHandleForLib, {CreateHandleOpcodeArg, Load}, CbufferName);
}
// Raygen always reads from the start so no offset calculations needed
if (ShaderKind != DXIL::ShaderKind::RayGeneration) {
std::string ShaderRecordOffsetFuncName =
"\x1?Fallback_ShaderRecordOffset@@YAIXZ";
Function *ShaderRecordOffsetFunc =
M.getFunction(ShaderRecordOffsetFuncName);
if (!ShaderRecordOffsetFunc) {
FunctionType *ShaderRecordOffsetFuncType =
FunctionType::get(llvm::Type::getInt32Ty(Ctx), {}, false);
ShaderRecordOffsetFunc =
Function::Create(ShaderRecordOffsetFuncType,
GlobalValue::LinkageTypes::ExternalLinkage,
ShaderRecordOffsetFuncName, &M);
}
BaseShaderRecordOffset =
Builder.CreateCall(ShaderRecordOffsetFunc, {}, "shaderRecordOffset");
} else {
BaseShaderRecordOffset = HlslOP->GetU32Const(0);
}
}
llvm::Value *DxilPatchShaderRecordBindings::CreateOffsetToShaderRecord(
Module &M, IRBuilder<> &Builder, unsigned int RecordOffsetInBytes,
llvm::Value *CbufferOffsetInBytes) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
// Create handle for the newly-added constant buffer (which is achieved via a
// function call)
auto AdddName = "ShaderRecordOffsetInBytes";
Constant *ShaderRecordOffsetInBytes = HlslOP->GetU32Const(
RecordOffsetInBytes); // Offset of constants in shader record buffer
return Builder.CreateAdd(CbufferOffsetInBytes, ShaderRecordOffsetInBytes,
AdddName);
}
llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadLegacy(
Module &M, IRBuilder<> &Builder, llvm::Value *ResourceHandle,
unsigned int RowToLoad) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
LLVMContext &Ctx = M.getContext();
auto BufferLoadName = "ConstantBuffer";
Function *BufferLoad =
HlslOP->GetOpFunc(DXIL::OpCode::CBufferLoadLegacy, Type::getInt32Ty(Ctx));
Constant *CBufferLoadOpcodeArg =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::CBufferLoadLegacy);
Constant *RowToLoadConst = HlslOP->GetU32Const(RowToLoad);
return Builder.CreateCall(
BufferLoad, {CBufferLoadOpcodeArg, ResourceHandle, RowToLoadConst},
BufferLoadName);
}
llvm::Value *DxilPatchShaderRecordBindings::CreateShaderRecordBufferLoad(
Module &M, IRBuilder<> &Builder, llvm::Value *ShaderRecordOffsetInBytes,
llvm::Type *type) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
LLVMContext &Ctx = M.getContext();
// Create handle for the newly-added constant buffer (which is achieved via a
// function call)
auto BufferLoadName = "ShaderRecordBuffer";
if (type->getNumContainedTypes() > 1) {
// TODO: Buffer loads aren't legal with container types, check if this is
// the right wait to handle this
type = type->getContainedType(0);
}
// TODO Do I need to check the result? Hopefully not
Function *BufferLoad = HlslOP->GetOpFunc(DXIL::OpCode::BufferLoad, type);
Constant *BufferLoadOpcodeArg =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::BufferLoad);
Constant *Unused = UndefValue::get(llvm::Type::getInt32Ty(Ctx));
return Builder.CreateCall(BufferLoad,
{BufferLoadOpcodeArg, ShaderTableHandle,
ShaderRecordOffsetInBytes, Unused},
BufferLoadName);
}
void DxilPatchShaderRecordBindings::ReplaceUsesOfWith(
llvm::Instruction *InstructionToReplace, llvm::Value *ValueToReplaceWith) {
for (auto UserIter = InstructionToReplace->user_begin();
UserIter != InstructionToReplace->user_end();) {
// Increment the iterator before the replace since the replace alters the
// uses list
auto userInstr = UserIter++;
userInstr->replaceUsesOfWith(InstructionToReplace, ValueToReplaceWith);
}
InstructionToReplace->eraseFromParent();
}
llvm::Value *DxilPatchShaderRecordBindings::CreateCBufferLoadOffsetInBytes(
Module &M, IRBuilder<> &Builder, llvm::Instruction *instruction) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
DxilInst_CBufferLoad cbufferLoad(instruction);
DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
if (cbufferLoad) {
return cbufferLoad.get_byteOffset();
} else if (cbufferLoadLegacy) {
Constant *LegacyMultiplier = HlslOP->GetU32Const(16);
return Builder.CreateMul(cbufferLoadLegacy.get_regIndex(),
LegacyMultiplier);
} else {
ThrowFailure();
return nullptr;
}
}
bool DxilPatchShaderRecordBindings::IsCBufferLoad(
llvm::Instruction *instruction) {
DxilInst_CBufferLoad cbufferLoad(instruction);
DxilInst_CBufferLoadLegacy cbufferLoadLegacy(instruction);
return cbufferLoad || cbufferLoadLegacy;
}
unsigned int GetResolvedRangeID(DXIL::ResourceClass resClass,
Value *rangeIdVal) {
if (auto CI = dyn_cast<ConstantInt>(rangeIdVal)) {
return CI->getZExtValue();
} else {
assert(false);
return 0;
}
}
// TODO: This code is quite inefficient
bool DxilPatchShaderRecordBindings::GetHandleInfo(
Module &M, DxilInst_CreateHandleForLib &createHandleStructForLib,
unsigned int &shaderRegister, unsigned int &registerSpace,
DXIL::ResourceKind &kind, DXIL::ResourceClass &resClass,
llvm::Type *&resType) {
DxilModule &DM = M.GetOrCreateDxilModule();
LoadInst *loadRangeId =
cast<LoadInst>(createHandleStructForLib.get_Resource());
Value *ResourceSymbol = loadRangeId->getPointerOperand();
DXIL::ResourceClass resourceClasses[] = {
DXIL::ResourceClass::CBuffer, DXIL::ResourceClass::SRV,
DXIL::ResourceClass::UAV, DXIL::ResourceClass::Sampler};
hlsl::DxilResourceBase *Resource = nullptr;
for (auto &resourceClass : resourceClasses) {
switch (resourceClass) {
case DXIL::ResourceClass::CBuffer: {
auto &cbuffers = DM.GetCBuffers();
for (auto &cbuffer : cbuffers) {
if (cbuffer->GetGlobalSymbol() == ResourceSymbol) {
Resource = cbuffer.get();
break;
}
}
break;
}
case DXIL::ResourceClass::SRV:
case DXIL::ResourceClass::UAV: {
auto &viewList = resourceClass == DXIL::ResourceClass::SRV ? DM.GetSRVs()
: DM.GetUAVs();
for (auto &view : viewList) {
if (view->GetGlobalSymbol() == ResourceSymbol) {
Resource = view.get();
break;
}
}
break;
}
case DXIL::ResourceClass::Sampler: {
auto &samplers = DM.GetSamplers();
for (auto &sampler : samplers) {
if (sampler->GetGlobalSymbol() == ResourceSymbol) {
Resource = sampler.get();
break;
}
}
break;
}
}
}
if (Resource) {
registerSpace = Resource->GetSpaceID();
shaderRegister = Resource->GetLowerBound();
kind = Resource->GetKind();
resClass = Resource->GetClass();
resType = Resource->GetHLSLType()->getPointerElementType();
}
return Resource != nullptr;
}
llvm::Value *DxilPatchShaderRecordBindings::LoadShaderRecordData(
Module &M, IRBuilder<> &Builder, llvm::Value *offsetToShaderRecord,
unsigned int dataOffsetInShaderRecord) {
DxilModule &DM = M.GetOrCreateDxilModule();
LLVMContext &Ctx = M.getContext();
OP *HlslOP = DM.GetOP();
Constant *dataOffset = HlslOP->GetU32Const(dataOffsetInShaderRecord);
Value *shaderTableOffsetToData =
Builder.CreateAdd(dataOffset, offsetToShaderRecord);
return CreateShaderRecordBufferLoad(M, Builder, shaderTableOffsetToData,
llvm::Type::getInt32Ty(Ctx));
}
void DxilPatchShaderRecordBindings::PatchCreateHandleToUseDescriptorIndex(
Module &M, IRBuilder<> &Builder, DXIL::ResourceKind &resourceKind,
DXIL::ResourceClass &resourceClass, llvm::Type *resourceType,
llvm::Value *descriptorIndex,
DxilInst_CreateHandleForLib &createHandleInstr) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
llvm::Value *descriptorHeapSymbol = GetAliasedDescriptorHeapHandle(
M, resourceType, resourceClass, resourceKind);
llvm::Value *viewSymbol = Builder.CreateGEP(
descriptorHeapSymbol, {HlslOP->GetU32Const(0), descriptorIndex},
"IndexIntoDH");
DxilMDHelper::MarkNonUniform(cast<Instruction>(viewSymbol));
llvm::Value *handle = Builder.CreateLoad(viewSymbol);
auto callInst = cast<CallInst>(createHandleInstr.Instr);
callInst->setCalledFunction(
HlslOP->GetOpFunc(DXIL::OpCode::CreateHandleForLib, handle->getType()));
createHandleInstr.set_Resource(handle);
}
void DxilPatchShaderRecordBindings::InitializeViewTable() {
// The Fallback Layer declares a bindless raw buffer that spans the entire
// descriptor heap, manually add it to the list of UAV register spaces used
if (*pInputShaderInfo->pNumUAVSpaces == 0) {
ViewKey key = {(unsigned int)hlsl::DXIL::ResourceKind::RawBuffer, {0}};
unsigned int index =
FindOrInsertViewIntoList(key, pInputShaderInfo->pUAVRegisterSpaceArray,
*pInputShaderInfo->pNumUAVSpaces,
FallbackLayerNumDescriptorHeapSpacesPerView);
(void)index;
assert(index == 0);
}
}
void DxilPatchShaderRecordBindings::PatchShaderBindings(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
OP *HlslOP = DM.GetOP();
// Don't erase instructions until the very end because it throws off the
// iterator
std::vector<llvm::Instruction *> instructionsToRemove;
for (BasicBlock &block : EntryPointFunction->getBasicBlockList()) {
auto &Instructions = block.getInstList();
for (auto &instr : Instructions) {
DxilInst_CreateHandleForLib createHandleForLib(&instr);
if (createHandleForLib) {
DXIL::ResourceClass resourceClass;
unsigned int registerSpace;
unsigned int registerIndex;
DXIL::ResourceKind kind;
llvm::Type *resType;
bool resourceIsResolved = true;
resourceIsResolved =
GetHandleInfo(M, createHandleForLib, registerIndex, registerSpace,
kind, resourceClass, resType);
if (!resourceIsResolved)
continue; // TODO: This shouldn't actually be happening?
ShaderRecordEntry shaderRecord = FindRootSignatureDescriptor(
*pRootSignatureDesc,
pInputShaderInfo->ShaderRecordIdentifierSizeInBytes, resourceClass,
registerIndex, registerSpace);
const bool IsBindingSpecifiedInLocalRootSignature =
!shaderRecord.IsInvalid();
if (IsBindingSpecifiedInLocalRootSignature) {
if (!DispatchRaysConstantsHandle) {
AddInputBinding(M);
}
switch (shaderRecord.ParameterType) {
case DxilRootParameterType::Constants32Bit: {
for (User *U : instr.users()) {
llvm::Instruction *instruction = cast<CallInst>(U);
if (IsCBufferLoad(instruction)) {
llvm::Instruction *cbufferLoadInstr = instruction;
IRBuilder<> Builder(cbufferLoadInstr);
llvm::Value *cbufferOffsetInBytes =
CreateCBufferLoadOffsetInBytes(M, Builder,
cbufferLoadInstr);
llvm::Value *LocalOffsetToRootConstant =
CreateOffsetToShaderRecord(M, Builder,
shaderRecord.RecordOffsetInBytes,
cbufferOffsetInBytes);
llvm::Value *GlobalOffsetToRootConstant = Builder.CreateAdd(
LocalOffsetToRootConstant, BaseShaderRecordOffset);
llvm::Value *srvBufferLoad = CreateShaderRecordBufferLoad(
M, Builder, GlobalOffsetToRootConstant,
cbufferLoadInstr->getType());
ReplaceUsesOfWith(cbufferLoadInstr, srvBufferLoad);
} else {
ThrowFailure();
}
}
instructionsToRemove.push_back(&instr);
break;
}
case DxilRootParameterType::DescriptorTable: {
IRBuilder<> Builder(&instr);
llvm::Value *srvBufferLoad =
LoadShaderRecordData(M, Builder, BaseShaderRecordOffset,
shaderRecord.RecordOffsetInBytes);
llvm::Value *DescriptorTableEntryLo = Builder.CreateExtractValue(
srvBufferLoad, 0, "DescriptorTableHandleLo");
unsigned int offsetToLoadInUints =
offsetof(DispatchRaysConstants, SrvCbvUavDescriptorHeapStart) /
sizeof(uint32_t);
unsigned int uintsPerRow = 4;
unsigned int rowToLoad = offsetToLoadInUints / uintsPerRow;
unsigned int extractValueOffset = offsetToLoadInUints % uintsPerRow;
llvm::Value *DescHeapConstants = CreateCBufferLoadLegacy(
M, Builder, DispatchRaysConstantsHandle, rowToLoad);
llvm::Value *DescriptorHeapStartAddressLo =
Builder.CreateExtractValue(DescHeapConstants,
extractValueOffset,
"DescriptorHeapStartHandleLo");
// TODO: The hi bits can only be ignored if the difference is
// guaranteed to be < 32 bytes. This is an unsafe assumption,
// particularly given large descriptor sizes
llvm::Value *DescriptorTableOffsetInBytes = Builder.CreateSub(
DescriptorTableEntryLo, DescriptorHeapStartAddressLo,
"TableOffsetInBytes");
Constant *DescriptorSizeInBytes = HlslOP->GetU32Const(
pInputShaderInfo->SrvCbvUavDescriptorSizeInBytes);
llvm::Value *DescriptorTableStartIndex = Builder.CreateExactUDiv(
DescriptorTableOffsetInBytes, DescriptorSizeInBytes,
"TableStartIndex");
Constant *RecordOffset =
HlslOP->GetU32Const(shaderRecord.OffsetInDescriptors);
llvm::Value *BaseDescriptorIndex = Builder.CreateAdd(
DescriptorTableStartIndex, RecordOffset, "BaseDescriptorIndex");
// TODO: Not supporting dynamic indexing yet, should be pulled from
// CreateHandleForLib If dynamic indexing is being used, add the
// apps index on top of the calculated index
llvm::Value *DynamicIndex = HlslOP->GetU32Const(0);
llvm::Value *DescriptorIndex = Builder.CreateAdd(
BaseDescriptorIndex, DynamicIndex, "DescriptorIndex");
PatchCreateHandleToUseDescriptorIndex(
M, Builder, kind, resourceClass, resType, DescriptorIndex,
createHandleForLib);
break;
}
case DxilRootParameterType::CBV:
case DxilRootParameterType::SRV:
case DxilRootParameterType::UAV: {
IRBuilder<> Builder(&instr);
llvm::Value *srvBufferLoad =
LoadShaderRecordData(M, Builder, BaseShaderRecordOffset,
shaderRecord.RecordOffsetInBytes);
llvm::Value *DescriptorIndex = Builder.CreateExtractValue(
srvBufferLoad, 1, "DescriptorHeapIndex");
// TODO: Handle offset in bytes
// llvm::Value *OffsetInBytes = Builder.CreateExtractValue(
// srvBufferLoad, 0, "OffsetInBytes");
PatchCreateHandleToUseDescriptorIndex(
M, Builder, kind, resourceClass, resType, DescriptorIndex,
createHandleForLib);
break;
}
default:
ThrowFailure();
break;
}
}
}
}
}
for (auto instruction : instructionsToRemove) {
instruction->eraseFromParent();
}
}
bool IsParameterTypeCompatibleWithResourceClass(
DXIL::ResourceClass resourceClass, DxilRootParameterType parameterType) {
switch (parameterType) {
case DxilRootParameterType::DescriptorTable:
return true;
case DxilRootParameterType::Constants32Bit:
case DxilRootParameterType::CBV:
return resourceClass == DXIL::ResourceClass::CBuffer;
case DxilRootParameterType::SRV:
return resourceClass == DXIL::ResourceClass::SRV;
case DxilRootParameterType::UAV:
return resourceClass == DXIL::ResourceClass::UAV;
default:
ThrowFailure();
return false;
}
}
DxilRootParameterType
ConvertD3D12ParameterTypeToDxil(DxilRootParameterType parameter) {
switch (parameter) {
case DxilRootParameterType::Constants32Bit:
return DxilRootParameterType::Constants32Bit;
case DxilRootParameterType::DescriptorTable:
return DxilRootParameterType::DescriptorTable;
case DxilRootParameterType::CBV:
return DxilRootParameterType::CBV;
case DxilRootParameterType::SRV:
return DxilRootParameterType::SRV;
case DxilRootParameterType::UAV:
return DxilRootParameterType::UAV;
}
assert(false);
return (DxilRootParameterType)-1;
}
DXIL::ResourceClass
ConvertD3D12RangeTypeToDxil(DxilDescriptorRangeType rangeType) {
switch (rangeType) {
case DxilDescriptorRangeType::SRV:
return DXIL::ResourceClass::SRV;
case DxilDescriptorRangeType::UAV:
return DXIL::ResourceClass::UAV;
case DxilDescriptorRangeType::CBV:
return DXIL::ResourceClass::CBuffer;
case DxilDescriptorRangeType::Sampler:
return DXIL::ResourceClass::Sampler;
}
assert(false);
return (DXIL::ResourceClass)-1;
}
unsigned int GetParameterTypeAlignment(DxilRootParameterType parameterType) {
switch (parameterType) {
case DxilRootParameterType::DescriptorTable:
return SizeofD3D12GpuDescriptorHandle;
case DxilRootParameterType::Constants32Bit:
return sizeof(uint32_t);
case DxilRootParameterType::CBV: // fallthrough
case DxilRootParameterType::SRV: // fallthrough
case DxilRootParameterType::UAV:
return SizeofD3D12GpuVA;
default:
return UINT_MAX;
}
}
template <typename TD3D12_ROOT_SIGNATURE_DESC>
ShaderRecordEntry FindRootSignatureDescriptorHelper(
const TD3D12_ROOT_SIGNATURE_DESC &rootSignatureDescriptor,
unsigned int ShaderRecordIdentifierSizeInBytes,
DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
unsigned int registerSpace) {
// Automatically fail if it's looking for a fallback binding as these never
// need to be patched
if (registerSpace != FallbackLayerRegisterSpace) {
unsigned int recordOffset = ShaderRecordIdentifierSizeInBytes;
for (unsigned int rootParamIndex = 0;
rootParamIndex < rootSignatureDescriptor.NumParameters;
rootParamIndex++) {
auto &rootParam = rootSignatureDescriptor.pParameters[rootParamIndex];
auto dxilParamType =
ConvertD3D12ParameterTypeToDxil(rootParam.ParameterType);
#define ALIGN(alignment, num) (((num + alignment - 1) / alignment) * alignment)
recordOffset = ALIGN(GetParameterTypeAlignment(rootParam.ParameterType),
recordOffset);
switch (rootParam.ParameterType) {
case DxilRootParameterType::Constants32Bit:
if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
dxilParamType) &&
baseRegisterIndex == rootParam.Constants.ShaderRegister &&
registerSpace == rootParam.Constants.RegisterSpace) {
return {dxilParamType, recordOffset, 0};
}
recordOffset += rootParam.Constants.Num32BitValues * sizeof(uint32_t);
break;
case DxilRootParameterType::DescriptorTable: {
auto &descriptorTable = rootParam.DescriptorTable;
unsigned int rangeOffsetInDescriptors = 0;
for (unsigned int rangeIndex = 0;
rangeIndex < descriptorTable.NumDescriptorRanges; rangeIndex++) {
auto &range = descriptorTable.pDescriptorRanges[rangeIndex];
if (range.OffsetInDescriptorsFromTableStart != (unsigned)-1) {
rangeOffsetInDescriptors = range.OffsetInDescriptorsFromTableStart;
}
if (ConvertD3D12RangeTypeToDxil(range.RangeType) == resourceClass &&
range.RegisterSpace == registerSpace &&
range.BaseShaderRegister <= baseRegisterIndex &&
range.BaseShaderRegister + range.NumDescriptors >
baseRegisterIndex) {
rangeOffsetInDescriptors +=
baseRegisterIndex - range.BaseShaderRegister;
return {dxilParamType, recordOffset, rangeOffsetInDescriptors};
}
rangeOffsetInDescriptors += range.NumDescriptors;
}
recordOffset += SizeofD3D12GpuDescriptorHandle;
break;
}
case DxilRootParameterType::CBV:
case DxilRootParameterType::SRV:
case DxilRootParameterType::UAV:
if (IsParameterTypeCompatibleWithResourceClass(resourceClass,
dxilParamType) &&
baseRegisterIndex == rootParam.Descriptor.ShaderRegister &&
registerSpace == rootParam.Descriptor.RegisterSpace) {
return {dxilParamType, recordOffset, 0};
}
recordOffset += SizeofD3D12GpuVA;
break;
}
}
}
return ShaderRecordEntry::InvalidEntry();
}
// TODO: Consider pre-calculating this into a map
ShaderRecordEntry DxilPatchShaderRecordBindings::FindRootSignatureDescriptor(
const DxilVersionedRootSignatureDesc &rootSignatureDescriptor,
unsigned int ShaderRecordIdentifierSizeInBytes,
DXIL::ResourceClass resourceClass, unsigned int baseRegisterIndex,
unsigned int registerSpace) {
switch (rootSignatureDescriptor.Version) {
case DxilRootSignatureVersion::Version_1_0:
return FindRootSignatureDescriptorHelper(
rootSignatureDescriptor.Desc_1_0, ShaderRecordIdentifierSizeInBytes,
resourceClass, baseRegisterIndex, registerSpace);
case DxilRootSignatureVersion::Version_1_1:
return FindRootSignatureDescriptorHelper(
rootSignatureDescriptor.Desc_1_1, ShaderRecordIdentifierSizeInBytes,
resourceClass, baseRegisterIndex, registerSpace);
default:
ThrowFailure();
return ShaderRecordEntry::InvalidEntry();
}
}