diff --git a/include/dxc/HLSL/DxilPipelineStateValidation.h b/include/dxc/HLSL/DxilPipelineStateValidation.h index 6a66d995f..d740eee3d 100644 --- a/include/dxc/HLSL/DxilPipelineStateValidation.h +++ b/include/dxc/HLSL/DxilPipelineStateValidation.h @@ -158,15 +158,23 @@ struct PSVResourceBindInfo0 uint32_t UpperBound; }; -struct RuntimeDataResourceInfo : public PSVResourceBindInfo0 +struct RuntimeDataResourceInfo { - uint32_t Kind; // PSVResourceKind - uint32_t Name; // offset for string table + uint32_t ResType; // PSVResourceType + uint32_t Space; + uint32_t LowerBound; + uint32_t UpperBound; + uint32_t Kind; // PSVResourceKind + uint32_t Name; // offset for string table + uint32_t ID; // id per resource class + uint32_t flags; // flag for resource. }; + struct RuntimeDataFunctionInfo { uint32_t Name; // offset for string table uint32_t UnmangledName; // offset for string table uint32_t Resources; // index to an index table + uint32_t FunctionDependencies; // index to a list of functions that function depends on uint32_t ShaderKind; // shader kind uint32_t PayloadSizeInBytes; // payload count for miss, closest hit, any hit // shader, or parameter size for call shader @@ -219,7 +227,7 @@ struct RuntimeDataTableHeader { uint32_t offset; }; -enum RuntimeDataTableType : uint32_t { +enum RuntimeDataPartType : uint32_t { Invalid = 0, String, Function, @@ -436,6 +444,9 @@ public: return m_ResourceInfo->UpperBound; } PSVResourceKind GetResourceKind() { return (PSVResourceKind)m_ResourceInfo->Kind; } + uint32_t GetID() { + return m_ResourceInfo->ID; + } const char *GetName() { return m_Context->pStringTableReader->Get(m_ResourceInfo->Name); } @@ -462,12 +473,34 @@ public: m_CBufferCount(CBufferCount), m_SamplerCount(SamplerCount), m_SRVCount(SRVCount), m_UAVCount(UAVCount){}; - void SetResourceInfo(const RuntimeDataResourceInfo *ptr) { m_ResourceInfo = ptr; } + void SetResourceInfo(const RuntimeDataResourceInfo *ptr, uint32_t count) { + m_ResourceInfo = ptr; + // Assuming that resources are in order of CBuffer, Sampler, SRV, and UAV, + // count the number for each resource class + m_CBufferCount = 0; + m_SamplerCount = 0; + m_SRVCount = 0; + m_UAVCount = 0; + + for (uint32_t i = 0; i < count; ++i) { + const RuntimeDataResourceInfo *curPtr = &ptr[i]; + if (curPtr->ResType == (uint32_t)PSVResourceType::CBV) + m_CBufferCount++; + else if (curPtr->ResType == (uint32_t)PSVResourceType::Sampler) + m_SamplerCount++; + else if (curPtr->ResType == (uint32_t)PSVResourceType::SRVRaw || + curPtr->ResType == (uint32_t)PSVResourceType::SRVStructured || + curPtr->ResType == (uint32_t)PSVResourceType::SRVTyped) + m_SRVCount++; + else if (curPtr->ResType == (uint32_t)PSVResourceType::UAVRaw || + curPtr->ResType == (uint32_t)PSVResourceType::UAVStructured || + curPtr->ResType == (uint32_t)PSVResourceType::UAVStructuredWithCounter || + curPtr->ResType == (uint32_t)PSVResourceType::UAVTyped) + m_UAVCount++; + } + } + void SetContext(RuntimeDataContext *context) { m_Context = context; } - void SetCBufferCount(uint32_t count) { m_CBufferCount = count; } - void SetSamplerCount(uint32_t count) { m_SamplerCount = count; } - void SetSRVCount(uint32_t count) { m_SRVCount = count; } - void SetUAVCount(uint32_t count) { m_UAVCount = count; } uint32_t GetNumResources() { return m_CBufferCount + m_SamplerCount + m_SRVCount + m_UAVCount; @@ -477,7 +510,6 @@ public: return ResourceReader(&m_ResourceInfo[i], m_Context); } - uint32_t GetNumCBuffers() { return m_CBufferCount; } ResourceReader GetCBuffer(uint32_t i) { _Analysis_assume_(i < m_CBufferCount); @@ -528,7 +560,7 @@ public: } uint32_t GetShaderStageFlag() { return m_RuntimeDataFunctionInfo->ShaderStageFlag; } uint32_t GetMinShaderTarget() { return m_RuntimeDataFunctionInfo->MinShaderTarget; } - uint32_t FunctionReader::GetNumResources() { + uint32_t GetNumResources() { if (m_RuntimeDataFunctionInfo->Resources == UINT_MAX) return 0; return m_Context->pIndexTableReader->getRow(m_RuntimeDataFunctionInfo->Resources).Count(); @@ -538,6 +570,20 @@ public: return m_Context->pResourceTableReader->GetItem(resIndex); } + uint32_t GetNumDependencies() { + if (m_RuntimeDataFunctionInfo->FunctionDependencies == UINT_MAX) + return 0; + return m_Context->pIndexTableReader + ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies).Count(); + } + + const char *GetDependency(uint32_t i) { + uint32_t resIndex = + m_Context->pIndexTableReader + ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies).At(i); + return m_Context->pStringTableReader->Get(resIndex); + } + uint32_t GetPayloadSizeInBytes() { return m_RuntimeDataFunctionInfo->PayloadSizeInBytes; } uint32_t GetAttributeSizeInBytes() { return m_RuntimeDataFunctionInfo->AttributeSizeInBytes; } // payload (hit shaders) and parameters (call shaders) are mutually exclusive @@ -580,6 +626,8 @@ public: m_FunctionTableReader(), m_IndexTableReader(), m_Context() { m_Context = {&m_StringReader, &m_IndexTableReader, &m_ResourceTableReader, &m_FunctionTableReader}; + m_ResourceTableReader.SetContext(&m_Context); + m_FunctionTableReader.SetContext(&m_Context); } DxilRuntimeData(const char *ptr) { InitFromRDAT(ptr); @@ -592,32 +640,24 @@ public: for (uint32_t i = 0; i < TableCount; ++i) { RuntimeDataTableHeader *curRecord = &records[i]; switch (curRecord->tableType) { - case RuntimeDataTableType::Resource: { - uint32_t cBufferCount = *(uint32_t*)(ptr + curRecord->offset); - uint32_t samplerCount = *(uint32_t*)(ptr + curRecord->offset + 4); - uint32_t srvCount = *(uint32_t*)(ptr + curRecord->offset + 8); - uint32_t uavCount = *(uint32_t*)(ptr + curRecord->offset + 12); - m_ResourceTableReader.SetResourceInfo((RuntimeDataResourceInfo*)(ptr + curRecord->offset + 16)); - m_ResourceTableReader.SetCBufferCount(cBufferCount); - m_ResourceTableReader.SetSamplerCount(samplerCount); - m_ResourceTableReader.SetSRVCount(srvCount); - m_ResourceTableReader.SetUAVCount(uavCount); - m_ResourceTableReader.SetContext(&m_Context); + case RuntimeDataPartType::Resource: { + m_ResourceTableReader.SetResourceInfo( + (RuntimeDataResourceInfo *)(ptr + curRecord->offset), + curRecord->size / sizeof(RuntimeDataResourceInfo)); break; } - case RuntimeDataTableType::String: { + case RuntimeDataPartType::String: { m_StringReader = PSVStringTable(ptr + curRecord->offset, curRecord->size); break; } - case RuntimeDataTableType::Function: { - RuntimeDataFunctionInfo *funcInfo = - (RuntimeDataFunctionInfo *)(ptr + curRecord->offset); - m_FunctionTableReader.SetFunctionInfo(funcInfo); - m_FunctionTableReader.SetCount(curRecord->size / sizeof(RuntimeDataFunctionInfo)); - m_FunctionTableReader.SetContext(&m_Context); + case RuntimeDataPartType::Function: { + m_FunctionTableReader.SetFunctionInfo( + (RuntimeDataFunctionInfo *)(ptr + curRecord->offset)); + m_FunctionTableReader.SetCount(curRecord->size / + sizeof(RuntimeDataFunctionInfo)); break; } - case RuntimeDataTableType::Index: { + case RuntimeDataPartType::Index: { m_IndexTableReader = IndexTableReader( (uint32_t *)(ptr + curRecord->offset), curRecord->size / 4); break; diff --git a/include/dxc/HLSL/DxilUtil.h b/include/dxc/HLSL/DxilUtil.h index 5ac2b6a4c..c3794a898 100644 --- a/include/dxc/HLSL/DxilUtil.h +++ b/include/dxc/HLSL/DxilUtil.h @@ -20,6 +20,9 @@ class Module; class MemoryBuffer; class LLVMContext; class DiagnosticInfo; +class Value; +class Instruction; +class StringRef; } namespace hlsl { diff --git a/lib/HLSL/DxilContainerAssembler.cpp b/lib/HLSL/DxilContainerAssembler.cpp index db4d2283c..c65937b86 100644 --- a/lib/HLSL/DxilContainerAssembler.cpp +++ b/lib/HLSL/DxilContainerAssembler.cpp @@ -20,6 +20,7 @@ #include "dxc/HLSL/DxilRootSignature.h" #include "dxc/HLSL/DxilUtil.h" #include "dxc/HLSL/DxilFunctionProps.h" +#include "dxc/HLSL/DxilOperations.h" #include "dxc/Support/Global.h" #include "dxc/Support/Unicode.h" #include "dxc/Support/WinIncludes.h" @@ -699,154 +700,54 @@ public: } }; -class RDATTable { +// Like DXIL container, RDAT itself is a mini container that contains multiple RDAT parts +class RDATPart { public: - virtual uint32_t GetBlobSize() const { return 0; } - virtual void write(void *ptr) {} - virtual RuntimeDataTableType GetType() const { return RuntimeDataTableType::Invalid; } + virtual uint32_t GetPartSize() const { return 0; } + virtual void Write(void *ptr) {} + virtual RuntimeDataPartType GetType() const { return RuntimeDataPartType::Invalid; } + virtual ~RDATPart() {} +}; + +// Most RDAT parts are tables each containing a list of structures of same type. +// Exceptions are string table and index table because each string or list of +// indicies can be of different sizes. +template +class RDATTable : public RDATPart { +protected: + std::vector m_rows; +public: + virtual void Insert(T *data) {} virtual ~RDATTable() {} -}; -class ResourceTable : public RDATTable { -private: - uint32_t m_Version; - std::vector> CBufferToOffset; - std::vector> SamplerToOffset; - std::vector> SRVToOffset; - std::vector> UAVToOffset; - - void UpdateResourceInfo(const DxilResourceBase *res, uint32_t offset, - RuntimeDataResourceInfo *info, char **pCur) { - info->Kind = static_cast(res->GetKind()); - info->Space = res->GetSpaceID(); - info->LowerBound = res->GetLowerBound(); - info->UpperBound = res->GetUpperBound(); - info->Name = offset; - memcpy(*pCur, info, sizeof(RuntimeDataResourceInfo)); - *pCur += sizeof(RuntimeDataResourceInfo); + void Insert(const T &data) { + m_rows.push_back(data); } -public: - ResourceTable(uint32_t version) : m_Version(version), CBufferToOffset(), SamplerToOffset(), SRVToOffset(), UAVToOffset() {} - void AddCBuffer(const DxilCBuffer *resource, uint32_t offset) { - CBufferToOffset.emplace_back( - std::pair(resource, offset)); - } - void AddSampler(const DxilSampler *resource, uint32_t offset) { - SamplerToOffset.emplace_back( - std::pair(resource, offset)); - } - void AddSRV(const DxilResource *resource, uint32_t offset) { - SRVToOffset.emplace_back( - std::pair(resource, offset)); - } - void AddUAV(const DxilResource *resource, uint32_t offset) { - UAVToOffset.emplace_back( - std::pair(resource, offset)); - } - uint32_t NumResources() const { - return CBufferToOffset.size() + SamplerToOffset.size() + - SRVToOffset.size() + UAVToOffset.size(); - } - RuntimeDataTableType GetType() const { return RuntimeDataTableType::Resource; } - uint32_t GetBlobSize() const { - return NumResources() * sizeof(RuntimeDataResourceInfo) + - 4 * sizeof(uint32_t); - } - void write(void *ptr) { - // Only impelemented for RDAT for now - if (m_Version == 0) { - char *pCur = (char*)ptr; - // count for each resource class - uint32_t cBufferCount = CBufferToOffset.size(); - uint32_t samplerCount = SamplerToOffset.size(); - uint32_t srvCount = SRVToOffset.size(); - uint32_t uavCount = UAVToOffset.size(); - memcpy(pCur, &cBufferCount, sizeof(uint32_t)); - pCur += sizeof(uint32_t); - memcpy(pCur, &samplerCount, sizeof(uint32_t)); - pCur += sizeof(uint32_t); - memcpy(pCur, &srvCount, sizeof(uint32_t)); - pCur += sizeof(uint32_t); - memcpy(pCur, &uavCount, sizeof(uint32_t)); - pCur += sizeof(uint32_t); - - for (auto pair : CBufferToOffset) { - RuntimeDataResourceInfo info = {}; - info.ResType = static_cast(PSVResourceType::CBV); - UpdateResourceInfo(pair.first, pair.second, &info, &pCur); - } - for (auto pair : SamplerToOffset) { - RuntimeDataResourceInfo info = {}; - info.ResType = static_cast(PSVResourceType::Sampler); - UpdateResourceInfo(pair.first, pair.second, &info, &pCur); - } - for (auto pair : SRVToOffset) { - RuntimeDataResourceInfo info = {}; - auto res = pair.first; - if (res->IsStructuredBuffer()) { - info.ResType = (UINT)PSVResourceType::SRVStructured; - } else if (res->IsRawBuffer()) { - info.ResType = (UINT)PSVResourceType::SRVRaw; - } else { - info.ResType = (UINT)PSVResourceType::SRVTyped; - } - UpdateResourceInfo(pair.first, pair.second, &info, &pCur); - } - for (auto pair : UAVToOffset) { - RuntimeDataResourceInfo info = {}; - auto res = pair.first; - if (res->IsStructuredBuffer()) { - if (res->HasCounter()) - info.ResType = (UINT)PSVResourceType::UAVStructuredWithCounter; - else - info.ResType = (UINT)PSVResourceType::UAVStructured; - } else if (res->IsRawBuffer()) { - info.ResType = (UINT)PSVResourceType::UAVRaw; - } else { - info.ResType = (UINT)PSVResourceType::UAVTyped; - } - UpdateResourceInfo(res, pair.second, &info, &pCur); - } + void Write(void *ptr) { + char *pCur = (char*)ptr; + for (auto row : m_rows) { + memcpy(pCur, &row, sizeof(T)); + pCur += sizeof(T); } - } + }; + + uint32_t GetPartSize() const { return m_rows.size() * sizeof(T); } }; -class FunctionTable : public RDATTable { -private: - std::vector> FuncToInfo; +// Resource table will contain a list of RuntimeDataResourceInfo in order of +// CBuffer, Sampler, SRV, and UAV resource classes. +class ResourceTable : public RDATTable { public: - FunctionTable(): FuncToInfo() {} - uint32_t NumFunctions() const { return FuncToInfo.size(); } - void AddFunction(const llvm::Function *func, uint32_t mangledOfffset, - uint32_t unmangledOffset, uint32_t shaderKind, uint32_t resourceIndex, - uint32_t payloadSizeInBytes, uint32_t attrSizeInBytes, ShaderFlags flags) { - RuntimeDataFunctionInfo info = {}; - info.Name = mangledOfffset; - info.UnmangledName = unmangledOffset; - info.ShaderKind = shaderKind; - info.Resources = resourceIndex; - info.PayloadSizeInBytes = payloadSizeInBytes; - info.AttributeSizeInBytes = attrSizeInBytes; - uint64_t rawFlags = flags.GetShaderFlagsRaw(); - info.FeatureInfo1 = rawFlags & 0xffffffff; - info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff; - FuncToInfo.push_back({ func, info }); - } - - uint32_t GetBlobSize() const { return NumFunctions() * sizeof(RuntimeDataFunctionInfo); } - RuntimeDataTableType GetType() const { return RuntimeDataTableType::Function; } - void write(void *ptr) { - char *cur = (char *)ptr; - for (auto &&pair : FuncToInfo) { - auto offset = pair.second; - memcpy(cur, &offset, sizeof(RuntimeDataFunctionInfo)); - cur += sizeof(RuntimeDataFunctionInfo); - } - } + RuntimeDataPartType GetType() const { return RuntimeDataPartType::Resource; } }; -class StringTable : public RDATTable { +class FunctionTable : public RDATTable { +public: + RuntimeDataPartType GetType() const { return RuntimeDataPartType::Function; } +}; + +class StringTable : public RDATPart { private: SmallVector m_StringBuffer; uint32_t curIndex; @@ -863,42 +764,41 @@ public: curIndex += name.size() + 1; return prevIndex; } - RuntimeDataTableType GetType() const { return RuntimeDataTableType::String; } - uint32_t GetBlobSize() const { return m_StringBuffer.size(); } - void write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); } + RuntimeDataPartType GetType() const { return RuntimeDataPartType::String; } + uint32_t GetPartSize() const { return m_StringBuffer.size(); } + void Write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); } }; -template -struct IndexTable : public RDATTable { +struct IndexTable : public RDATPart { private: - std::vector> m_IndicesList; + std::vector> m_IndicesList; uint32_t m_curOffset; public: IndexTable() : m_IndicesList(), m_curOffset(0) {} - uint32_t AddIndex(const std::vector &Indices) { + uint32_t AddIndex(const std::vector &Indices) { uint32_t prevOffset = m_curOffset; m_curOffset += Indices.size() + 1; m_IndicesList.emplace_back(std::move(Indices)); return prevOffset; } - RuntimeDataTableType GetType() const { return RuntimeDataTableType::Index; } - uint32_t GetBlobSize() const { + RuntimeDataPartType GetType() const { return RuntimeDataPartType::Index; } + uint32_t GetPartSize() const { uint32_t size = 0; for (auto Indices : m_IndicesList) { size += Indices.size() + 1; } - return sizeof(T) * size; + return sizeof(uint32_t) * size; } - void write(void *ptr) { - T *cur = (T*)ptr; + void Write(void *ptr) { + uint32_t *cur = (uint32_t*)ptr; for (auto Indices : m_IndicesList) { uint32_t count = Indices.size(); memcpy(cur, &count, 4); std::copy(Indices.data(), Indices.data() + Indices.size(), cur + 1); - cur += sizeof(T)/sizeof(4) + Indices.size(); + cur += sizeof(uint32_t)/sizeof(4) + Indices.size(); } } }; @@ -908,62 +808,130 @@ private: const DxilModule &m_Module; SmallVector m_RDATBuffer; - std::vector> m_tables; - std::map> m_FuncToResNameOffset; + std::vector> m_tables; + typedef std::unordered_map> FunctionIndexMap; + FunctionIndexMap m_FuncToResNameOffset; // list of resources used + FunctionIndexMap m_FuncToDependencies; // list of unresolved functions used - void UpdateFunctionToResourceInfo(const DxilResourceBase *resource, uint32_t offset) { + llvm::Function *FindUsingFunction(llvm::Value *User) { + if (llvm::Instruction *I = dyn_cast(User)) { + // Instruction should be inside a basic block, which is in a function + return cast(I->getParent()->getParent()); + } + // User can be either instruction, constant, or operator. But User is an + // operator only if constant is a scalar value, not resource pointer. + llvm::Constant *CU = cast(User); + return FindUsingFunction(*CU->user_begin()); + } + + void UpdateFunctionToResourceInfo(const DxilResourceBase *resource, + uint32_t offset) { Constant *var = resource->GetGlobalSymbol(); if (var) { for (auto user : var->users()) { - if (llvm::Instruction *I = dyn_cast(user)) { - if (llvm::Function *F = dyn_cast(I->getParent()->getParent())) { - if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) { - m_FuncToResNameOffset[F].emplace_back(offset); - } - else { - m_FuncToResNameOffset[F] = std::vector({offset}); - } - } + // Find the function. + llvm::Function *F = FindUsingFunction(user); + if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) { + m_FuncToResNameOffset[F].emplace_back(offset); + } + else { + m_FuncToResNameOffset[F] = std::vector({offset}); } } } } + + void InsertToResourceTable(DxilResourceBase &resource, + PSVResourceType resType, + ResourceTable &resourceTable, + StringTable &stringTable, + uint32_t &resourceIndex) { + uint32_t stringIndex = stringTable.Insert(resource.GetGlobalName()); + UpdateFunctionToResourceInfo(&resource, resourceIndex++); + RuntimeDataResourceInfo info = {}; + info.Kind = static_cast(resource.GetKind()); + info.ResType = (uint32_t)resType, + info.Space = resource.GetSpaceID(); + info.LowerBound = resource.GetLowerBound(); + info.UpperBound = resource.GetUpperBound(); + info.Name = stringIndex; + info.ID = resource.GetID(); + resourceTable.Insert(info); + } + void UpdateResourceInfo(StringTable &stringTable) { // Try to allocate string table for resources. String table is a sequence // of strings delimited by \0 - m_tables.emplace_back(std::make_unique(0)); + m_tables.emplace_back(std::make_unique()); ResourceTable &resourceTable = *(ResourceTable*)m_tables.back().get(); - uint32_t stringIndex; uint32_t resourceIndex = 0; for (auto &resource : m_Module.GetCBuffers()) { - stringIndex = stringTable.Insert(resource->GetGlobalName()); - UpdateFunctionToResourceInfo(resource.get(), resourceIndex++); - resourceTable.AddCBuffer(resource.get(), stringIndex); + InsertToResourceTable(*resource.get(), PSVResourceType::CBV, resourceTable, stringTable, + resourceIndex); + } for (auto &resource : m_Module.GetSamplers()) { - stringIndex = stringTable.Insert(resource->GetGlobalName()); - UpdateFunctionToResourceInfo(resource.get(), resourceIndex++); - resourceTable.AddSampler(resource.get(), stringIndex); + InsertToResourceTable(*resource.get(), PSVResourceType::Sampler, resourceTable, stringTable, + resourceIndex); } for (auto &resource : m_Module.GetSRVs()) { - stringIndex = stringTable.Insert(resource->GetGlobalName()); - UpdateFunctionToResourceInfo(resource.get(), resourceIndex++); - resourceTable.AddSRV(resource.get(), stringIndex); + PSVResourceType resType = PSVResourceType::Invalid; + if (resource->IsStructuredBuffer()) { + resType = PSVResourceType::SRVStructured; + } else if (resource->IsRawBuffer()) { + resType = PSVResourceType::SRVRaw; + } else { + resType = PSVResourceType::SRVTyped; + } + InsertToResourceTable(*resource.get(), resType, resourceTable, stringTable, + resourceIndex); } for (auto &resource : m_Module.GetUAVs()) { - stringIndex = stringTable.Insert(resource->GetGlobalName()); - UpdateFunctionToResourceInfo(resource.get(), resourceIndex++); - resourceTable.AddUAV(resource.get(), stringIndex); + PSVResourceType resType = PSVResourceType::Invalid; + if (resource->IsStructuredBuffer()) { + if (resource->HasCounter()) + resType = PSVResourceType::UAVStructuredWithCounter; + else + resType = PSVResourceType::UAVStructured; + } else if (resource->IsRawBuffer()) { + resType = PSVResourceType::UAVRaw; + } else { + resType = PSVResourceType::UAVTyped; + } + InsertToResourceTable(*resource.get(), resType, resourceTable, stringTable, + resourceIndex); } } + + void UpdateFunctionDependency(llvm::Function *F, StringTable &stringTable) { + for (const auto &user : F->users()) { + llvm::Function *userFunction = FindUsingFunction(user); + uint32_t index = stringTable.Insert(F->getName()); + if (m_FuncToDependencies.find(userFunction) == + m_FuncToDependencies.end()) { + m_FuncToDependencies[userFunction] = + std::vector({index}); + } else { + m_FuncToDependencies[userFunction].push_back(index); + } + } + } + void UpdateFunctionInfo(StringTable &stringTable) { - // TODO: get a list of required features // TODO: get a list of valid shader flags // TODO: get a minimum shader version + std::unordered_map> + FuncToUnresolvedDependencies; m_tables.emplace_back(std::make_unique()); FunctionTable &functionTable = *(FunctionTable*)(m_tables.back().get()); - m_tables.emplace_back(std::make_unique>()); - IndexTable &indexTable = *(IndexTable*)(m_tables.back().get()); + m_tables.emplace_back(std::make_unique()); + IndexTable &indexTable = *(IndexTable*)(m_tables.back().get()); + for (auto &function : m_Module.GetModule()->getFunctionList()) { + // If function is a declaration, it is an unresolved dependency in the library + if (function.isDeclaration() && !OP::IsDxilOpFunc(&function)) { + UpdateFunctionDependency(&function, stringTable); + } + } for (auto &function : m_Module.GetModule()->getFunctionList()) { if (!function.isDeclaration()) { StringRef mangled = function.getName(); @@ -972,11 +940,14 @@ private: uint32_t unmangledIndex = stringTable.Insert(unmangled); // Update resource Index uint32_t resourceIndex = UINT_MAX; + uint32_t functionDependencies = UINT_MAX; uint32_t payloadSizeInBytes = 0; uint32_t attrSizeInBytes = 0; uint32_t shaderKind = (uint32_t)PSVShaderKind::Library; if (m_FuncToResNameOffset.find(&function) != m_FuncToResNameOffset.end()) resourceIndex = indexTable.AddIndex(m_FuncToResNameOffset[&function]); + if (m_FuncToDependencies.find(&function) != m_FuncToDependencies.end()) + functionDependencies = indexTable.AddIndex(m_FuncToDependencies[&function]); if (m_Module.HasDxilFunctionProps(&function)) { auto props = m_Module.GetDxilFunctionProps(&function); if (props.IsClosestHit() || props.IsAnyHit()) { @@ -992,9 +963,18 @@ private: shaderKind = (uint32_t)props.shaderKind; } ShaderFlags flags = ShaderFlags::CollectShaderFlags(&function, &m_Module); - functionTable.AddFunction(&function, mangledIndex, unmangledIndex, - shaderKind, resourceIndex, - payloadSizeInBytes, attrSizeInBytes, flags); + RuntimeDataFunctionInfo info = {}; + info.Name = mangledIndex; + info.UnmangledName = unmangledIndex; + info.ShaderKind = shaderKind; + info.Resources = resourceIndex; + info.FunctionDependencies = functionDependencies; + info.PayloadSizeInBytes = payloadSizeInBytes; + info.AttributeSizeInBytes = attrSizeInBytes; + uint64_t rawFlags = flags.GetShaderFlagsRaw(); + info.FeatureInfo1 = rawFlags & 0xffffffff; + info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff; + functionTable.Insert(info); } } } @@ -1013,7 +993,7 @@ public: // one variable to count the number of blobs and two blobs uint32_t total = 4 + m_tables.size() * sizeof(RuntimeDataTableHeader); for (auto &&table : m_tables) - total += table->GetBlobSize(); + total += table->GetPartSize(); return total; } @@ -1027,15 +1007,15 @@ public: // write records uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4; for (auto &&table : m_tables) { - RuntimeDataTableHeader record = { table->GetType(), table->GetBlobSize(), curTableOffset }; + RuntimeDataTableHeader record = { table->GetType(), table->GetPartSize(), curTableOffset }; memcpy(pCur, &record, sizeof(RuntimeDataTableHeader)); pCur += sizeof(RuntimeDataTableHeader); curTableOffset += record.size; } // write tables for (auto &&table : m_tables) { - table->write(pCur); - pCur += table->GetBlobSize(); + table->Write(pCur); + pCur += table->GetPartSize(); } ULONG cbWritten; diff --git a/tools/clang/unittests/HLSL/DxilContainerTest.cpp b/tools/clang/unittests/HLSL/DxilContainerTest.cpp index f5b328db7..b60a1cf85 100644 --- a/tools/clang/unittests/HLSL/DxilContainerTest.cpp +++ b/tools/clang/unittests/HLSL/DxilContainerTest.cpp @@ -35,6 +35,7 @@ #include "dxc/HLSL/DxilContainer.h" #include "dxc/HLSL/DxilPipelineStateValidation.h" #include "dxc/HLSL/DxilShaderFlags.h" +#include "dxc/HLSL/DxilUtil.h" #include #include @@ -72,6 +73,7 @@ public: TEST_METHOD(CompileWhenDebugSourceThenSourceMatters) TEST_METHOD(CompileWhenOkThenCheckRDAT) + TEST_METHOD(CompileWhenOkThenCheckRDAT2) TEST_METHOD(CompileWhenOKThenIncludesFeatureInfo) TEST_METHOD(CompileWhenOKThenIncludesSignatures) TEST_METHOD(CompileWhenSigSquareThenIncludeSplit) @@ -745,6 +747,68 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) { IFTBOOLMSG(blobFound, E_FAIL, "failed to find RDAT blob after compiling"); } +TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT2) { + if (m_ver.SkipDxilVersion(1, 3)) return; + // This is a case when the user of resource is a constant, not instruction. + // Compiler generates the following load instruction for texture. + // load %class.Texture2D, %class.Texture2D* getelementptr inbounds ([3 x + // %class.Texture2D], [3 x %class.Texture2D]* + // @"\01?ThreeTextures@@3PAV?$Texture2D@M@@A", i32 0, i32 0), align 4 + const char *shader = + "SamplerState Sampler : register(s0); RWBuffer Uav : " + "register(u0); Texture2D ThreeTextures[3] : register(t0); " + "float function1();" + "[shader(\"raygeneration\")] void RayGenMain() { Uav[0] = " + "ThreeTextures[0].Sample(Sampler, float2(0, 0)) + function1(); }"; + CComPtr pCompiler; + CComPtr pSource; + CComPtr pProgram; + CComPtr pDisassembly; + CComPtr pResult; + HRESULT status; + + VERIFY_SUCCEEDED(CreateCompiler(&pCompiler)); + CreateBlobFromText(shader, &pSource); + VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"hlsl.hlsl", L"main", + L"lib_6_3", nullptr, 0, nullptr, 0, + nullptr, &pResult)); + VERIFY_SUCCEEDED(pResult->GetResult(&pProgram)); + VERIFY_SUCCEEDED(pResult->GetStatus(&status)); + VERIFY_SUCCEEDED(status); + CComPtr pReflection; + uint32_t partCount; + IFT(m_dllSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection)); + IFT(pReflection->Load(pProgram)); + IFT(pReflection->GetPartCount(&partCount)); + bool blobFound = false; + for (uint32_t i = 0; i < partCount; ++i) { + uint32_t kind; + IFT(pReflection->GetPartKind(i, &kind)); + if (kind == (uint32_t)hlsl::DxilFourCC::DFCC_RuntimeData) { + blobFound = true; + using namespace hlsl::DXIL::PSV; + CComPtr pBlob; + IFT(pReflection->GetPartContent(i, &pBlob)); + DxilRuntimeData context; + context.InitFromRDAT((char *)pBlob->GetBufferPointer()); + FunctionTableReader *funcTableReader = context.GetFunctionTableReader(); + ResourceTableReader *resTableReader = context.GetResourceTableReader(); + VERIFY_IS_TRUE(funcTableReader->GetNumFunctions() == 1); + VERIFY_IS_TRUE(resTableReader->GetNumResources() == 3); + FunctionReader funcReader = funcTableReader->GetItem(0); + llvm::StringRef name(funcReader.GetUnmangledName()); + VERIFY_IS_TRUE(name.compare("RayGenMain") == 0); + VERIFY_IS_TRUE(funcReader.GetShaderKind() == PSVShaderKind::RayGeneration); + VERIFY_IS_TRUE(funcReader.GetNumResources() == 3); + VERIFY_IS_TRUE(funcReader.GetNumDependencies() == 1); + llvm::StringRef dependencyName = + hlsl::dxilutil::DemangleFunctionName(funcReader.GetDependency(0)); + VERIFY_IS_TRUE(dependencyName.compare("function1") == 0); + } + } + IFTBOOLMSG(blobFound, E_FAIL, "failed to find RDAT blob after compiling"); +} + TEST_F(DxilContainerTest, CompileWhenOKThenIncludesFeatureInfo) { CComPtr pCompiler; CComPtr pSource;