[NFC][SPIR-V] Refactor SpirvGroupNonUniformOps (#6596)
A follow-up change will use the PartitionedExclusiveScanNV GroupOperation, which requires that an additional operand is added to all GroupNonUniformArithmetic instructions. This means that some of the SPIR-V opcodes which are currently categorized as unary will become either unary or binary depending on the GroupOp. Since the arity distinctions between the OpGroupNonUniform* instructions were already somewhat arbitrary, I'm prefacing that change by refactoring them into a single SpirvGroupNonUniformOp instruction type for better reusability. Follow up: #6608
This commit is contained in:
Родитель
ff623f8a74
Коммит
d9caef5289
|
@ -238,17 +238,10 @@ public:
|
|||
|
||||
/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
|
||||
/// opcode.
|
||||
SpirvNonUniformElect *createGroupNonUniformElect(spv::Op op,
|
||||
QualType resultType,
|
||||
spv::Scope execScope,
|
||||
SourceLocation);
|
||||
SpirvNonUniformUnaryOp *createGroupNonUniformUnaryOp(
|
||||
SourceLocation, spv::Op op, QualType resultType, spv::Scope execScope,
|
||||
SpirvInstruction *operand,
|
||||
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
|
||||
SpirvNonUniformBinaryOp *createGroupNonUniformBinaryOp(
|
||||
SpirvGroupNonUniformOp *createGroupNonUniformOp(
|
||||
spv::Op op, QualType resultType, spv::Scope execScope,
|
||||
SpirvInstruction *operand1, SpirvInstruction *operand2, SourceLocation);
|
||||
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
|
||||
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
|
||||
|
||||
/// \brief Creates an atomic instruction with the given parameters and returns
|
||||
/// its pointer.
|
||||
|
|
|
@ -111,11 +111,7 @@ public:
|
|||
|
||||
IK_SetMeshOutputsEXT, // OpSetMeshOutputsEXT
|
||||
|
||||
// The following section is for group non-uniform instructions.
|
||||
// Used by LLVM-style RTTI; order matters.
|
||||
IK_GroupNonUniformBinaryOp, // Group non-uniform binary operations
|
||||
IK_GroupNonUniformElect, // OpGroupNonUniformElect
|
||||
IK_GroupNonUniformUnaryOp, // Group non-uniform unary operations
|
||||
IK_GroupNonUniformOp, // Group non-uniform operations
|
||||
|
||||
IK_ImageOp, // OpImage*
|
||||
IK_ImageQuery, // OpImageQuery*
|
||||
|
@ -1495,102 +1491,43 @@ private:
|
|||
llvm::SmallVector<SpirvInstruction *, 4> args;
|
||||
};
|
||||
|
||||
/// \brief Base for OpGroupNonUniform* instructions
|
||||
/// \brief OpGroupNonUniform* instructions
|
||||
class SpirvGroupNonUniformOp : public SpirvInstruction {
|
||||
public:
|
||||
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
|
||||
llvm::ArrayRef<SpirvInstruction *> operands,
|
||||
SourceLocation loc,
|
||||
llvm::Optional<spv::GroupOperation> group);
|
||||
|
||||
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvGroupNonUniformOp)
|
||||
|
||||
// For LLVM-style RTTI
|
||||
static bool classof(const SpirvInstruction *inst) {
|
||||
return inst->getKind() >= IK_GroupNonUniformBinaryOp &&
|
||||
inst->getKind() <= IK_GroupNonUniformUnaryOp;
|
||||
return inst->getKind() == IK_GroupNonUniformOp;
|
||||
}
|
||||
|
||||
bool invokeVisitor(Visitor *v) override;
|
||||
|
||||
spv::Scope getExecutionScope() const { return execScope; }
|
||||
|
||||
protected:
|
||||
SpirvGroupNonUniformOp(Kind kind, spv::Op opcode, QualType resultType,
|
||||
SourceLocation loc, spv::Scope scope);
|
||||
llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }
|
||||
|
||||
bool hasGroupOp() const { return groupOp.hasValue(); }
|
||||
spv::GroupOperation getGroupOp() const { return groupOp.getValue(); }
|
||||
|
||||
void replaceOperand(
|
||||
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
|
||||
bool inEntryFunctionWrapper) override {
|
||||
for (auto *operand : getOperands()) {
|
||||
operand = remapOp(operand);
|
||||
}
|
||||
if (inEntryFunctionWrapper)
|
||||
setAstResultType(getOperands()[0]->getAstResultType());
|
||||
}
|
||||
|
||||
private:
|
||||
spv::Scope execScope;
|
||||
};
|
||||
|
||||
/// \brief OpGroupNonUniform* binary instructions.
|
||||
class SpirvNonUniformBinaryOp : public SpirvGroupNonUniformOp {
|
||||
public:
|
||||
SpirvNonUniformBinaryOp(spv::Op opcode, QualType resultType,
|
||||
SourceLocation loc, spv::Scope scope,
|
||||
SpirvInstruction *arg1, SpirvInstruction *arg2);
|
||||
|
||||
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformBinaryOp)
|
||||
|
||||
// For LLVM-style RTTI
|
||||
static bool classof(const SpirvInstruction *inst) {
|
||||
return inst->getKind() == IK_GroupNonUniformBinaryOp;
|
||||
}
|
||||
|
||||
bool invokeVisitor(Visitor *v) override;
|
||||
|
||||
SpirvInstruction *getArg1() const { return arg1; }
|
||||
SpirvInstruction *getArg2() const { return arg2; }
|
||||
void replaceOperand(
|
||||
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
|
||||
bool inEntryFunctionWrapper) override {
|
||||
arg1 = remapOp(arg1);
|
||||
arg2 = remapOp(arg2);
|
||||
}
|
||||
|
||||
private:
|
||||
SpirvInstruction *arg1;
|
||||
SpirvInstruction *arg2;
|
||||
};
|
||||
|
||||
/// \brief OpGroupNonUniformElect instruction. This is currently the only
|
||||
/// non-uniform instruction that takes no other arguments.
|
||||
class SpirvNonUniformElect : public SpirvGroupNonUniformOp {
|
||||
public:
|
||||
SpirvNonUniformElect(QualType resultType, SourceLocation loc,
|
||||
spv::Scope scope);
|
||||
|
||||
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformElect)
|
||||
|
||||
// For LLVM-style RTTI
|
||||
static bool classof(const SpirvInstruction *inst) {
|
||||
return inst->getKind() == IK_GroupNonUniformElect;
|
||||
}
|
||||
|
||||
bool invokeVisitor(Visitor *v) override;
|
||||
};
|
||||
|
||||
/// \brief OpGroupNonUniform* unary instructions.
|
||||
class SpirvNonUniformUnaryOp : public SpirvGroupNonUniformOp {
|
||||
public:
|
||||
SpirvNonUniformUnaryOp(spv::Op opcode, QualType resultType,
|
||||
SourceLocation loc, spv::Scope scope,
|
||||
llvm::Optional<spv::GroupOperation> group,
|
||||
SpirvInstruction *arg);
|
||||
|
||||
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformUnaryOp)
|
||||
|
||||
// For LLVM-style RTTI
|
||||
static bool classof(const SpirvInstruction *inst) {
|
||||
return inst->getKind() == IK_GroupNonUniformUnaryOp;
|
||||
}
|
||||
|
||||
bool invokeVisitor(Visitor *v) override;
|
||||
|
||||
SpirvInstruction *getArg() const { return arg; }
|
||||
bool hasGroupOp() const { return groupOp.hasValue(); }
|
||||
spv::GroupOperation getGroupOp() const { return groupOp.getValue(); }
|
||||
void replaceOperand(
|
||||
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
|
||||
bool inEntryFunctionWrapper) override {
|
||||
arg = remapOp(arg);
|
||||
if (inEntryFunctionWrapper)
|
||||
setAstResultType(arg->getAstResultType());
|
||||
}
|
||||
|
||||
private:
|
||||
SpirvInstruction *arg;
|
||||
llvm::SmallVector<SpirvInstruction *, 4> operands;
|
||||
llvm::Optional<spv::GroupOperation> groupOp;
|
||||
};
|
||||
|
||||
|
|
|
@ -96,9 +96,7 @@ public:
|
|||
DEFINE_VISIT_METHOD(SpirvEndPrimitive)
|
||||
DEFINE_VISIT_METHOD(SpirvExtInst)
|
||||
DEFINE_VISIT_METHOD(SpirvFunctionCall)
|
||||
DEFINE_VISIT_METHOD(SpirvNonUniformBinaryOp)
|
||||
DEFINE_VISIT_METHOD(SpirvNonUniformElect)
|
||||
DEFINE_VISIT_METHOD(SpirvNonUniformUnaryOp)
|
||||
DEFINE_VISIT_METHOD(SpirvGroupNonUniformOp)
|
||||
DEFINE_VISIT_METHOD(SpirvImageOp)
|
||||
DEFINE_VISIT_METHOD(SpirvImageQuery)
|
||||
DEFINE_VISIT_METHOD(SpirvImageSparseTexelsResident)
|
||||
|
|
|
@ -1087,35 +1087,7 @@ bool EmitVisitor::visit(SpirvFunctionCall *inst) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
|
||||
initInstruction(inst);
|
||||
curInst.push_back(inst->getResultTypeId());
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
|
||||
curInst.push_back(typeHandler.getOrCreateConstantInt(
|
||||
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
|
||||
context.getUIntType(32), /* isSpecConst */ false));
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg1()));
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg2()));
|
||||
finalizeInstruction(&mainBinary);
|
||||
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
|
||||
inst->getDebugName());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
|
||||
initInstruction(inst);
|
||||
curInst.push_back(inst->getResultTypeId());
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
|
||||
curInst.push_back(typeHandler.getOrCreateConstantInt(
|
||||
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
|
||||
context.getUIntType(32), /* isSpecConst */ false));
|
||||
finalizeInstruction(&mainBinary);
|
||||
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
|
||||
inst->getDebugName());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
|
||||
bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
|
||||
initInstruction(inst);
|
||||
curInst.push_back(inst->getResultTypeId());
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
|
||||
|
@ -1124,7 +1096,8 @@ bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
|
|||
context.getUIntType(32), /* isSpecConst */ false));
|
||||
if (inst->hasGroupOp())
|
||||
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg()));
|
||||
for (auto *operand : inst->getOperands())
|
||||
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
|
||||
finalizeInstruction(&mainBinary);
|
||||
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
|
||||
inst->getDebugName());
|
||||
|
|
|
@ -257,9 +257,7 @@ public:
|
|||
bool visit(SpirvCompositeInsert *) override;
|
||||
bool visit(SpirvExtInst *) override;
|
||||
bool visit(SpirvFunctionCall *) override;
|
||||
bool visit(SpirvNonUniformBinaryOp *) override;
|
||||
bool visit(SpirvNonUniformElect *) override;
|
||||
bool visit(SpirvNonUniformUnaryOp *) override;
|
||||
bool visit(SpirvGroupNonUniformOp *) override;
|
||||
bool visit(SpirvImageOp *) override;
|
||||
bool visit(SpirvImageQuery *) override;
|
||||
bool visit(SpirvImageSparseTexelsResident *) override;
|
||||
|
|
|
@ -294,17 +294,9 @@ bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
|
||||
// Went through each non-uniform binary operation and made sure the following
|
||||
// does not result in a wrong type deduction.
|
||||
tryToUpdateInstLitType(inst->getArg(), inst->getAstResultType());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
|
||||
// Went through each non-uniform unary operation and made sure the following
|
||||
// does not result in a wrong type deduction.
|
||||
tryToUpdateInstLitType(inst->getArg1(), inst->getAstResultType());
|
||||
bool LiteralTypeVisitor::visit(SpirvGroupNonUniformOp *inst) {
|
||||
for (auto *operand : inst->getOperands())
|
||||
tryToUpdateInstLitType(operand, inst->getAstResultType());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,8 +32,7 @@ public:
|
|||
bool visit(SpirvBitFieldExtract *) override;
|
||||
bool visit(SpirvSelect *) override;
|
||||
bool visit(SpirvVectorShuffle *) override;
|
||||
bool visit(SpirvNonUniformUnaryOp *) override;
|
||||
bool visit(SpirvNonUniformBinaryOp *) override;
|
||||
bool visit(SpirvGroupNonUniformOp *) override;
|
||||
bool visit(SpirvLoad *) override;
|
||||
bool visit(SpirvStore *) override;
|
||||
bool visit(SpirvConstantComposite *) override;
|
||||
|
|
|
@ -98,7 +98,6 @@ public:
|
|||
REMAP_FUNC_OP(ImageOp)
|
||||
REMAP_FUNC_OP(ExtInst)
|
||||
REMAP_FUNC_OP(Atomic)
|
||||
REMAP_FUNC_OP(NonUniformBinaryOp)
|
||||
REMAP_FUNC_OP(BitFieldInsert)
|
||||
REMAP_FUNC_OP(BitFieldExtract)
|
||||
REMAP_FUNC_OP(IntrinsicInstruction)
|
||||
|
@ -115,7 +114,7 @@ public:
|
|||
REMAP_FUNC_OP(Select)
|
||||
REMAP_FUNC_OP(Switch)
|
||||
REMAP_FUNC_OP(CopyObject)
|
||||
REMAP_FUNC_OP(NonUniformUnaryOp)
|
||||
REMAP_FUNC_OP(GroupNonUniformOp)
|
||||
|
||||
private:
|
||||
///< Whether in entry function wrapper, which will influence replace steps.
|
||||
|
|
|
@ -233,14 +233,9 @@ bool PreciseVisitor::visit(SpirvUnaryOp *inst) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PreciseVisitor::visit(SpirvNonUniformBinaryOp *inst) {
|
||||
inst->getArg1()->setPrecise(inst->isPrecise());
|
||||
inst->getArg2()->setPrecise(inst->isPrecise());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreciseVisitor::visit(SpirvNonUniformUnaryOp *inst) {
|
||||
inst->getArg()->setPrecise(inst->isPrecise());
|
||||
bool PreciseVisitor::visit(SpirvGroupNonUniformOp *inst) {
|
||||
for (auto *operand : inst->getOperands())
|
||||
operand->setPrecise(inst->isPrecise());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,7 @@ public:
|
|||
bool visit(SpirvStore *) override;
|
||||
bool visit(SpirvBinaryOp *) override;
|
||||
bool visit(SpirvUnaryOp *) override;
|
||||
bool visit(SpirvNonUniformBinaryOp *) override;
|
||||
bool visit(SpirvNonUniformUnaryOp *) override;
|
||||
bool visit(SpirvGroupNonUniformOp *) override;
|
||||
bool visit(SpirvExtInst *) override;
|
||||
bool visit(SpirvFunctionCall *) override;
|
||||
|
||||
|
|
|
@ -431,32 +431,13 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
|
|||
return instruction;
|
||||
}
|
||||
|
||||
SpirvNonUniformElect *SpirvBuilder::createGroupNonUniformElect(
|
||||
spv::Op op, QualType resultType, spv::Scope execScope, SourceLocation loc) {
|
||||
assert(insertPoint && "null insert point");
|
||||
auto *instruction =
|
||||
new (context) SpirvNonUniformElect(resultType, loc, execScope);
|
||||
insertPoint->addInstruction(instruction);
|
||||
return instruction;
|
||||
}
|
||||
|
||||
SpirvNonUniformUnaryOp *SpirvBuilder::createGroupNonUniformUnaryOp(
|
||||
SourceLocation loc, spv::Op op, QualType resultType, spv::Scope execScope,
|
||||
SpirvInstruction *operand, llvm::Optional<spv::GroupOperation> groupOp) {
|
||||
SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp(
|
||||
spv::Op op, QualType resultType, spv::Scope execScope,
|
||||
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation loc,
|
||||
llvm::Optional<spv::GroupOperation> groupOp) {
|
||||
assert(insertPoint && "null insert point");
|
||||
auto *instruction = new (context)
|
||||
SpirvNonUniformUnaryOp(op, resultType, loc, execScope, groupOp, operand);
|
||||
insertPoint->addInstruction(instruction);
|
||||
return instruction;
|
||||
}
|
||||
|
||||
SpirvNonUniformBinaryOp *SpirvBuilder::createGroupNonUniformBinaryOp(
|
||||
spv::Op op, QualType resultType, spv::Scope execScope,
|
||||
SpirvInstruction *operand1, SpirvInstruction *operand2,
|
||||
SourceLocation loc) {
|
||||
assert(insertPoint && "null insert point");
|
||||
auto *instruction = new (context) SpirvNonUniformBinaryOp(
|
||||
op, resultType, loc, execScope, operand1, operand2);
|
||||
SpirvGroupNonUniformOp(op, resultType, execScope, operands, loc, groupOp);
|
||||
insertPoint->addInstruction(instruction);
|
||||
return instruction;
|
||||
}
|
||||
|
|
|
@ -9421,8 +9421,8 @@ SpirvInstruction *SpirvEmitter::processWaveQuery(const CallExpr *callExpr,
|
|||
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
|
||||
callExpr->getExprLoc());
|
||||
const QualType retType = callExpr->getCallReturnType(astContext);
|
||||
return spvBuilder.createGroupNonUniformElect(
|
||||
opcode, retType, spv::Scope::Subgroup, callExpr->getExprLoc());
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
opcode, retType, spv::Scope::Subgroup, {}, callExpr->getExprLoc());
|
||||
}
|
||||
|
||||
SpirvInstruction *SpirvEmitter::processIsHelperLane(const CallExpr *callExpr,
|
||||
|
@ -9463,8 +9463,9 @@ SpirvInstruction *SpirvEmitter::processWaveVote(const CallExpr *callExpr,
|
|||
callExpr->getExprLoc());
|
||||
auto *predicate = doExpr(callExpr->getArg(0));
|
||||
const QualType retType = callExpr->getCallReturnType(astContext);
|
||||
return spvBuilder.createGroupNonUniformUnaryOp(
|
||||
callExpr->getExprLoc(), opcode, retType, spv::Scope::Subgroup, predicate);
|
||||
return spvBuilder.createGroupNonUniformOp(opcode, retType,
|
||||
spv::Scope::Subgroup, {predicate},
|
||||
callExpr->getExprLoc());
|
||||
}
|
||||
|
||||
spv::Op SpirvEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
|
||||
|
@ -9551,14 +9552,13 @@ SpirvEmitter::processWaveCountBits(const CallExpr *callExpr,
|
|||
const QualType u32Type = astContext.UnsignedIntTy;
|
||||
const QualType v4u32Type = astContext.getExtVectorType(u32Type, 4);
|
||||
const QualType retType = callExpr->getCallReturnType(astContext);
|
||||
auto *ballot = spvBuilder.createGroupNonUniformUnaryOp(
|
||||
srcLoc, spv::Op::OpGroupNonUniformBallot, v4u32Type, spv::Scope::Subgroup,
|
||||
predicate);
|
||||
auto *ballot = spvBuilder.createGroupNonUniformOp(
|
||||
spv::Op::OpGroupNonUniformBallot, v4u32Type, spv::Scope::Subgroup,
|
||||
{predicate}, srcLoc);
|
||||
|
||||
return spvBuilder.createGroupNonUniformUnaryOp(
|
||||
srcLoc, spv::Op::OpGroupNonUniformBallotBitCount, retType,
|
||||
spv::Scope::Subgroup, ballot,
|
||||
llvm::Optional<spv::GroupOperation>(groupOp));
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
spv::Op::OpGroupNonUniformBallotBitCount, retType, spv::Scope::Subgroup,
|
||||
{ballot}, srcLoc, groupOp);
|
||||
}
|
||||
|
||||
SpirvInstruction *SpirvEmitter::processWaveReductionOrPrefix(
|
||||
|
@ -9580,9 +9580,9 @@ SpirvInstruction *SpirvEmitter::processWaveReductionOrPrefix(
|
|||
callExpr->getExprLoc());
|
||||
auto *predicate = doExpr(callExpr->getArg(0));
|
||||
const QualType retType = callExpr->getCallReturnType(astContext);
|
||||
return spvBuilder.createGroupNonUniformUnaryOp(
|
||||
callExpr->getExprLoc(), opcode, retType, spv::Scope::Subgroup, predicate,
|
||||
llvm::Optional<spv::GroupOperation>(groupOp));
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
opcode, retType, spv::Scope::Subgroup, {predicate},
|
||||
callExpr->getExprLoc(), llvm::Optional<spv::GroupOperation>(groupOp));
|
||||
}
|
||||
|
||||
SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) {
|
||||
|
@ -9600,13 +9600,13 @@ SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) {
|
|||
// WaveReadLaneAt is in fact not a broadcast operation (even though its name
|
||||
// might incorrectly suggest so). The proper mapping to SPIR-V for
|
||||
// it is OpGroupNonUniformShuffle, *not* OpGroupNonUniformBroadcast.
|
||||
return spvBuilder.createGroupNonUniformBinaryOp(
|
||||
spv::Op::OpGroupNonUniformShuffle, retType, spv::Scope::Subgroup, value,
|
||||
doExpr(callExpr->getArg(1)), srcLoc);
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
spv::Op::OpGroupNonUniformShuffle, retType, spv::Scope::Subgroup,
|
||||
{value, doExpr(callExpr->getArg(1))}, srcLoc);
|
||||
else
|
||||
return spvBuilder.createGroupNonUniformUnaryOp(
|
||||
srcLoc, spv::Op::OpGroupNonUniformBroadcastFirst, retType,
|
||||
spv::Scope::Subgroup, value);
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
spv::Op::OpGroupNonUniformBroadcastFirst, retType, spv::Scope::Subgroup,
|
||||
{value}, srcLoc);
|
||||
}
|
||||
|
||||
SpirvInstruction *
|
||||
|
@ -9648,8 +9648,8 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
|
|||
llvm_unreachable("case should not appear here");
|
||||
}
|
||||
|
||||
return spvBuilder.createGroupNonUniformBinaryOp(
|
||||
opcode, retType, spv::Scope::Subgroup, value, target, srcLoc);
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc);
|
||||
}
|
||||
|
||||
SpirvInstruction *
|
||||
|
@ -9673,9 +9673,9 @@ SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
|
|||
SpirvInstruction *
|
||||
SpirvEmitter::processWaveActiveAllEqualScalar(SpirvInstruction *arg,
|
||||
clang::SourceLocation srcLoc) {
|
||||
return spvBuilder.createGroupNonUniformUnaryOp(
|
||||
srcLoc, spv::Op::OpGroupNonUniformAllEqual, astContext.BoolTy,
|
||||
spv::Scope::Subgroup, arg);
|
||||
return spvBuilder.createGroupNonUniformOp(
|
||||
spv::Op::OpGroupNonUniformAllEqual, astContext.BoolTy,
|
||||
spv::Scope::Subgroup, {arg}, srcLoc);
|
||||
}
|
||||
|
||||
SpirvInstruction *
|
||||
|
|
|
@ -64,9 +64,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEmitVertex)
|
|||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEndPrimitive)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInst)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvFunctionCall)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformBinaryOp)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformElect)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformUnaryOp)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvGroupNonUniformOp)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageOp)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageQuery)
|
||||
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageSparseTexelsResident)
|
||||
|
@ -662,65 +660,66 @@ SpirvFunctionCall::SpirvFunctionCall(QualType resultType, SourceLocation loc,
|
|||
loc, range),
|
||||
function(fn), args(argsVec.begin(), argsVec.end()) {}
|
||||
|
||||
SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(Kind kind, spv::Op op,
|
||||
QualType resultType,
|
||||
SourceLocation loc,
|
||||
spv::Scope scope)
|
||||
: SpirvInstruction(kind, op, resultType, loc), execScope(scope) {}
|
||||
SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
|
||||
spv::Op op, QualType resultType, spv::Scope scope,
|
||||
llvm::ArrayRef<SpirvInstruction *> operandsVec, SourceLocation loc,
|
||||
llvm::Optional<spv::GroupOperation> group)
|
||||
: SpirvInstruction(IK_GroupNonUniformOp, op, resultType, loc),
|
||||
execScope(scope), operands(operandsVec.begin(), operandsVec.end()),
|
||||
groupOp(group) {
|
||||
switch (op) {
|
||||
|
||||
SpirvNonUniformBinaryOp::SpirvNonUniformBinaryOp(
|
||||
spv::Op op, QualType resultType, SourceLocation loc, spv::Scope scope,
|
||||
SpirvInstruction *arg1Inst, SpirvInstruction *arg2Inst)
|
||||
: SpirvGroupNonUniformOp(IK_GroupNonUniformBinaryOp, op, resultType, loc,
|
||||
scope),
|
||||
arg1(arg1Inst), arg2(arg2Inst) {
|
||||
assert(op == spv::Op::OpGroupNonUniformBroadcast ||
|
||||
op == spv::Op::OpGroupNonUniformBallotBitExtract ||
|
||||
op == spv::Op::OpGroupNonUniformShuffle ||
|
||||
op == spv::Op::OpGroupNonUniformShuffleXor ||
|
||||
op == spv::Op::OpGroupNonUniformShuffleUp ||
|
||||
op == spv::Op::OpGroupNonUniformShuffleDown ||
|
||||
op == spv::Op::OpGroupNonUniformQuadBroadcast ||
|
||||
op == spv::Op::OpGroupNonUniformQuadSwap);
|
||||
}
|
||||
// Group non-uniform nullary operations.
|
||||
case spv::Op::OpGroupNonUniformElect:
|
||||
assert(operandsVec.size() == 0);
|
||||
break;
|
||||
|
||||
SpirvNonUniformElect::SpirvNonUniformElect(QualType resultType,
|
||||
SourceLocation loc, spv::Scope scope)
|
||||
: SpirvGroupNonUniformOp(IK_GroupNonUniformElect,
|
||||
spv::Op::OpGroupNonUniformElect, resultType, loc,
|
||||
scope) {}
|
||||
// Group non-uniform unary operations.
|
||||
case spv::Op::OpGroupNonUniformAll:
|
||||
case spv::Op::OpGroupNonUniformAny:
|
||||
case spv::Op::OpGroupNonUniformAllEqual:
|
||||
case spv::Op::OpGroupNonUniformBroadcastFirst:
|
||||
case spv::Op::OpGroupNonUniformBallot:
|
||||
case spv::Op::OpGroupNonUniformInverseBallot:
|
||||
case spv::Op::OpGroupNonUniformBallotBitCount:
|
||||
case spv::Op::OpGroupNonUniformBallotFindLSB:
|
||||
case spv::Op::OpGroupNonUniformBallotFindMSB:
|
||||
case spv::Op::OpGroupNonUniformIAdd:
|
||||
case spv::Op::OpGroupNonUniformFAdd:
|
||||
case spv::Op::OpGroupNonUniformIMul:
|
||||
case spv::Op::OpGroupNonUniformFMul:
|
||||
case spv::Op::OpGroupNonUniformSMin:
|
||||
case spv::Op::OpGroupNonUniformUMin:
|
||||
case spv::Op::OpGroupNonUniformFMin:
|
||||
case spv::Op::OpGroupNonUniformSMax:
|
||||
case spv::Op::OpGroupNonUniformUMax:
|
||||
case spv::Op::OpGroupNonUniformFMax:
|
||||
case spv::Op::OpGroupNonUniformBitwiseAnd:
|
||||
case spv::Op::OpGroupNonUniformBitwiseOr:
|
||||
case spv::Op::OpGroupNonUniformBitwiseXor:
|
||||
case spv::Op::OpGroupNonUniformLogicalAnd:
|
||||
case spv::Op::OpGroupNonUniformLogicalOr:
|
||||
case spv::Op::OpGroupNonUniformLogicalXor:
|
||||
assert(operandsVec.size() == 1);
|
||||
break;
|
||||
|
||||
SpirvNonUniformUnaryOp::SpirvNonUniformUnaryOp(
|
||||
spv::Op op, QualType resultType, SourceLocation loc, spv::Scope scope,
|
||||
llvm::Optional<spv::GroupOperation> group, SpirvInstruction *argInst)
|
||||
: SpirvGroupNonUniformOp(IK_GroupNonUniformUnaryOp, op, resultType, loc,
|
||||
scope),
|
||||
arg(argInst), groupOp(group) {
|
||||
assert(op == spv::Op::OpGroupNonUniformAll ||
|
||||
op == spv::Op::OpGroupNonUniformAny ||
|
||||
op == spv::Op::OpGroupNonUniformAllEqual ||
|
||||
op == spv::Op::OpGroupNonUniformBroadcastFirst ||
|
||||
op == spv::Op::OpGroupNonUniformBallot ||
|
||||
op == spv::Op::OpGroupNonUniformInverseBallot ||
|
||||
op == spv::Op::OpGroupNonUniformBallotBitCount ||
|
||||
op == spv::Op::OpGroupNonUniformBallotFindLSB ||
|
||||
op == spv::Op::OpGroupNonUniformBallotFindMSB ||
|
||||
op == spv::Op::OpGroupNonUniformIAdd ||
|
||||
op == spv::Op::OpGroupNonUniformFAdd ||
|
||||
op == spv::Op::OpGroupNonUniformIMul ||
|
||||
op == spv::Op::OpGroupNonUniformFMul ||
|
||||
op == spv::Op::OpGroupNonUniformSMin ||
|
||||
op == spv::Op::OpGroupNonUniformUMin ||
|
||||
op == spv::Op::OpGroupNonUniformFMin ||
|
||||
op == spv::Op::OpGroupNonUniformSMax ||
|
||||
op == spv::Op::OpGroupNonUniformUMax ||
|
||||
op == spv::Op::OpGroupNonUniformFMax ||
|
||||
op == spv::Op::OpGroupNonUniformBitwiseAnd ||
|
||||
op == spv::Op::OpGroupNonUniformBitwiseOr ||
|
||||
op == spv::Op::OpGroupNonUniformBitwiseXor ||
|
||||
op == spv::Op::OpGroupNonUniformLogicalAnd ||
|
||||
op == spv::Op::OpGroupNonUniformLogicalOr ||
|
||||
op == spv::Op::OpGroupNonUniformLogicalXor);
|
||||
// Group non-uniform binary operations.
|
||||
case spv::Op::OpGroupNonUniformBroadcast:
|
||||
case spv::Op::OpGroupNonUniformBallotBitExtract:
|
||||
case spv::Op::OpGroupNonUniformShuffle:
|
||||
case spv::Op::OpGroupNonUniformShuffleXor:
|
||||
case spv::Op::OpGroupNonUniformShuffleUp:
|
||||
case spv::Op::OpGroupNonUniformShuffleDown:
|
||||
case spv::Op::OpGroupNonUniformQuadBroadcast:
|
||||
case spv::Op::OpGroupNonUniformQuadSwap:
|
||||
assert(operandsVec.size() == 2);
|
||||
break;
|
||||
|
||||
// Unexpected opcode.
|
||||
default:
|
||||
assert(false && "Unexpected Group non-uniform opcode");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SpirvImageOp::SpirvImageOp(
|
||||
|
|
Загрузка…
Ссылка в новой задаче