From 6eb541244a6951e056f619d318da1548f53aa654 Mon Sep 17 00:00:00 2001 From: Tex Riddell Date: Mon, 21 Oct 2019 16:21:23 -0700 Subject: [PATCH] Lower vector/matrix early for UDT ptrs used directly such as Payload --- include/dxc/HLSL/HLLowerUDT.h | 7 +- lib/HLSL/HLLowerUDT.cpp | 31 ++- .../Scalar/ScalarReplAggregatesHLSL.cpp | 246 +++++++++++++++++- .../mesh/as-groupshared-payload-matrix.hlsl | 62 +++++ .../mesh/as-groupshared-payload.hlsl | 25 +- .../mesh/mesh-payload-matrix.hlsl | 81 ++++++ .../shader_targets/mesh/mesh.hlsl | 8 +- 7 files changed, 432 insertions(+), 28 deletions(-) create mode 100644 tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-matrix.hlsl create mode 100644 tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh-payload-matrix.hlsl diff --git a/include/dxc/HLSL/HLLowerUDT.h b/include/dxc/HLSL/HLLowerUDT.h index 0487a6992..bfd8a3e81 100644 --- a/include/dxc/HLSL/HLLowerUDT.h +++ b/include/dxc/HLSL/HLLowerUDT.h @@ -23,9 +23,12 @@ class Value; } // namespace llvm namespace hlsl { +class DxilTypeSystem; -llvm::StructType *GetLoweredUDT(llvm::StructType *structTy); -llvm::Constant *TranslateInitForLoweredUDT(llvm::Constant *Init, llvm::Type *NewTy, +llvm::StructType *GetLoweredUDT( + llvm::StructType *structTy, hlsl::DxilTypeSystem *pTypeSys = nullptr); +llvm::Constant *TranslateInitForLoweredUDT( + llvm::Constant *Init, llvm::Type *NewTy, // We need orientation for matrix fields hlsl::DxilTypeSystem *pTypeSys, hlsl::MatrixOrientation matOrientation = hlsl::MatrixOrientation::Undefined); diff --git a/lib/HLSL/HLLowerUDT.cpp b/lib/HLSL/HLLowerUDT.cpp index 03c2057b8..c989dc97d 100644 --- a/lib/HLSL/HLLowerUDT.cpp +++ b/lib/HLSL/HLLowerUDT.cpp @@ -54,7 +54,7 @@ static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, un // Lowered UDT is the same layout, but with vectors and matrices translated to // arrays. // Returns nullptr for failure due to embedded HLSL object type. -StructType *hlsl::GetLoweredUDT(StructType *structTy) { +StructType *hlsl::GetLoweredUDT(StructType *structTy, DxilTypeSystem *pTypeSys) { bool changed = false; SmallVector NewElTys(structTy->getNumContainedTypes()); @@ -106,17 +106,29 @@ StructType *hlsl::GetLoweredUDT(StructType *structTy) { } if (changed) { - return StructType::create( + StructType *newStructTy = StructType::create( structTy->getContext(), NewElTys, structTy->getStructName()); + if (DxilStructAnnotation *pSA = pTypeSys ? + pTypeSys->GetStructAnnotation(structTy) : nullptr) { + if (!pTypeSys->GetStructAnnotation(newStructTy)) { + DxilStructAnnotation &NewSA = *pTypeSys->AddStructAnnotation(newStructTy); + for (unsigned iField = 0; iField < NewElTys.size(); ++iField) { + NewSA.GetFieldAnnotation(iField) = pSA->GetFieldAnnotation(iField); + } + } + } + return newStructTy; } return structTy; } -Constant *hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy, +Constant *hlsl::TranslateInitForLoweredUDT( + Constant *Init, Type *NewTy, // We need orientation for matrix fields DxilTypeSystem *pTypeSys, MatrixOrientation matOrientation) { + // handle undef and zero init if (isa(Init)) return UndefValue::get(NewTy); @@ -159,14 +171,23 @@ Constant *hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy, } } } else if (StructType *ST = dyn_cast(Ty)) { + DxilStructAnnotation *pStructAnnotation = + pTypeSys ? pTypeSys->GetStructAnnotation(ST) : nullptr; values.reserve(ST->getNumContainedTypes()); ConstantStruct *CS = cast(Init); for (unsigned i = 0; i < ST->getStructNumElements(); ++i) { + MatrixOrientation matFieldOrientation = matOrientation; + if (pStructAnnotation) { + DxilFieldAnnotation &FA = pStructAnnotation->GetFieldAnnotation(i); + if (FA.HasMatrixAnnotation()) { + matFieldOrientation = FA.GetMatrixAnnotation().Orientation; + } + } values.emplace_back( TranslateInitForLoweredUDT( cast(CS->getAggregateElement(i)), NewTy->getStructElementType(i), - pTypeSys, matOrientation)); + pTypeSys, matFieldOrientation)); } return ConstantStruct::get(cast(NewTy), values); } @@ -411,7 +432,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) { } } break; - case HLOpcodeGroup::NotHL: + //case HLOpcodeGroup::NotHL: // TODO: Support lib functions case HLOpcodeGroup::HLIntrinsic: { // Just bitcast for now IRBuilder<> Builder(CI); diff --git a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp index 605f14abf..52fb4cc21 100644 --- a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp +++ b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp @@ -58,6 +58,7 @@ #include "dxc/HLSL/HLMatrixLowerHelper.h" #include "dxc/HLSL/HLMatrixType.h" #include "dxc/DXIL/DxilOperations.h" +#include "dxc/HLSL/HLLowerUDT.h" #include #include #include @@ -777,6 +778,200 @@ static unsigned getNestedLevelInStruct(const Type *ty) { return lvl; } +/// Returns first GEP index that indexes a struct member, or 0 otherwise. +/// Ignores initial ptr index. +static unsigned FindFirstStructMemberIdxInGEP(GEPOperator *GEP) { + StructType *ST = dyn_cast( + GEP->getPointerOperandType()->getPointerElementType()); + int index = 1; + for (auto it = gep_type_begin(GEP), E = gep_type_end(GEP); it != E; + ++it, ++index) { + if (ST) { + DXASSERT(!HLMatrixType::isa(ST) && !dxilutil::IsHLSLObjectType(ST), + "otherwise, indexing into hlsl object"); + return index; + } + ST = dyn_cast(it->getPointerElementType()); + } + return 0; +} + +/// Return true when ptr should not be SROA'd or copied, but used directly +/// by a function in its lowered form. Also collect uses for translation. +/// What is meant by directly here: +/// Possibly accessed through GEP array index or address space cast, but +/// not under another struct member (always allow SROA of outer struct). +typedef SmallMapVector FunctionUseMap; +static unsigned IsPtrUsedByLoweredFn( + Value *V, FunctionUseMap &CollectedUses) { + bool bFound = false; + for (Use &U : V->uses()) { + User *user = U.getUser(); + + if (CallInst *CI = dyn_cast(user)) { + unsigned foundIdx = (unsigned)-1; + Function *F = CI->getCalledFunction(); + Type *Ty = V->getType(); + if (F->isDeclaration() && !F->isIntrinsic() && + Ty->isPointerTy()) { + HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F); + if (group == HLOpcodeGroup::HLIntrinsic) { + unsigned opIdx = U.getOperandNo(); + switch ((IntrinsicOp)hlsl::GetHLOpcode(CI)) { + // TODO: Lower these as well, along with function parameter types + //case IntrinsicOp::IOP_TraceRay: + // if (opIdx != HLOperandIndex::kTraceRayPayLoadOpIdx) + // continue; + // break; + //case IntrinsicOp::IOP_ReportHit: + // if (opIdx != HLOperandIndex::kReportIntersectionAttributeOpIdx) + // continue; + // break; + //case IntrinsicOp::IOP_CallShader: + // if (opIdx != HLOperandIndex::kCallShaderPayloadOpIdx) + // continue; + // break; + case IntrinsicOp::IOP_DispatchMesh: + if (opIdx != HLOperandIndex::kDispatchMeshOpPayload) + continue; + break; + default: + continue; + } + foundIdx = opIdx; + + // TODO: Lower these as well, along with function parameter types + //} else if (group == HLOpcodeGroup::NotHL) { + // foundIdx = U.getOperandNo(); + } + } + if (foundIdx != (unsigned)-1) { + bFound = true; + auto insRes = CollectedUses.insert(std::make_pair(CI, foundIdx)); + DXASSERT_LOCALVAR(insRes, insRes.second, + "otherwise, multiple uses in single call"); + } + + } else if (GetElementPtrInst *GEP = dyn_cast(user)) { + // Not what we are looking for if GEP result is not [array of] struct. + // If use is under struct member, we can still SROA the outer struct. + if (!dxilutil::StripArrayTypes(GEP->getType()->getPointerElementType()) + ->isStructTy() || + FindFirstStructMemberIdxInGEP(cast(GEP))) + continue; + if (IsPtrUsedByLoweredFn(user, CollectedUses)) + bFound = true; + + } else if (AddrSpaceCastInst *AC = dyn_cast(user)) { + if (IsPtrUsedByLoweredFn(user, CollectedUses)) + bFound = true; + + } else if (ConstantExpr *CE = dyn_cast(user)) { + unsigned opcode = CE->getOpcode(); + if (opcode == Instruction::AddrSpaceCast || Instruction::GetElementPtr) + if (IsPtrUsedByLoweredFn(user, CollectedUses)) + bFound = true; + } + } + return bFound; +} + +/// Rewrite call to natively use an argument with addrspace cast/bitcast +static CallInst *RewriteIntrinsicCallForCastedArg(CallInst *CI, unsigned argIdx) { + Function *F = CI->getCalledFunction(); + HLOpcodeGroup group = GetHLOpcodeGroupByName(F); + DXASSERT_NOMSG(group == HLOpcodeGroup::HLIntrinsic); + unsigned opcode = GetHLOpcode(CI); + SmallVector newArgTypes(CI->getFunctionType()->param_begin(), + CI->getFunctionType()->param_end()); + SmallVector newArgs(CI->arg_operands()); + + Value *newArg = CI->getOperand(argIdx)->stripPointerCasts(); + newArgTypes[argIdx] = newArg->getType(); + newArgs[argIdx] = newArg; + + FunctionType *newFuncTy = FunctionType::get(CI->getType(), newArgTypes, false); + Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode); + + IRBuilder<> Builder(CI); + return Builder.CreateCall(newF, newArgs); +} + +/// Translate pointer for cases where intrinsics use UDT pointers directly +/// Return existing or new ptr if needs preserving, +/// otherwise nullptr to proceed with existing checks and SROA. +static Value *TranslatePtrIfUsedByLoweredFn( + Value *Ptr, DxilTypeSystem &TypeSys) { + if (!Ptr->getType()->isPointerTy()) + return nullptr; + Type *Ty = Ptr->getType()->getPointerElementType(); + SmallVector outerToInnerLengths; + Ty = dxilutil::StripArrayTypes(Ty, &outerToInnerLengths); + if (!Ty->isStructTy()) + return nullptr; + if (HLMatrixType::isa(Ty) || dxilutil::IsHLSLObjectType(Ty)) + return nullptr; + unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace(); + FunctionUseMap FunctionUses; + if (!IsPtrUsedByLoweredFn(Ptr, FunctionUses)) + return nullptr; + // Translate vectors to arrays in type, but don't SROA + Type *NewTy = GetLoweredUDT(cast(Ty)); + + // No work to do here, but prevent SROA. + if (Ty == NewTy && AddrSpace != DXIL::kTGSMAddrSpace) + return Ptr; + + // If type changed, replace value, otherwise casting may still + // require a rewrite of the calls. + Value *NewPtr = Ptr; + if (Ty != NewTy) { + // TODO: Transfer type annotation + DxilStructAnnotation *pOldAnnotation = TypeSys.GetStructAnnotation(cast(Ty)); + if (pOldAnnotation) { + + } + NewTy = dxilutil::WrapInArrayTypes(NewTy, outerToInnerLengths); + if (GlobalVariable *GV = dyn_cast(Ptr)) { + Module &M = *GV->getParent(); + // Rewrite init expression for arrays instead of vectors + Constant *Init = GV->hasInitializer() ? + GV->getInitializer() : UndefValue::get(Ptr->getType()); + Constant *NewInit = TranslateInitForLoweredUDT( + Init, NewTy, &TypeSys); + // Replace with new GV, and rewrite vector load/store users + GlobalVariable *NewGV = new GlobalVariable( + M, NewTy, GV->isConstant(), GV->getLinkage(), + NewInit, GV->getName(), /*InsertBefore*/ GV, + GV->getThreadLocalMode(), AddrSpace); + NewPtr = NewGV; + } else if (AllocaInst *AI = dyn_cast(Ptr)) { + IRBuilder<> Builder(AI); + AllocaInst * NewAI = Builder.CreateAlloca(NewTy, nullptr, AI->getName()); + NewPtr = NewAI; + } else { + DXASSERT(false, "Ptr must be global or alloca"); + } + // This will rewrite vector load/store users + // and insert bitcasts for CallInst users + ReplaceUsesForLoweredUDT(Ptr, NewPtr); + } + + // Rewrite the HLIntrinsic calls + for (auto it : FunctionUses) { + CallInst *CI = it.first; + HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction()); + if (group == HLOpcodeGroup::NotHL) + continue; + CallInst *newCI = RewriteIntrinsicCallForCastedArg(CI, it.second); + CI->replaceAllUsesWith(newCI); + CI->eraseFromParent(); + } + + return NewPtr; +} + + // performScalarRepl - This algorithm is a simple worklist driven algorithm, // which runs on all of the alloca instructions in the entry block, removing // them if they are only used by getelementptr instructions. @@ -866,6 +1061,15 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) { continue; } + if (Value *NewV = TranslatePtrIfUsedByLoweredFn(AI, typeSys)) { + if (NewV != AI) { + DXASSERT(AI->getNumUses() == 0, "must have zero users."); + AI->eraseFromParent(); + Changed = true; + } + continue; + } + // If the alloca looks like a good candidate for scalar replacement, and // if // all its users can be transformed, then split up the aggregate into its @@ -1053,8 +1257,7 @@ void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset, IntrinsicOp opcode = static_cast(GetHLOpcode(CI)); if (IntrinsicOp::IOP_TraceRay == opcode || IntrinsicOp::IOP_ReportHit == opcode || - IntrinsicOp::IOP_CallShader == opcode || - IntrinsicOp::IOP_DispatchMesh == opcode) { + IntrinsicOp::IOP_CallShader == opcode) { return MarkUnsafe(Info, User); } } @@ -2588,13 +2791,6 @@ void SROA_Helper::RewriteCall(CallInst *CI) { RewriteCallArg(CI, HLOperandIndex::kCallShaderPayloadOpIdx, /*bIn*/ true, /*bOut*/ true); } break; - case IntrinsicOp::IOP_DispatchMesh: { - if (OldVal == - CI->getArgOperand(HLOperandIndex::kDispatchMeshOpPayload)) { - RewriteCallArg(CI, HLOperandIndex::kDispatchMeshOpPayload, - /*bIn*/ true, /*bOut*/ false); - } - } break; case IntrinsicOp::MOP_TraceRayInline: { if (OldVal == CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) { @@ -4068,10 +4264,20 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) { bFlatVector = false; std::vector Elts; - bool SROAed = SROA_Helper::DoScalarReplacement( - EltGV, Elts, Builder, bFlatVector, - // TODO: set precise. - /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts); + bool SROAed = false; + if (GlobalVariable *NewEltGV = dyn_cast_or_null( + TranslatePtrIfUsedByLoweredFn(EltGV, dxilTypeSys))) { + if (GV != EltGV) { + EltGV->removeDeadConstantUsers(); + EltGV->eraseFromParent(); + } + EltGV = NewEltGV; + } else { + SROAed = SROA_Helper::DoScalarReplacement( + EltGV, Elts, Builder, bFlatVector, + // TODO: set precise. + /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts); + } if (SROAed) { // Push Elts into workList. @@ -4722,6 +4928,19 @@ Value *SROA_Parameter_HLSL::castArgumentIfRequired( Module &M = *m_pHLModule->GetModule(); IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint())); + if (inputQual == DxilParamInputQual::InPayload) { + DXASSERT_NOMSG(isa(Ty)); + // Lower payload type here + StructType *LoweredTy = GetLoweredUDT(cast(Ty)); + if (LoweredTy != Ty) { + Value *Ptr = AllocaBuilder.CreateAlloca(LoweredTy); + ReplaceUsesForLoweredUDT(V, Ptr); + castParamMap[V] = std::make_pair(Ptr, inputQual); + V = Ptr; + } + return V; + } + // Remove pointer for vector/scalar which is not out. if (V->getType()->isPointerTy() && !Ty->isAggregateType() && !bOut) { Value *Ptr = AllocaBuilder.CreateAlloca(Ty); @@ -5419,6 +5638,7 @@ static void LegalizeDxilInputOutputs(Function *F, bLoadOutputFromTemp = true; } else if (bLoad && bStore) { switch (qual) { + case DxilParamInputQual::InPayload: case DxilParamInputQual::InputPrimitive: case DxilParamInputQual::InputPatch: case DxilParamInputQual::OutputPatch: { diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-matrix.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-matrix.hlsl new file mode 100644 index 000000000..8bfc18325 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-matrix.hlsl @@ -0,0 +1,62 @@ +// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s + +// CHECK: define void @main + +struct MeshPayload +{ + int4 data; + bool2x2 mat; +}; + +struct GSStruct +{ + row_major bool2x2 mat; + int4 vecmat; + MeshPayload pld[2]; +}; + +groupshared GSStruct gs[2]; + +row_major bool2x2 row_mat_array[2]; + +int i, j; + +[numthreads(4,1,1)] +void main(uint gtid : SV_GroupIndex) +{ + // write to dynamic row/col + gs[j].pld[i].mat[gtid >> 1][gtid & 1] = (int)gtid; + gs[j].vecmat[gtid] = (int)gtid; + + int2x2 mat = gs[j].pld[i].mat; + gs[j].pld[i].mat = (bool2x2)gs[j].vecmat; + + // subscript + constant GEP for component + gs[j].pld[i].mat[1].x = mat[1].y; + mat[0].y = gs[j].pld[i].mat[0].x; + + // dynamic subscript + constant component index + gs[j].pld[i].mat[gtid & 1].x = mat[gtid & 1].y; + mat[gtid & 1].y = gs[j].pld[i].mat[gtid & 1].x; + + // dynamic subscript + GEP for component + gs[j].pld[i].mat[gtid & 1] = mat[gtid & 1].y; + mat[gtid & 1].y = gs[j].pld[i].mat[gtid & 1].x; + + // subscript element + gs[j].pld[i].mat._m01_m10 = mat[1]; + mat[0] = gs[j].pld[i].mat._m00_m11; + + // dynamic index of subscript element vector + mat[0].x = gs[j].pld[i].mat._m00_m11_m10[gtid & 1]; + gs[j].pld[i].mat._m11_m10[gtid & 1] = gtid; + + // Dynamic index into vector + int idx = gs[j].vecmat.x; + gs[j].pld[i].mat[1][idx] = mat[1].y; + mat[0].y = gs[j].pld[i].mat[0][idx]; + int2 vec = gs[j].mat[0]; + int2 multiplied = mul(mat, vec); + gs[j].pld[i].data = multiplied.xyxy; + DispatchMesh(1,1,1,gs[j].pld[i]); +} diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload.hlsl index 2d15e7156..f57bbfa2e 100644 --- a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload.hlsl +++ b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload.hlsl @@ -1,17 +1,34 @@ // RUN: %dxc -E amplification -T as_6_5 %s | FileCheck %s +// Make sure we pass constant gep of groupshared mesh payload directly +// in to DispatchMesh, with no alloca involved. + // CHECK: define void @amplification +// CHECK-NOT: alloca +// CHECK-NOT: addrspacecast +// CHECK-NOT: bitcast +// CHECK: call void @dx.op.dispatchMesh.struct.MeshPayload{{[^ ]*}}(i32 173, i32 1, i32 1, i32 1, %struct.MeshPayload{{[^ ]*}} addrspace(3)* getelementptr inbounds (%struct.GSStruct{{[^ ]*}}, %struct.GSStruct{{[^ ]*}} addrspace(3)* @"\01?gs@@3UGSStruct@@A{{[^ ]*}}", i32 0, i32 1)) +// CHECK: ret void struct MeshPayload { - uint data[4]; + uint4 data; }; -groupshared MeshPayload pld; +struct GSStruct +{ + uint i; + MeshPayload pld; +}; + +groupshared GSStruct gs; +GSStruct cb_gs; [numthreads(4,1,1)] void amplification(uint gtid : SV_GroupIndex) { - pld.data[gtid] = gtid; - DispatchMesh(1,1,1,pld); + gs = cb_gs; + gs.i = 1; + gs.pld.data[gtid] = gtid; + DispatchMesh(1,1,1,gs.pld); } diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh-payload-matrix.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh-payload-matrix.hlsl new file mode 100644 index 000000000..e792ee8e1 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh-payload-matrix.hlsl @@ -0,0 +1,81 @@ +// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s + +// CHECK: %[[pld:[^ ]+]] = call %struct.MeshPayload{{[^ ]*}} @dx.op.getMeshPayload.struct.MeshPayload{{.*}}(i32 170) +// CHECK: call void @dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16) +// CHECK: call void @dx.op.emitIndices + +// Verify bool translated from mem type +// CHECK: %[[ppld0:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 0 +// CHECK: %[[pld0:[^ ]+]] = load i32, i32* %[[ppld0]], align 4 +// CHECK: %[[ppld1:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 1 +// CHECK: %[[pld1:[^ ]+]] = load i32, i32* %[[ppld1]], align 4 +// CHECK: %[[ppld2:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 2 +// CHECK: %[[pld2:[^ ]+]] = load i32, i32* %[[ppld2]], align 4 +// CHECK: %[[ppld3:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 3 +// CHECK: %[[pld3:[^ ]+]] = load i32, i32* %[[ppld3]], align 4 +// Inner components reversed due to column_major +// CHECK: icmp ne i32 %[[pld0]], 0 +// CHECK: icmp ne i32 %[[pld2]], 0 +// CHECK: icmp ne i32 %[[pld1]], 0 +// CHECK: icmp ne i32 %[[pld3]], 0 + +// CHECK: call void @dx.op.storePrimitiveOutput +// CHECK: call void @dx.op.storeVertexOutput + +// CHECK: ret void + +#define MAX_VERT 32 +#define MAX_PRIM 16 +#define NUM_THREADS 32 +struct MeshPerVertex { + float4 position : SV_Position; + float color[4] : COLOR; +}; + +struct MeshPerPrimitive { + float normal : NORMAL; +}; + +struct MeshPayload { + float normal; + int4 data; + bool2x2 mat; +}; + +groupshared float gsMem[MAX_PRIM]; + +[numthreads(NUM_THREADS, 1, 1)] +[outputtopology("triangle")] +void main( + out indices uint3 primIndices[MAX_PRIM], + out vertices MeshPerVertex verts[MAX_VERT], + out primitives MeshPerPrimitive prims[MAX_PRIM], + in payload MeshPayload mpl, + in uint tig : SV_GroupIndex, + in uint vid : SV_ViewID + ) +{ + SetMeshOutputCounts(MAX_VERT, MAX_PRIM); + MeshPerVertex ov; + if (vid % 2) { + ov.position = float4(4.0,5.0,6.0,7.0); + ov.color[0] = 4.0; + ov.color[1] = 5.0; + ov.color[2] = 6.0; + ov.color[3] = 7.0; + } else { + ov.position = float4(14.0,15.0,16.0,17.0); + ov.color[0] = 14.0; + ov.color[1] = 15.0; + ov.color[2] = 16.0; + ov.color[3] = 17.0; + } + if (tig % 3) { + primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2); + MeshPerPrimitive op; + op.normal = dot(mpl.normal.xx, mul(mpl.data.xy, mpl.mat)); + gsMem[tig / 3] = op.normal; + prims[tig / 3] = op; + } + verts[tig] = ov; +} diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh.hlsl index 0d1353e2c..81ae715f9 100644 --- a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh.hlsl +++ b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh.hlsl @@ -1,10 +1,10 @@ // RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s -// CHECK: dx.op.getMeshPayload.struct.MeshPayload +// CHECK: dx.op.getMeshPayload.struct.MeshPayload(i32 170) // CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16) -// CHECK: dx.op.emitIndices -// CHECK: dx.op.storeVertexOutput -// CHECK: dx.op.storePrimitiveOutput +// CHECK: dx.op.emitIndices(i32 169, +// CHECK: dx.op.storePrimitiveOutput.f32(i32 172, +// CHECK: dx.op.storeVertexOutput.f32(i32 171, // CHECK: !"cullPrimitive", i32 3, i32 100, i32 4, !"SV_CullPrimitive", i32 7, i32 1} #define MAX_VERT 32