[SPIR-V] Emit OpUndef for undefined values (#6686)

Before this change, OpConstantNull was emitted when an undef value was
required.
This causes an issue for some types which cannot have the OpConstantNull
value.

In addition, it mixed well-defined values with undefined values, which
prevents any kind of optimization/analysis later on.

Fixes #6653

---------

Signed-off-by: Nathan Gauër <brioche@google.com>
This commit is contained in:
Nathan Gauër 2024-06-13 09:21:34 +02:00 коммит произвёл GitHub
Родитель 84d39b66cf
Коммит 56f3c40381
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 167 добавлений и 9 удалений

Просмотреть файл

@ -748,6 +748,7 @@ public:
llvm::ArrayRef<SpirvConstant *> constituents,
bool specConst = false);
SpirvConstant *getConstantNull(QualType);
SpirvUndef *getUndef(QualType);
SpirvString *createString(llvm::StringRef str);
SpirvString *getString(llvm::StringRef str);

Просмотреть файл

@ -67,6 +67,9 @@ public:
IK_ConstantComposite,
IK_ConstantNull,
// OpUndef
IK_Undef,
// Function structure kinds
IK_FunctionParameter, // OpFunctionParameter
@ -1302,6 +1305,22 @@ public:
bool operator==(const SpirvConstantNull &that) const;
};
class SpirvUndef : public SpirvInstruction {
public:
SpirvUndef(QualType type);
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvUndef)
// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_Undef;
}
bool operator==(const SpirvUndef &that) const;
bool invokeVisitor(Visitor *v) override;
};
/// \brief OpCompositeConstruct instruction
class SpirvCompositeConstruct : public SpirvInstruction {
public:

Просмотреть файл

@ -142,6 +142,9 @@ public:
// Adds a constant to the module.
void addConstant(SpirvConstant *);
// Adds an Undef to the module.
void addUndef(SpirvUndef *);
// Adds given string to the module which will be emitted via OpString.
void addString(SpirvString *);
@ -202,6 +205,7 @@ private:
decorations;
std::vector<SpirvConstant *> constants;
std::vector<SpirvUndef *> undefs;
std::vector<SpirvVariable *> variables;
// A vector of functions in the module in the order that they should be
// emitted. The order starts with the entry-point function followed by a

Просмотреть файл

@ -89,6 +89,7 @@ public:
DEFINE_VISIT_METHOD(SpirvConstantFloat)
DEFINE_VISIT_METHOD(SpirvConstantComposite)
DEFINE_VISIT_METHOD(SpirvConstantNull)
DEFINE_VISIT_METHOD(SpirvUndef)
DEFINE_VISIT_METHOD(SpirvCompositeConstruct)
DEFINE_VISIT_METHOD(SpirvCompositeExtract)
DEFINE_VISIT_METHOD(SpirvCompositeInsert)

Просмотреть файл

@ -1006,6 +1006,13 @@ bool EmitVisitor::visit(SpirvConstantNull *inst) {
return true;
}
bool EmitVisitor::visit(SpirvUndef *inst) {
typeHandler.getOrCreateUndef(inst);
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvCompositeConstruct *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
@ -2010,6 +2017,8 @@ uint32_t EmitTypeHandler::getOrCreateConstant(SpirvConstant *inst) {
return getOrCreateConstantNull(constNull);
} else if (auto *constBool = dyn_cast<SpirvConstantBoolean>(inst)) {
return getOrCreateConstantBool(constBool);
} else if (auto *constUndef = dyn_cast<SpirvUndef>(inst)) {
return getOrCreateUndef(constUndef);
}
llvm_unreachable("cannot emit unknown constant type");
@ -2070,6 +2079,31 @@ uint32_t EmitTypeHandler::getOrCreateConstantNull(SpirvConstantNull *inst) {
return inst->getResultId();
}
uint32_t EmitTypeHandler::getOrCreateUndef(SpirvUndef *inst) {
auto canonicalType = inst->getAstResultType().getCanonicalType();
auto found = std::find_if(
emittedUndef.begin(), emittedUndef.end(),
[canonicalType](SpirvUndef *cached) {
return cached->getAstResultType().getCanonicalType() == canonicalType;
});
if (found != emittedUndef.end()) {
// We have already emitted this constant. Reuse.
inst->setResultId((*found)->getResultId());
return inst->getResultId();
}
// Constant wasn't emitted in the past.
const uint32_t typeId = emitType(inst->getResultType());
initTypeInstruction(inst->getopcode());
curTypeInst.push_back(typeId);
curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
finalizeTypeInstruction();
// Remember this constant for the future
emittedUndef.push_back(inst);
return inst->getResultId();
}
uint32_t EmitTypeHandler::getOrCreateConstantFloat(SpirvConstantFloat *inst) {
llvm::APFloat value = inst->getValue();
const SpirvType *type = inst->getResultType();

Просмотреть файл

@ -57,7 +57,7 @@ public:
typeConstantBinary(typesVec), takeNextIdFunction(takeNextIdFn),
emittedConstantInts({}), emittedConstantFloats({}),
emittedConstantComposites({}), emittedConstantNulls({}),
emittedConstantBools() {
emittedUndef({}), emittedConstantBools() {
assert(decVec);
assert(typesVec);
}
@ -107,6 +107,7 @@ public:
uint32_t getOrCreateConstantFloat(SpirvConstantFloat *);
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
uint32_t getOrCreateUndef(SpirvUndef *);
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
template <typename vecType>
void emitLiteral(const SpirvConstant *, vecType &outInst);
@ -172,6 +173,7 @@ private:
emittedConstantFloats;
llvm::SmallVector<SpirvConstantComposite *, 8> emittedConstantComposites;
llvm::SmallVector<SpirvConstantNull *, 8> emittedConstantNulls;
llvm::SmallVector<SpirvUndef *, 8> emittedUndef;
SpirvConstantBoolean *emittedConstantBools[2];
llvm::DenseSet<const SpirvInstruction *> emittedSpecConstantInstructions;
@ -252,6 +254,7 @@ public:
bool visit(SpirvConstantFloat *) override;
bool visit(SpirvConstantComposite *) override;
bool visit(SpirvConstantNull *) override;
bool visit(SpirvUndef *) override;
bool visit(SpirvCompositeConstruct *) override;
bool visit(SpirvCompositeExtract *) override;
bool visit(SpirvCompositeInsert *) override;

Просмотреть файл

@ -1826,6 +1826,13 @@ SpirvConstant *SpirvBuilder::getConstantNull(QualType type) {
return nullConst;
}
SpirvUndef *SpirvBuilder::getUndef(QualType type) {
// We do not care about making unique constants at this point.
auto *undef = new (context) SpirvUndef(type);
mod->addUndef(undef);
return undef;
}
SpirvString *SpirvBuilder::createString(llvm::StringRef str) {
// Create a SpirvString instruction
auto *instr = new (context) SpirvString(/* SourceLocation */ {}, str);

Просмотреть файл

@ -1517,10 +1517,9 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
spvBuilder.createReturn(returnLoc);
} else {
// If the source code does not provide a proper return value for some
// control flow path, it's undefined behavior. We just return null
// value here.
spvBuilder.createReturnValue(spvBuilder.getConstantNull(retType),
returnLoc);
// control flow path, it's undefined behavior. We just return an
// undefined value here.
spvBuilder.createReturnValue(spvBuilder.getUndef(retType), returnLoc);
}
}
}

Просмотреть файл

@ -57,6 +57,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantInteger)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantFloat)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantNull)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUndef)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeConstruct)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeExtract)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeInsert)
@ -540,6 +541,11 @@ bool SpirvConstant::operator==(const SpirvConstant &that) const {
if (thatNullInst == nullptr)
return false;
return *nullInst == *thatNullInst;
} else if (auto *nullInst = dyn_cast<SpirvUndef>(this)) {
auto *thatNullInst = dyn_cast<SpirvUndef>(&that);
if (thatNullInst == nullptr)
return false;
return *nullInst == *thatNullInst;
}
assert(false && "operator== undefined for SpirvConstant subclass");
@ -613,6 +619,15 @@ bool SpirvConstantNull::operator==(const SpirvConstantNull &that) const {
astResultType == that.astResultType;
}
SpirvUndef::SpirvUndef(QualType type)
: SpirvInstruction(IK_Undef, spv::Op::OpUndef, type,
/*SourceLocation*/ {}) {}
bool SpirvUndef::operator==(const SpirvUndef &that) const {
return opcode == that.opcode && resultType == that.resultType &&
astResultType == that.astResultType;
}
SpirvCompositeExtract::SpirvCompositeExtract(QualType resultType,
SourceLocation loc,
SpirvInstruction *compositeInst,

Просмотреть файл

@ -17,8 +17,8 @@ namespace spirv {
SpirvModule::SpirvModule()
: capabilities({}), extensions({}), extInstSets({}), memoryModel(nullptr),
entryPoints({}), executionModes({}), moduleProcesses({}), decorations({}),
constants({}), variables({}), functions({}), debugInstructions({}),
perVertexInterp(false) {}
constants({}), undefs({}), variables({}), functions({}),
debugInstructions({}), perVertexInterp(false) {}
SpirvModule::~SpirvModule() {
for (auto *cap : capabilities)
@ -43,6 +43,8 @@ SpirvModule::~SpirvModule() {
decoration->releaseMemory();
for (auto *constant : constants)
constant->releaseMemory();
for (auto *undef : undefs)
undef->releaseMemory();
for (auto *var : variables)
var->releaseMemory();
for (auto *di : debugInstructions)
@ -91,6 +93,12 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
return false;
}
for (auto iter = undefs.rbegin(); iter != undefs.rend(); ++iter) {
auto *undef = *iter;
if (!undef->invokeVisitor(visitor))
return false;
}
// Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
// manual indexing.
for (auto decorIndex = decorations.size(); decorIndex > 0; --decorIndex) {
@ -203,6 +211,10 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
if (!constant->invokeVisitor(visitor))
return false;
for (auto undef : undefs)
if (!undef->invokeVisitor(visitor))
return false;
for (auto var : variables)
if (!var->invokeVisitor(visitor))
return false;
@ -334,6 +346,11 @@ void SpirvModule::addConstant(SpirvConstant *constant) {
constants.push_back(constant);
}
void SpirvModule::addUndef(SpirvUndef *undef) {
assert(undef);
undefs.push_back(undef);
}
void SpirvModule::addString(SpirvString *str) {
assert(str);
constStrings.push_back(str);

Просмотреть файл

@ -0,0 +1,27 @@
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
// CHECK: [[undef:%[0-9]+]] = OpUndef %type_2d_image
Texture2D texA;
Texture2D texB;
// CHECK: %select = OpFunction
Texture2D select(bool lhs) {
// CHECK: %if_true = OpLabel
if (lhs)
return texA;
// CHECK: %if_false = OpLabel
else
return texB;
// no return for dead branch.
// CHECK: %if_merge = OpLabel
// CHECK-NEXT: OpReturnValue [[undef]]
}
// CHECK-NEXT: OpFunctionEnd
float main(bool a: A) : B {
Texture2D tex = select(true);
return 1.0;
}

Просмотреть файл

@ -1,12 +1,12 @@
// RUN: %dxc -T vs_6_0 -E main -Wno-return-type -fcgl %s -spirv | FileCheck %s
// CHECK:[[null:%[0-9]+]] = OpConstantNull %float
// CHECK:[[undef:%[0-9]+]] = OpUndef %float
float main(bool a: A) : B {
if (a) return 1.0;
// No return value for else
// CHECK: %if_merge = OpLabel
// CHECK-NEXT: OpReturnValue [[null]]
// CHECK-NEXT: OpReturnValue [[undef]]
// CHECK-NEXT: OpFunctionEnd
}

Просмотреть файл

@ -0,0 +1,31 @@
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
// CHECK: [[undef:%[0-9]+]] = OpUndef %type_2d_image
// CHECK-NOT: OpUndef %type_2d_image
Texture2D texA;
Texture2D texB;
Texture2D select1(bool lhs) {
if (lhs)
return texA;
else
return texB;
// no return for dead branch.
}
Texture2D select2(bool lhs) {
if (lhs)
return texA;
else
return texB;
// no return for dead branch.
}
float main(bool a: A) : B {
Texture2D x = select1(true);
Texture2D y = select2(true);
return 1.0;
}