diff --git a/include/dxc/HLSL/DxilModule.h b/include/dxc/HLSL/DxilModule.h index f11b64307..64f67f6f9 100644 --- a/include/dxc/HLSL/DxilModule.h +++ b/include/dxc/HLSL/DxilModule.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace llvm { class LLVMContext; @@ -132,6 +133,14 @@ public: DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F); // Move DxilFunctionProps of F to NewF. void ReplaceDxilFunctionProps(llvm::Function *F, llvm::Function *NewF); + void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc); + bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps + bool IsPatchConstantShader(llvm::Function *F); + bool IsComputeShader(llvm::Function *F); + + // Is an entry function that uses input/output signature conventions? + // Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function. + bool IsEntryThatUsesSignatures(llvm::Function *F); // Remove Root Signature from module metadata void StripRootSignatureFromMetadata(); @@ -436,6 +445,9 @@ private: std::unordered_map> m_DxilEntrySignatureMap; + // Keeps track of patch constant functions used by hull shaders + std::unordered_set m_PatchConstantFunctions; + // ViewId state. std::unique_ptr m_pViewIdState; diff --git a/include/dxc/HLSL/HLModule.h b/include/dxc/HLSL/HLModule.h index c6f0a29c6..30fdba0ab 100644 --- a/include/dxc/HLSL/HLModule.h +++ b/include/dxc/HLSL/HLModule.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace llvm { class LLVMContext; @@ -127,6 +128,14 @@ public: bool HasDxilFunctionProps(llvm::Function *F); DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F); void AddDxilFunctionProps(llvm::Function *F, std::unique_ptr &info); + void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc); + bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps + bool IsPatchConstantShader(llvm::Function *F); + bool IsComputeShader(llvm::Function *F); + + // Is an entry function that uses input/output signature conventions? + // Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function. + bool IsEntryThatUsesSignatures(llvm::Function *F); DxilFunctionAnnotation *GetFunctionAnnotation(llvm::Function *F); DxilFunctionAnnotation *AddFunctionAnnotation(llvm::Function *F); @@ -238,6 +247,7 @@ private: // High level function info. std::unordered_map> m_DxilFunctionPropsMap; + std::unordered_set m_PatchConstantFunctions; // Resource type annotation. std::unordered_map> m_ResTypeAnnotation; diff --git a/lib/HLSL/DxilGenerationPass.cpp b/lib/HLSL/DxilGenerationPass.cpp index 4ec5c6ccf..7e4f4fd74 100644 --- a/lib/HLSL/DxilGenerationPass.cpp +++ b/lib/HLSL/DxilGenerationPass.cpp @@ -238,7 +238,7 @@ public: for (auto It = M.begin(); It != M.end();) { Function &F = *(It++); // Lower signature for each entry function. - if (m_pHLModule->HasDxilFunctionProps(&F)) { + if (m_pHLModule->IsEntryThatUsesSignatures(&F)) { DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&F); std::unique_ptr pSig = llvm::make_unique(props.shaderKind, m_pHLModule->GetHLOptions().bUseMinPrecision); diff --git a/lib/HLSL/DxilLinker.cpp b/lib/HLSL/DxilLinker.cpp index fb3a3f94c..94117dab1 100644 --- a/lib/HLSL/DxilLinker.cpp +++ b/lib/HLSL/DxilLinker.cpp @@ -607,7 +607,8 @@ DxilLinkJob::Link(std::pair &entryLinkPair, Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc; Function *newPatchConstantFunc = m_newFunctions[patchConstantFunc->getName()]; - props.ShaderProps.HS.patchConstantFunc = newPatchConstantFunc; + DM.SetPatchConstantFunctionForHS(entryFunc, nullptr); + DM.SetPatchConstantFunctionForHS(NewEntryFunc, newPatchConstantFunc); if (newPatchConstantFunc->hasFnAttribute(llvm::Attribute::AlwaysInline)) newPatchConstantFunc->removeFnAttr(llvm::Attribute::AlwaysInline); diff --git a/lib/HLSL/DxilModule.cpp b/lib/HLSL/DxilModule.cpp index 55068e5c5..4ecf2dd28 100644 --- a/lib/HLSL/DxilModule.cpp +++ b/lib/HLSL/DxilModule.cpp @@ -1102,6 +1102,35 @@ void DxilModule::ReplaceDxilFunctionProps(llvm::Function *F, m_DxilFunctionPropsMap.erase(F); m_DxilFunctionPropsMap[NewF] = std::move(props); } +void DxilModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) { + auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc); + DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "Hull shader must already have function props!"); + DxilFunctionProps &props = *(propIter->second); + DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader"); + if (props.ShaderProps.HS.patchConstantFunc) + m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc); + props.ShaderProps.HS.patchConstantFunc = patchConstantFunc; + if (patchConstantFunc) + m_PatchConstantFunctions.insert(patchConstantFunc); +} +bool DxilModule::IsGraphicsShader(llvm::Function *F) { + return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics(); +} +bool DxilModule::IsPatchConstantShader(llvm::Function *F) { + return m_PatchConstantFunctions.count(F) != 0; +} +bool DxilModule::IsComputeShader(llvm::Function *F) { + return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS(); +} +bool DxilModule::IsEntryThatUsesSignatures(llvm::Function *F) { + auto propIter = m_DxilFunctionPropsMap.find(F); + if (propIter != m_DxilFunctionPropsMap.end()) { + DxilFunctionProps &props = *(propIter->second); + return props.IsGraphics() || props.IsCS(); + } + // Otherwise, return true if patch constant function + return IsPatchConstantShader(F); +} void DxilModule::StripRootSignatureFromMetadata() { NamedMDNode *pRootSignatureNamedMD = GetModule()->getNamedMetadata(DxilMDHelper::kDxilRootSignatureMDName); @@ -1319,6 +1348,11 @@ void DxilModule::LoadDxilMetadata() { Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get()); + if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) { + // Add patch constant function to m_PatchConstantFunctions + m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc); + } + m_DxilFunctionPropsMap[F] = std::move(props); } diff --git a/lib/HLSL/DxilPreparePasses.cpp b/lib/HLSL/DxilPreparePasses.cpp index ef0e7f1f2..cb10df3a6 100644 --- a/lib/HLSL/DxilPreparePasses.cpp +++ b/lib/HLSL/DxilPreparePasses.cpp @@ -374,7 +374,7 @@ private: } else { std::vector entries; for (iplist::iterator F : M.getFunctionList()) { - if (DM.HasDxilFunctionProps(F)) { + if (DM.IsEntryThatUsesSignatures(F)) { entries.emplace_back(F); } } @@ -384,7 +384,7 @@ private: // Strip patch constant function first. Function *patchConstFunc = StripFunctionParameter( props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs); - props.ShaderProps.HS.patchConstantFunc = patchConstFunc; + DM.SetPatchConstantFunctionForHS(entry, patchConstFunc); } StripFunctionParameter(entry, DM, FunctionDIs); } diff --git a/lib/HLSL/HLMatrixLowerPass.cpp b/lib/HLSL/HLMatrixLowerPass.cpp index cb690cde9..8fba26d95 100644 --- a/lib/HLSL/HLMatrixLowerPass.cpp +++ b/lib/HLSL/HLMatrixLowerPass.cpp @@ -272,6 +272,9 @@ private: // Get new matrix value corresponding to vecVal Value *GetMatrixForVec(Value *vecVal, Type *matTy); + // Translate library function input/output to preserve function signatures + void TranslateLibraryArgs(Function &F); + // Replace matVal with vecVal on matUseInst. void TrivialMatReplace(Value *matVal, Value *vecVal, CallInst *matUseInst); @@ -1269,6 +1272,16 @@ void HLMatrixLowerPass::TrivialMatReplace(Value *matVal, } } +static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned row, unsigned col) { + SmallVector castMask(col * row); + unsigned idx = 0; + for (unsigned c = 0; c < col; c++) + for (unsigned r = 0; r < row; r++) + castMask[idx++] = r * col + c; + return cast( + Builder.CreateShuffleVector(vecVal, vecVal, castMask)); +} + void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal, Value *vecVal, CallInst *castInst, @@ -1291,25 +1304,9 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal, IRBuilder<> Builder(castInst); - // shuf to change major. - SmallVector castMask(col * row); - unsigned idx = 0; - if (bRowToCol) { - for (unsigned c = 0; c < col; c++) - for (unsigned r = 0; r < row; r++) { - unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col); - castMask[idx++] = matIdx; - } - } else { - for (unsigned r = 0; r < row; r++) - for (unsigned c = 0; c < col; c++) { - unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row); - castMask[idx++] = matIdx; - } - } - - Instruction *vecCast = cast( - Builder.CreateShuffleVector(vecVal, vecVal, castMask)); + if (bRowToCol) + std::swap(row, col); + Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col); // Replace vec cast function call with vecCast. DXASSERT(matToVecMap.count(castInst), "must has vec version"); @@ -2109,12 +2106,10 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) { void HLMatrixLowerPass::replaceMatWithVec(Value *matVal, Value *vecVal) { + Type *matTy = matVal->getType(); for (Value::user_iterator user = matVal->user_begin(); user != matVal->user_end();) { Instruction *useInst = cast(*(user++)); - // Skip return here. - if (isa(useInst)) - continue; // User must be function call. if (CallInst *useCall = dyn_cast(useInst)) { hlsl::HLOpcodeGroup group = @@ -2183,7 +2178,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal, for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) { if (useCall->getArgOperand(i) == matVal) { // update the user call with the correct matrix value in new code sequence - Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType()); + Value *newMatVal = GetMatrixForVec(vecVal, matTy); if (matVal != newMatVal) useCall->setArgOperand(i, newMatVal); } @@ -2194,8 +2189,10 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal, // Just replace the src with vec version. useInst->setOperand(0, vecVal); } else if (ReturnInst *RI = dyn_cast(useInst)) { - Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType()); + Value *newMatVal = GetMatrixForVec(vecVal, matTy); RI->setOperand(0, newMatVal); + } else if (isa(useInst)) { + DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values"); } else { // Must be GEP on mat array alloca. GetElementPtrInst *GEP = cast(useInst); @@ -2467,6 +2464,85 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) { } } +void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) { + // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec) + // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major + for (auto &arg : F.args()) { + SmallVector Candidates; + for (User *U : arg.users()) { + if (CallInst *CI = dyn_cast(U)) { + HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction()); + switch (group) { + case HLOpcodeGroup::HLCast: + case HLOpcodeGroup::HLMatLoadStore: + Candidates.push_back(CI); + break; + } + } + } + for (CallInst *CI : Candidates) { + IRBuilder<> Builder(CI); + HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction()); + switch (group) { + case HLOpcodeGroup::HLCast: { + HLCastOpcode opcode = static_cast(hlsl::GetHLOpcode(CI)); + if (opcode == HLCastOpcode::RowMatrixToVecCast || + opcode == HLCastOpcode::ColMatrixToVecCast) { + Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx); + Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(), + /*bOrigAllocaTy*/false, + matVal->getName()); + if (opcode == HLCastOpcode::ColMatrixToVecCast) { + unsigned row, col; + HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row); + vecVal = CreateTransposeShuffle(Builder, vecVal, row, col); + } + CI->replaceAllUsesWith(vecVal); + CI->eraseFromParent(); + } + } break; + case HLOpcodeGroup::HLMatLoadStore: { + HLMatLoadStoreOpcode opcode = static_cast(hlsl::GetHLOpcode(CI)); + bool bTranspose = false; + switch (opcode) { + case HLMatLoadStoreOpcode::ColMatStore: + bTranspose = true; + case HLMatLoadStoreOpcode::RowMatStore: { + // shuffle if transposed, bitcast, and store + Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx); + Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx); + if (bTranspose) { + unsigned row, col; + HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row); + vecVal = CreateTransposeShuffle(Builder, vecVal, row, col); + } + Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo()); + Builder.CreateStore(vecVal, castPtr); + CI->eraseFromParent(); + } break; + case HLMatLoadStoreOpcode::ColMatLoad: + bTranspose = true; + case HLMatLoadStoreOpcode::RowMatLoad: { + // bitcast, load, and shuffle if transposed + Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx); + Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo()); + Value *vecVal = Builder.CreateLoad(castPtr); + if (bTranspose) { + unsigned row, col; + HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row); + // row/col swapped for col major source + vecVal = CreateTransposeShuffle(Builder, vecVal, col, row); + } + CI->replaceAllUsesWith(vecVal); + CI->eraseFromParent(); + } break; + } + } break; + } + } + } +} + void HLMatrixLowerPass::runOnFunction(Function &F) { // Create vector version of matrix instructions first. // The matrix operands will be undefval for these instructions. @@ -2531,4 +2607,12 @@ void HLMatrixLowerPass::runOnFunction(Function &F) { DeleteDeadInsts(); matToVecMap.clear(); + vecToMatMap.clear(); + + // If this is a library function, now fix input/output matrix params + // TODO: What about Patch Constant Shaders? + if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) { + TranslateLibraryArgs(F); + } + return; } diff --git a/lib/HLSL/HLModule.cpp b/lib/HLSL/HLModule.cpp index a5a9d07b8..ec56c8b22 100644 --- a/lib/HLSL/HLModule.cpp +++ b/lib/HLSL/HLModule.cpp @@ -350,6 +350,35 @@ void HLModule::AddDxilFunctionProps(llvm::Function *F, std::unique_ptrshaderKind != DXIL::ShaderKind::Invalid); m_DxilFunctionPropsMap[F] = std::move(info); } +void HLModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) { + auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc); + DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "else Hull Shader missing function props"); + DxilFunctionProps &props = *(propIter->second); + DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader"); + if (props.ShaderProps.HS.patchConstantFunc) + m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc); + props.ShaderProps.HS.patchConstantFunc = patchConstantFunc; + if (patchConstantFunc) + m_PatchConstantFunctions.insert(patchConstantFunc); +} +bool HLModule::IsGraphicsShader(llvm::Function *F) { + return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics(); +} +bool HLModule::IsPatchConstantShader(llvm::Function *F) { + return m_PatchConstantFunctions.count(F) != 0; +} +bool HLModule::IsComputeShader(llvm::Function *F) { + return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS(); +} +bool HLModule::IsEntryThatUsesSignatures(llvm::Function *F) { + auto propIter = m_DxilFunctionPropsMap.find(F); + if (propIter != m_DxilFunctionPropsMap.end()) { + DxilFunctionProps &props = *(propIter->second); + return props.IsGraphics() || props.IsCS(); + } + // Otherwise, return true if patch constant function + return IsPatchConstantShader(F); +} DxilFunctionAnnotation *HLModule::GetFunctionAnnotation(llvm::Function *F) { return m_pTypeSystem->GetFunctionAnnotation(F); @@ -475,6 +504,11 @@ void HLModule::LoadHLMetadata() { Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get()); + if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) { + // Add patch constant function to m_PatchConstantFunctions + m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc); + } + m_DxilFunctionPropsMap[F] = std::move(props); } diff --git a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp index a39244cbb..90436daa5 100644 --- a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp +++ b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp @@ -5173,6 +5173,9 @@ void SROA_Parameter_HLSL::flattenArgument( Type *Ty = V->getType(); if (Ty->isPointerTy()) Ty = Ty->getPointerElementType(); + + // Stop doing this when preserving resource types and using new + // createHandleFrom??? whatever it's going to be called... V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder); // Cannot SROA, save it to final parameter list. @@ -5829,20 +5832,8 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) { IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get()); RetBuilder.CreateRetVoid(); } else { - Function *Entry = m_pHLModule->GetEntryFunction(); - hasShaderInputOutput = F == Entry; - if (m_pHLModule->HasDxilFunctionProps(F)) { - DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F); - if (!funcProps.IsRay()) - hasShaderInputOutput = true; - } - if (m_pHLModule->HasDxilFunctionProps(Entry)) { - DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(Entry); - if (funcProps.shaderKind == DXIL::ShaderKind::Hull) { - Function *patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc; - hasShaderInputOutput |= F == patchConstantFunc; - } - } + hasShaderInputOutput = F == m_pHLModule->GetEntryFunction() || + m_pHLModule->IsEntryThatUsesSignatures(F); } std::vector FlatParamList; @@ -6361,9 +6352,9 @@ void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) { if (funcProps.shaderKind == DXIL::ShaderKind::Hull) { Function *oldPatchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc; - if (funcMap.count(oldPatchConstantFunc)) - funcProps.ShaderProps.HS.patchConstantFunc = - funcMap[oldPatchConstantFunc]; + if (funcMap.count(oldPatchConstantFunc)) { + m_pHLModule->SetPatchConstantFunctionForHS(flatF, funcMap[oldPatchConstantFunc]); + } } } // TODO: flatten vector argument and lower resource argument when flatten diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 511e68b6c..9704075a1 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -4317,11 +4317,11 @@ void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr( } Function *patchConstFunc = Entry->second.Func; - DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func); - DXASSERT(HSProps != nullptr, + DXASSERT(m_pHLModule->HasDxilFunctionProps(EntryFunc.Func), " else AddHLSLFunctionInfo did not save the dxil function props for the " "HS entry."); - HSProps->ShaderProps.HS.patchConstantFunc = patchConstFunc; + DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func); + m_pHLModule->SetPatchConstantFunctionForHS(EntryFunc.Func, patchConstFunc); DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc)); // Check no inout parameter for patch constant function. DxilFunctionAnnotation *patchConstFuncAnnotation = diff --git a/tools/clang/test/CodeGenHLSL/quick-test/lib_rt.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/lib_rt.hlsl index 8b04e5dd3..2e8b063ca 100644 --- a/tools/clang/test/CodeGenHLSL/quick-test/lib_rt.hlsl +++ b/tools/clang/test/CodeGenHLSL/quick-test/lib_rt.hlsl @@ -2,19 +2,19 @@ //////////////////////////////////////////////////////////////////////////// // Prototype header contents to be removed on implementation of features: -#define HIT_KIND_TRIANGLE_FRONT_FACE 0xFE -#define HIT_KIND_TRIANGLE_BACK_FACE 0xFF +#define HIT_KIND_TRIANGLE_FRONT_FACE 0xFE +#define HIT_KIND_TRIANGLE_BACK_FACE 0xFF typedef uint RAY_FLAG; -#define RAY_FLAG_NONE 0x00 -#define RAY_FLAG_FORCE_OPAQUE 0x01 -#define RAY_FLAG_FORCE_NON_OPAQUE 0x02 -#define RAY_FLAG_TERMINATE_ON_FIRST_HIT 0x04 -#define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER 0x08 -#define RAY_FLAG_CULL_BACK_FACING_TRIANGLES 0x10 -#define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES 0x20 -#define RAY_FLAG_CULL_OPAQUE 0x40 -#define RAY_FLAG_CULL_NON_OPAQUE 0x80 +#define RAY_FLAG_NONE 0x00 +#define RAY_FLAG_FORCE_OPAQUE 0x01 +#define RAY_FLAG_FORCE_NON_OPAQUE 0x02 +#define RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH 0x04 +#define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER 0x08 +#define RAY_FLAG_CULL_BACK_FACING_TRIANGLES 0x10 +#define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES 0x20 +#define RAY_FLAG_CULL_OPAQUE 0x40 +#define RAY_FLAG_CULL_NON_OPAQUE 0x80 struct RayDesc { @@ -29,38 +29,46 @@ struct BuiltInTriangleIntersectionAttributes float2 barycentrics; }; -typedef ByteAddressBuffer RayTracingAccelerationStructure; +typedef ByteAddressBuffer RaytracingAccelerationStructure; +// group: Indirect Shader Invocation // Declare TraceRay overload for given payload structure #define Declare_TraceRay(payload_t) \ - void TraceRay(RayTracingAccelerationStructure, uint RayFlags, uint InstanceCullMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t); + void TraceRay(RaytracingAccelerationStructure, uint RayFlags, uint InstanceInclusionMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t); -// Declare ReportIntersection overload for given attribute structure -#define Declare_ReportIntersection(attr_t) \ - bool ReportIntersection(float HitT, uint HitKind, attr_t); +// Declare ReportHit overload for given attribute structure +#define Declare_ReportHit(attr_t) \ + bool ReportHit(float HitT, uint HitKind, attr_t); // Declare CallShader overload for given param structure #define Declare_CallShader(param_t) \ void CallShader(uint ShaderIndex, inout param_t); -void IgnoreIntersection(); -void TerminateRay(); +// group: AnyHit Terminals +void IgnoreHit(); +void AcceptHitAndEndSearch(); // System Value retrieval functions +// group: Ray Dispatch Arguments uint2 RayDispatchIndex(); uint2 RayDispatchDimension(); +// group: Ray Vectors float3 WorldRayOrigin(); float3 WorldRayDirection(); -float RayTMin(); -float CurrentRayT(); -uint PrimitiveID(); -uint InstanceID(); -uint InstanceIndex(); float3 ObjectRayOrigin(); float3 ObjectRayDirection(); +// group: RayT +float RayTMin(); +float CurrentRayT(); +// group: Raytracing uint System Values +uint PrimitiveID(); // watch for existing +uint InstanceID(); +uint InstanceIndex(); +uint HitKind(); +uint RayFlag(); +// group: Ray Transforms row_major float3x4 ObjectToWorld(); row_major float3x4 WorldToObject(); -uint HitKind(); //////////////////////////////////////////////////////////////////////////// struct MyPayload { @@ -79,7 +87,7 @@ struct MyParam { }; Declare_TraceRay(MyPayload); -Declare_ReportIntersection(MyAttributes); +Declare_ReportHit(MyAttributes); Declare_CallShader(MyParam); // CHECK: ; S sampler NA NA S0 s1 1 @@ -90,7 +98,7 @@ Declare_CallShader(MyParam); // CHECK: @T_rangeID = external constant i32 // CHECK: @S_rangeID = external constant i32 -RayTracingAccelerationStructure RTAS : register(t5); +RaytracingAccelerationStructure RTAS : register(t5); // CHECK: define void [[raygen1:@"\\01\?raygen1@[^\"]+"]]() { // CHECK: [[RAWBUF_ID:[^ ]+]] = load i32, i32* @RTAS_rangeID @@ -114,7 +122,7 @@ void raygen1() // CHECK: define void [[intersection1:@"\\01\?intersection1@[^\"]+"]]() { // CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull [[pCurrentRayT:%[^)]+]]) // CHECK: [[CurrentRayT:%[^ ]+]] = load float, float* [[pCurrentRayT]], align 4 -// CHECK: call void {{.*}}ReportIntersection{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}}) +// CHECK: call void {{.*}}ReportHit{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}}) // CHECK: ret void [shader("intersection")] @@ -122,15 +130,15 @@ void intersection1() { float hitT = CurrentRayT(); MyAttributes attr = (MyAttributes)0; - bool bReported = ReportIntersection(hitT, 0, attr); + bool bReported = ReportHit(hitT, 0, attr); } // CHECK: define void [[anyhit1:@"\\01\?anyhit1@[^\"]+"]](float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, i32* noalias nocapture, i32* noalias nocapture, float, float, i32) // CHECK: call void {{.*}}ObjectRayOrigin{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}}) // CHECK: call void {{.*}}ObjectRayDirection{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}}) // CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull {{.*}}) -// CHECK: call void {{.*}}TerminateRay{{.*}}() -// CHECK: call void {{.*}}IgnoreIntersection{{.*}}() +// CHECK: call void {{.*}}AcceptHitAndEndSearch{{.*}}() +// CHECK: call void {{.*}}IgnoreHit{{.*}}() // CHECK: store float {{.*}}, float* %0, align 4 // CHECK: store float {{.*}}, float* %1, align 4 // CHECK: store float {{.*}}, float* %2, align 4 @@ -145,9 +153,9 @@ void anyhit1( inout MyPayload payload : SV_RayPayload, { float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * CurrentRayT(); if (hitLocation.z < attr.bary.x) - TerminateRay(); // aborts function + AcceptHitAndEndSearch(); // aborts function if (hitLocation.z < attr.bary.y) - IgnoreIntersection(); // aborts function + IgnoreHit(); // aborts function payload.color += float4(0.125, 0.25, 0.5, 1.0); }