[DXIL generation] Merge GepUse last to avoid crash in EmitGetNodeRecordPtrAndUpdateUsers (#6314)

In EmitGetNodeRecordPtrAndUpdateUsers, the type will be mutated. And the
GEP user of the RecordPtr will be merged at same time. This make things
complex because the GEP index need to be updated since type is mutated.

To make things easier, merge the GepUse after mutate type.

Fixes #6223
This commit is contained in:
Xiang Li 2024-02-16 13:34:57 -08:00 коммит произвёл GitHub
Родитель df588beb48
Коммит 823125b32e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 87 добавлений и 6 удалений

Просмотреть файл

@ -178,7 +178,7 @@ hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy,
return Init;
}
void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
static void ReplaceUsesForLoweredUDTImpl(Value *V, Value *NewV) {
Type *Ty = V->getType();
Type *NewTy = NewV->getType();
@ -255,8 +255,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
IRBuilder<> Builder(GEP);
SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
Value *NewGEP = Builder.CreateGEP(NewV, idxList);
ReplaceUsesForLoweredUDT(GEP, NewGEP);
dxilutil::MergeGepUse(NewGEP);
ReplaceUsesForLoweredUDTImpl(GEP, NewGEP);
GEP->eraseFromParent();
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(user)) {
@ -264,14 +263,14 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
Constant *NewGEP = ConstantExpr::getGetElementPtr(
nullptr, cast<Constant>(NewV), idxList, true);
ReplaceUsesForLoweredUDT(GEP, NewGEP);
ReplaceUsesForLoweredUDTImpl(GEP, NewGEP);
} else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
// Address space cast
IRBuilder<> Builder(AC);
Value *NewAC = Builder.CreateAddrSpaceCast(
NewV, PointerType::get(Ty, AC->getType()->getPointerAddressSpace()));
ReplaceUsesForLoweredUDT(user, NewAC);
ReplaceUsesForLoweredUDTImpl(user, NewAC);
AC->eraseFromParent();
} else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
IRBuilder<> Builder(BC);
@ -295,7 +294,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
Constant *NewAC = ConstantExpr::getAddrSpaceCast(
cast<Constant>(NewV),
PointerType::get(Ty, CE->getType()->getPointerAddressSpace()));
ReplaceUsesForLoweredUDT(user, NewAC);
ReplaceUsesForLoweredUDTImpl(user, NewAC);
} else if (CE->getOpcode() == Instruction::BitCast) {
if (CE->getType()->getPointerElementType() == NewTy) {
// if alreday bitcast to new type, just replace the bitcast
@ -475,3 +474,9 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
CV->removeDeadConstantUsers();
}
}
void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
ReplaceUsesForLoweredUDTImpl(V, NewV);
// Merge GepUse later to avoid mutate type and merge gep use at same time.
dxilutil::MergeGepUse(NewV);
}

Просмотреть файл

@ -0,0 +1,38 @@
// RUN: %dxc -Tlib_6_8 %s | FileCheck %s
// Make sure generate correct metadata for Entry.
// CHECK: !{void ()* @Entry, !"Entry", null, null, ![[ENTRY:[0-9]+]]}
// CHECK: ![[ENTRY]] = !{i32 8, i32 15, i32 13, i32 1, i32 14, i1 true, i32 15, ![[NodeId:[0-9]+]], i32 16, i32 -1, i32 18, ![[NodeDispatchGrid:[0-9]+]], i32 20, ![[NodeInputs:[0-9]+]], i32 4, ![[NumThreads:[0-9]+]], i32 5, ![[AutoBindingSpace:[0-9]+]]}
// CHECK: ![[NodeId]] = !{!"Entry", i32 0}
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
static const int maxPoints = 8;
struct EntryRecord {
float2 points[maxPoints];
int pointCoint;
};
[shader("node")]
[NodeIsProgramEntry]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(32, 1, 1)]
void Entry(
uint gtid : SV_GroupThreadId,
DispatchNodeInputRecord<EntryRecord> inputData
)
{
EntryRecord input = inputData.Get();
[[unroll]]
for (int i = 0; i < 8; ++i) {
float2 p = input.points[i];
}
}

Просмотреть файл

@ -0,0 +1,38 @@
// RUN: %dxc -Tlib_6_8 %s | FileCheck %s
// Make sure generate correct metadata for Entry.
// CHECK: !{void ()* @Entry, !"Entry", null, null, ![[ENTRY:[0-9]+]]}
// CHECK: ![[ENTRY]] = !{i32 8, i32 15, i32 13, i32 1, i32 14, i1 true, i32 15, ![[NodeId:[0-9]+]], i32 16, i32 -1, i32 18, ![[NodeDispatchGrid:[0-9]+]], i32 20, ![[NodeInputs:[0-9]+]], i32 4, ![[NumThreads:[0-9]+]], i32 5, ![[AutoBindingSpace:[0-9]+]]}
// CHECK: ![[NodeId]] = !{!"Entry", i32 0}
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
static const int maxPoints = 8;
struct EntryRecord {
float2 points[maxPoints];
int pointCoint;
};
[shader("node")]
[NodeIsProgramEntry]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(32, 1, 1)]
void Entry(
uint gtid : SV_GroupThreadId,
DispatchNodeInputRecord<EntryRecord> inputData
)
{
EntryRecord input = inputData.Get();
if (gtid < input.pointCoint) {
// reading input.points[0] works
float2 p = input.points[gtid];
}
}