[SPIR-V] Fix invalid isnan codegen (#6754)

IsNan returns a boolean, even is the input-type is a float. This was
working in most cases except:
 - if the layout was not Void
 - if the input type was not a matrix

The first bug is because a bool memory layout/representation is not
specified, and shall never be exposed to externaly-accessible memory.
Hence, if we saw a layout rule != Void, we converted it to a UINT. When
calling isnan, the layout rule should not be propagated as we loose any
layout info.

The second is because our codegen assumed matrix operations returned a
matrix with the same type as the input parameters. In the case of isnan,
this was just wrong.

Fixes #6712

Signed-off-by: Nathan Gauër <brioche@google.com>
This commit is contained in:
Nathan Gauër 2024-07-19 17:18:07 +02:00 коммит произвёл GitHub
Родитель 6f1c8e2443
Коммит 1028410a55
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 148 добавлений и 62 удалений

Просмотреть файл

@ -6286,9 +6286,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
if (isMxNMatrix(subType)) {
// For matrices, we can only increment/decrement each vector of it.
const auto actOnEachVec = [this, spvOp, one, expr,
range](uint32_t /*index*/, QualType vecType,
range](uint32_t /*index*/, QualType inType,
QualType outType,
SpirvInstruction *lhsVec) {
auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
auto *val = spvBuilder.createBinaryOp(spvOp, outType, lhsVec, one,
expr->getOperatorLoc(), range);
if (val)
val->setRValue();
@ -6356,9 +6357,10 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
if (isMxNMatrix(subType)) {
// For matrices, we can only negate each vector of it.
const auto actOnEachVec = [this, spvOp, expr,
range](uint32_t /*index*/, QualType vecType,
range](uint32_t /*index*/, QualType inType,
QualType outType,
SpirvInstruction *lhsVec) {
return spvBuilder.createUnaryOp(spvOp, vecType, lhsVec,
return spvBuilder.createUnaryOp(spvOp, outType, lhsVec,
expr->getOperatorLoc(), range);
};
return processEachVectorInMatrix(subExpr, subValue, actOnEachVec,
@ -7929,13 +7931,24 @@ void SpirvEmitter::assignToMSOutIndices(
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
const Expr *matrix, SpirvInstruction *matrixVal,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
SpirvInstruction *)>
actOnEachVector,
SourceLocation loc, SourceRange range) {
return processEachVectorInMatrix(matrix, matrix->getType(), matrixVal,
actOnEachVector, loc, range);
}
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
const Expr *matrix, QualType outputType, SpirvInstruction *matrixVal,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
SpirvInstruction *)>
actOnEachVector,
SourceLocation loc, SourceRange range) {
const auto matType = matrix->getType();
assert(isMxNMatrix(matType));
const QualType vecType = getComponentVectorType(astContext, matType);
assert(isMxNMatrix(matType) && isMxNMatrix(outputType));
const QualType inVecType = getComponentVectorType(astContext, matType);
const QualType outVecType = getComponentVectorType(astContext, outputType);
uint32_t rowCount = 0, colCount = 0;
hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
@ -7943,13 +7956,14 @@ SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
llvm::SmallVector<SpirvInstruction *, 4> vectors;
// Extract each component vector and do operation on it
for (uint32_t i = 0; i < rowCount; ++i) {
auto *lhsVec = spvBuilder.createCompositeExtract(vecType, matrixVal, {i},
auto *lhsVec = spvBuilder.createCompositeExtract(inVecType, matrixVal, {i},
matrix->getLocStart());
vectors.push_back(actOnEachVector(i, vecType, lhsVec));
vectors.push_back(actOnEachVector(i, inVecType, outVecType, lhsVec));
}
// Construct the result matrix
auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
auto *val =
spvBuilder.createCompositeConstruct(outputType, vectors, loc, range);
if (!val)
return nullptr;
val->setRValue();
@ -8056,15 +8070,15 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
case BO_MulAssign:
case BO_DivAssign:
case BO_RemAssign: {
const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc,
range](uint32_t index, QualType vecType,
SpirvInstruction *lhsVec) {
const auto actOnEachVec = [this, spvOp, rhsVal, rhs, loc, range](
uint32_t index, QualType inType,
QualType outType, SpirvInstruction *lhsVec) {
// For each vector of lhs, we need to load the corresponding vector of
// rhs and do the operation on them.
auto *rhsVec = spvBuilder.createCompositeExtract(vecType, rhsVal, {index},
auto *rhsVec = spvBuilder.createCompositeExtract(inType, rhsVal, {index},
rhs->getLocStart());
auto *val =
spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc, range);
spvBuilder.createBinaryOp(spvOp, outType, lhsVec, rhsVec, loc, range);
if (val)
val->setRValue();
return val;
@ -9066,6 +9080,15 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_firstbitlow: {
retVal = processIntrinsicFirstbit(callExpr, GLSLstd450::GLSLstd450FindILsb);
break;
}
case hlsl::IntrinsicOp::IOP_isnan: {
retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpIsNan,
/* doEachVec= */ true);
// OpIsNan returns a bool/vec<bool>, so the only valid layout is void. It
// will be the responsibility of the store to do an OpSelect and correctly
// convert this type to an externally storable type.
retVal->setLayoutRule(SpirvLayoutRule::Void);
break;
}
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@ -9075,7 +9098,6 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
@ -10030,14 +10052,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMad(const CallExpr *callExpr) {
if (isMxNMatrix(arg0->getType())) {
const auto actOnEachVec = [this, loc, arg1Instr, arg2Instr, arg1Loc,
arg2Loc,
range](uint32_t index, QualType vecType,
range](uint32_t index, QualType inType,
QualType outType,
SpirvInstruction *arg0Row) {
auto *arg1Row = spvBuilder.createCompositeExtract(
vecType, arg1Instr, {index}, arg1Loc, range);
inType, arg1Instr, {index}, arg1Loc, range);
auto *arg2Row = spvBuilder.createCompositeExtract(
vecType, arg2Instr, {index}, arg2Loc, range);
inType, arg2Instr, {index}, arg2Loc, range);
auto *fma = spvBuilder.createGLSLExtInst(
vecType, GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
outType, GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc, range);
spvBuilder.decorateNoContraction(fma, loc);
return fma;
};
@ -10257,13 +10280,14 @@ SpirvEmitter::processIntrinsicLdexp(const CallExpr *callExpr) {
uint32_t rowCount = 0, colCount = 0;
if (isMxNMatrix(paramType, nullptr, &rowCount, &colCount)) {
const auto actOnEachVec = [this, loc, expInstr, arg1Loc,
range](uint32_t index, QualType vecType,
range](uint32_t index, QualType inType,
QualType outType,
SpirvInstruction *xRowInstr) {
auto *expRowInstr = spvBuilder.createCompositeExtract(
vecType, expInstr, {index}, arg1Loc, range);
inType, expInstr, {index}, arg1Loc, range);
auto *twoExp = spvBuilder.createGLSLExtInst(
vecType, GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
return spvBuilder.createBinaryOp(spv::Op::OpFMul, vecType, xRowInstr,
outType, GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc, range);
return spvBuilder.createBinaryOp(spv::Op::OpFMul, outType, xRowInstr,
twoExp, loc, range);
};
return processEachVectorInMatrix(x, xInstr, actOnEachVec, loc, range);
@ -10427,15 +10451,15 @@ SpirvEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
// the operation on each vector of the matrix.
if (isMxNMatrix(argX->getType())) {
const auto actOnEachVec = [this, loc, range, glslOpcode, argMinInstr,
argMaxInstr, argMinLoc,
argMaxLoc](uint32_t index, QualType vecType,
SpirvInstruction *curRow) {
argMaxInstr, argMinLoc, argMaxLoc](
uint32_t index, QualType inType,
QualType outType, SpirvInstruction *curRow) {
auto *minRowInstr = spvBuilder.createCompositeExtract(
vecType, argMinInstr, {index}, argMinLoc, range);
inType, argMinInstr, {index}, argMinLoc, range);
auto *maxRowInstr = spvBuilder.createCompositeExtract(
vecType, argMaxInstr, {index}, argMaxLoc, range);
inType, argMaxInstr, {index}, argMaxLoc, range);
return spvBuilder.createGLSLExtInst(
vecType, glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
outType, glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc, range);
};
return processEachVectorInMatrix(argX, argXInstr, actOnEachVec, loc, range);
}
@ -11013,12 +11037,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
uint32_t numRows = 0, numCols = 0;
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
auto *vecOne = getVecValueOne(elemType, numCols);
const auto actOnEachVec = [this, vecOne, loc,
range](uint32_t /*index*/, QualType vecType,
SpirvInstruction *curRow) {
return spvBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne, curRow,
loc, range);
};
const auto actOnEachVec =
[this, vecOne, loc, range](uint32_t /*index*/, QualType inType,
QualType outType, SpirvInstruction *curRow) {
return spvBuilder.createBinaryOp(spv::Op::OpFDiv, outType, vecOne,
curRow, loc, range);
};
return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
}
@ -11335,10 +11359,10 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
auto *vecZero = getVecValueZero(elemType, numCols);
auto *vecOne = getVecValueOne(elemType, numCols);
const auto actOnEachVec = [this, loc, vecZero, vecOne,
range](uint32_t /*index*/, QualType vecType,
SpirvInstruction *curRow) {
return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FClamp,
const auto actOnEachVec = [this, loc, vecZero, vecOne, range](
uint32_t /*index*/, QualType inType,
QualType outType, SpirvInstruction *curRow) {
return spvBuilder.createGLSLExtInst(outType, GLSLstd450::GLSLstd450FClamp,
{curRow, vecZero, vecOne}, loc,
range);
};
@ -11364,10 +11388,10 @@ SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
// For matrices, we can perform the instruction on each vector of the matrix.
if (isMxNMatrix(argType)) {
const auto actOnEachVec = [this, loc, range](uint32_t /*index*/,
QualType vecType,
SpirvInstruction *curRow) {
return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FSign,
const auto actOnEachVec = [this, loc, range](
uint32_t /*index*/, QualType inType,
QualType outType, SpirvInstruction *curRow) {
return spvBuilder.createGLSLExtInst(outType, GLSLstd450::GLSLstd450FSign,
{curRow}, loc, range);
};
floatSign = processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
@ -11496,12 +11520,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
// If the instruction does not operate on matrices, we can perform the
// instruction on each vector of the matrix.
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
assert(isMxNMatrix(returnType));
const auto actOnEachVec = [this, opcode, loc,
range](uint32_t /*index*/, QualType vecType,
range](uint32_t /*index*/, QualType inType,
QualType outType,
SpirvInstruction *curRow) {
return spvBuilder.createUnaryOp(opcode, vecType, curRow, loc, range);
return spvBuilder.createUnaryOp(opcode, outType, curRow, loc, range);
};
return processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
return processEachVectorInMatrix(arg, returnType, argId, actOnEachVec,
loc, range);
}
return spvBuilder.createUnaryOp(opcode, returnType, argId, loc, range);
} else if (callExpr->getNumArgs() == 2u) {
@ -11514,11 +11541,12 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
// instruction on each vector of the matrix.
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
const auto actOnEachVec = [this, opcode, arg1Id, loc, range, arg1Loc,
arg1Range](uint32_t index, QualType vecType,
arg1Range](uint32_t index, QualType inType,
QualType outType,
SpirvInstruction *arg0Row) {
auto *arg1Row = spvBuilder.createCompositeExtract(
vecType, arg1Id, {index}, arg1Loc, arg1Range);
return spvBuilder.createBinaryOp(opcode, vecType, arg0Row, arg1Row, loc,
inType, arg1Id, {index}, arg1Loc, arg1Range);
return spvBuilder.createBinaryOp(opcode, outType, arg0Row, arg1Row, loc,
range);
};
return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec, loc, range);
@ -11546,9 +11574,10 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
// instruction on each vector of the matrix.
if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
const auto actOnEachVec = [this, loc, range,
opcode](uint32_t /*index*/, QualType vecType,
opcode](uint32_t /*index*/, QualType inType,
QualType outType,
SpirvInstruction *curRowInstr) {
return spvBuilder.createGLSLExtInst(vecType, opcode, {curRowInstr}, loc,
return spvBuilder.createGLSLExtInst(outType, opcode, {curRowInstr}, loc,
range);
};
return processEachVectorInMatrix(arg, argInstr, actOnEachVec, loc, range);
@ -11565,12 +11594,13 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
// instruction on each vector of the matrix.
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg1Range,
arg1Loc](uint32_t index, QualType vecType,
arg1Loc](uint32_t index, QualType inType,
QualType outType,
SpirvInstruction *arg0RowInstr) {
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
vecType, arg1Instr, {index}, arg1Loc, arg1Range);
inType, arg1Instr, {index}, arg1Loc, arg1Range);
return spvBuilder.createGLSLExtInst(
vecType, opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
outType, opcode, {arg0RowInstr, arg1RowInstr}, loc, range);
};
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,
range);
@ -11591,14 +11621,15 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
const auto actOnEachVec = [this, loc, range, opcode, arg1Instr, arg2Instr,
arg1Loc, arg2Loc, arg1Range,
arg2Range](uint32_t index, QualType vecType,
arg2Range](uint32_t index, QualType inType,
QualType outType,
SpirvInstruction *arg0RowInstr) {
auto *arg1RowInstr = spvBuilder.createCompositeExtract(
vecType, arg1Instr, {index}, arg1Loc, arg1Range);
inType, arg1Instr, {index}, arg1Loc, arg1Range);
auto *arg2RowInstr = spvBuilder.createCompositeExtract(
vecType, arg2Instr, {index}, arg2Loc, arg2Range);
inType, arg2Instr, {index}, arg2Loc, arg2Range);
return spvBuilder.createGLSLExtInst(
vecType, opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
outType, opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc,
range);
};
return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc,

Просмотреть файл

@ -361,7 +361,14 @@ private:
/// the value. It returns the <result-id> of the processed vector.
SpirvInstruction *processEachVectorInMatrix(
const Expr *matrix, SpirvInstruction *matrixVal,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
SpirvInstruction *)>
actOnEachVector,
SourceLocation loc = {}, SourceRange range = {});
SpirvInstruction *processEachVectorInMatrix(
const Expr *matrix, QualType outputType, SpirvInstruction *matrixVal,
llvm::function_ref<SpirvInstruction *(uint32_t, QualType, QualType,
SpirvInstruction *)>
actOnEachVector,
SourceLocation loc = {}, SourceRange range = {});

Просмотреть файл

@ -1,10 +1,16 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s
RWStructuredBuffer<float> buffer;
RWStructuredBuffer<float2x3> buffer_mat;
RWByteAddressBuffer byte_buffer;
void main() {
float a;
float4 b;
float2x3 c;
// CHECK: %isnan_c = OpVariable %_ptr_Function__arr_v3bool_uint_2 Function
// CHECK: [[a:%[0-9]+]] = OpLoad %float %a
// CHECK-NEXT: {{%[0-9]+}} = OpIsNan %bool [[a]]
bool isnan_a = isnan(a);
@ -13,6 +19,48 @@ void main() {
// CHECK-NEXT: {{%[0-9]+}} = OpIsNan %v4bool [[b]]
bool4 isnan_b = isnan(b);
// TODO: We can not translate the following since boolean matrices are currently not supported.
// bool2x3 isnan_c = isnan(c);
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_float %buffer %int_0 %uint_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %float [[ptr]]
// CHECK: [[res:%[0-9]+]] = OpIsNan %bool [[tmp]]
// CHECK: OpStore %res [[res]]
// CHECK: [[res:%[0-9]+]] = OpLoad %bool %res
// CHECK: [[tmp:%[0-9]+]] = OpSelect %float [[res]] %float_1 %float_0
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_float %buffer %int_0 %uint_0
// CHECK: OpStore [[ptr]] [[tmp]]
bool res = isnan(buffer[0]);
buffer[0] = (float)res;
// CHECK: [[c:%[0-9]+]] = OpLoad %mat2v3float %c
// CHECK: [[r0:%[0-9]+]] = OpCompositeExtract %v3float [[c]] 0
// CHECK: [[isnan_r0:%[0-9]+]] = OpIsNan %v3bool [[r0]]
// CHECK: [[r1:%[0-9]+]] = OpCompositeExtract %v3float [[c]] 1
// CHECK: [[isnan_r1:%[0-9]+]] = OpIsNan %v3bool [[r1]]
// CHECK: [[tmp:%[0-9]+]] = OpCompositeConstruct %_arr_v3bool_uint_2 [[isnan_r0]] [[isnan_r1]]
// CHECK: OpStore %isnan_c [[tmp]]
bool2x3 isnan_c = isnan(c);
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_mat2v3float %buffer_mat %int_0 %uint_0
// CHECK: [[tmp:%[0-9]+]] = OpLoad %mat2v3float [[ptr]]
// CHECK: [[r0:%[0-9]+]] = OpCompositeExtract %v3float [[tmp]] 0
// CHECK: [[isnan_r0:%[0-9]+]] = OpIsNan %v3bool [[r0]]
// CHECK: [[r1:%[0-9]+]] = OpCompositeExtract %v3float [[tmp]] 1
// CHECK: [[isnan_r1:%[0-9]+]] = OpIsNan %v3bool [[r1]]
// CHECK: [[tmp:%[0-9]+]] = OpCompositeConstruct %_arr_v3bool_uint_2 [[isnan_r0]] [[isnan_r1]]
// CHECK: OpStore %isnan_d [[tmp]]
bool2x3 isnan_d = isnan(buffer_mat[0]);
// CHECK: [[addr:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %byte_buffer %uint_0 [[addr]]
// CHECK: [[tmp:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[val:%[0-9]+]] = OpBitcast %float [[tmp]]
// CHECK: [[res:%[0-9]+]] = OpIsNan %bool [[val]]
// CHECK: OpStore %isnan_e [[res]]
bool isnan_e = isnan(byte_buffer.Load<float>(0));
// CHECK: [[res:%[0-9]+]] = OpLoad %bool %isnan_e
// CHECK: [[addr:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %byte_buffer %uint_0 [[addr]]
// CHECK: [[tmp:%[0-9]+]] = OpSelect %uint [[res]] %uint_1 %uint_0
// CHECK: OpStore [[ptr]] [[tmp]]
byte_buffer.Store(0, isnan_e);
}