[spirv] Minimally working compiler using v2.

This commit is contained in:
Ehsan Nasiri 2018-11-12 14:18:48 -05:00 коммит произвёл Ehsan
Родитель 2bae115d7f
Коммит 1f32d35d21
21 изменённых файлов: 659 добавлений и 327 удалений

Просмотреть файл

@ -31,7 +31,7 @@ enum class SpirvLayoutRule {
RelaxedGLSLStd430, // std430 with relaxed vector layout
FxcCTBuffer, // fxc.exe layout rule for cbuffer/tbuffer
FxcSBuffer, // fxc.exe layout rule for structured buffers
Scalar, // VK_EXT_scalar_block_layout
Max, // This is an invalid layout rule
};
struct SpirvCodeGenOptions {

Просмотреть файл

@ -125,6 +125,9 @@ bool canFitIntoOneRegister(QualType structType, QualType *elemType,
/// struct member type.
QualType getElementType(QualType type);
QualType getTypeWithCustomBitwidth(const ASTContext &, QualType type,
uint32_t bitwidth);
} // namespace spirv
} // namespace clang

Просмотреть файл

@ -26,9 +26,9 @@ class SpirvType;
//
// Mostly from DenseMapInfo<unsigned> in DenseMapInfo.h.
struct SpirvLayoutRuleDenseMapInfo {
static inline SpirvLayoutRule getEmptyKey() { return SpirvLayoutRule::Void; }
static inline SpirvLayoutRule getEmptyKey() { return SpirvLayoutRule::Max; }
static inline SpirvLayoutRule getTombstoneKey() {
return SpirvLayoutRule::Void;
return SpirvLayoutRule::Max;
}
static unsigned getHashValue(const SpirvLayoutRule &Val) {
return static_cast<unsigned>(Val) * 37U;
@ -85,6 +85,17 @@ private:
void emitLayoutDecorations(const StructType *, SpirvLayoutRule);
// There is no guarantee that an instruction or a function or a basic block
// has been assigned result-id. This method returns the result-id for the
// given object. If a result-id has not been assigned yet, it'll assign
// one and return it.
template <class T> uint32_t getResultId(T *obj) {
if (!obj->getResultId()) {
obj->setResultId(takeNextIdFunction());
}
return obj->getResultId();
}
private:
/// Emits error to the diagnostic engine associated with this visitor.
template <unsigned N>
@ -114,6 +125,22 @@ private:
/// \breif The visitor class that emits the SPIR-V words from the in-memory
/// representation.
class EmitVisitor : public Visitor {
public:
/// \brief The struct representing a SPIR-V module header.
struct Header {
/// \brief Default constructs a SPIR-V module header with id bound 0.
Header(uint32_t bound);
/// \brief Feeds the consumer with all the SPIR-V words for this header.
std::vector<uint32_t> takeBinary();
const uint32_t magicNumber;
uint32_t version;
const uint32_t generator;
uint32_t bound;
const uint32_t reserved;
};
public:
EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
const SpirvCodeGenOptions &opts)
@ -142,7 +169,6 @@ public:
bool visit(SpirvFunctionParameter *);
bool visit(SpirvLoopMerge *);
bool visit(SpirvSelectionMerge *);
bool visit(SpirvBranching *);
bool visit(SpirvBranch *);
bool visit(SpirvBranchConditional *);
bool visit(SpirvKill *);
@ -181,10 +207,24 @@ public:
bool visit(SpirvUnaryOp *);
bool visit(SpirvVectorShuffle *);
// Returns the assembled binary built up in this visitor.
std::vector<uint32_t> takeBinary();
private:
// Returns the next available result-id.
uint32_t takeNextId() { return ++id; }
// There is no guarantee that an instruction or a function or a basic block
// has been assigned result-id. This method returns the result-id for the
// given object. If a result-id has not been assigned yet, it'll assign
// one and return it.
template <class T> uint32_t getResultId(T *obj) {
if (!obj->getResultId()) {
obj->setResultId(takeNextId());
}
return obj->getResultId();
}
// Initiates the creation of a new instruction with the given Opcode.
void initInstruction(spv::Op);
// Initiates the creation of the given SPIR-V instruction.

Просмотреть файл

@ -25,6 +25,18 @@ public:
const SpirvCodeGenOptions &opts)
: Visitor(opts, spvCtx), astContext(astCtx), spvContext(spvCtx) {}
// Visiting different SPIR-V constructs.
bool visit(SpirvModule *, Phase) { return true; }
bool visit(SpirvFunction *, Phase);
bool visit(SpirvBasicBlock *, Phase) { return true; }
/// The "sink" visit function for all instructions.
///
/// By default, all other visit instructions redirect to this visit function.
/// So that you want override this visit function to handle all instructions,
/// regardless of their polymorphism.
bool visitInstruction(SpirvInstruction *instr);
private:
/// Emits error to the diagnostic engine associated with this visitor.
template <unsigned N>

Просмотреть файл

@ -137,6 +137,20 @@ struct StorageClassDenseMapInfo {
}
};
// Provides DenseMapInfo for QualType so that we can use it key to DenseMap.
//
// Mostly from DenseMapInfo<unsigned> in DenseMapInfo.h.
struct QualTypeDenseMapInfo {
static inline QualType getEmptyKey() { return {}; }
static inline QualType getTombstoneKey() { return {}; }
static unsigned getHashValue(const QualType &Val) {
return static_cast<unsigned>(Val.getTypePtr()->getScalarTypeKind()) * 37U;
}
static bool isEqual(const QualType &LHS, const QualType &RHS) {
return LHS == RHS;
}
};
/// The class owning various SPIR-V entities allocated in memory during CodeGen.
///
/// All entities should be allocated from an object of this class using
@ -184,6 +198,7 @@ public:
spv::ImageFormat);
const SamplerType *getSamplerType() const { return samplerType; }
const SampledImageType *getSampledImageType(const ImageType *image);
const HybridSampledImageType *getSampledImageType(QualType image);
const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount);
const RuntimeArrayType *getRuntimeArrayType(const SpirvType *elemType);
@ -204,10 +219,10 @@ public:
spv::StorageClass);
const HybridPointerType *getPointerType(QualType pointee, spv::StorageClass);
const FunctionType *getFunctionType(const SpirvType *ret,
FunctionType *getFunctionType(const SpirvType *ret,
llvm::ArrayRef<const SpirvType *> param);
HybridFunctionType *getFunctionType(QualType ret,
llvm::ArrayRef<const SpirvType *> param);
const HybridFunctionType *
getFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param);
const StructType *getByteAddressBufferType(bool isWritable);
const StructType *getACSBufferCounterType();
@ -235,7 +250,7 @@ private:
SpirvConstant *getConstantInt(T value, bool isSigned, uint32_t bitwidth,
bool specConst) {
const IntegerType *intType =
isSigend ? getSIntType(bitwidth) : getUIntType(bitwidth);
isSigned ? getSIntType(bitwidth) : getUIntType(bitwidth);
SpirvConstantInteger tempConstant(intType, value, specConst);
auto found =
@ -332,6 +347,8 @@ private:
llvm::SmallVector<const ImageType *, 8> imageTypes;
const SamplerType *samplerType;
llvm::DenseMap<const ImageType *, const SampledImageType *> sampledImageTypes;
llvm::DenseMap<QualType, const HybridSampledImageType *, QualTypeDenseMapInfo>
hybridSampledImageTypes;
llvm::DenseMap<const SpirvType *, CountToArrayMap> arrayTypes;
llvm::DenseMap<const SpirvType *, const RuntimeArrayType *> runtimeArrayTypes;
@ -340,10 +357,11 @@ private:
llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::DenseMap<QualType, SCToHybridPtrTyMap> hybridPointerTypes;
llvm::DenseMap<QualType, SCToHybridPtrTyMap, QualTypeDenseMapInfo>
hybridPointerTypes;
llvm::SmallVector<const FunctionType *, 8> functionTypes;
llvm::SmallVector<const HybridFunctionType *, 8> hybridFunctionTypes;
llvm::SmallVector<FunctionType *, 8> functionTypes;
llvm::SmallVector<HybridFunctionType *, 8> hybridFunctionTypes;
// Unique constants
// We currently do a linear search to find an existing constant (if any). This

Просмотреть файл

@ -36,7 +36,8 @@ public:
SpirvBasicBlock &operator=(SpirvBasicBlock &&) = delete;
/// Returns the label's <result-id> of this basic block.
uint32_t getLabelId() const { return labelId; }
uint32_t getResultId() const { return labelId; }
void setResultId(uint32_t id) { labelId = id; }
/// Returns the debug name of this basic block.
llvm::StringRef getName() const { return labelName; }

Просмотреть файл

@ -55,9 +55,9 @@ public:
/// on failure.
///
/// At any time, there can only exist at most one function under building.
SpirvFunction *beginFunction(QualType returnType,
const SpirvType *functionType, SourceLocation,
llvm::StringRef name = "");
SpirvFunction *beginFunction(QualType returnType, SpirvType *functionType,
SourceLocation, llvm::StringRef name = "",
SpirvFunction *func = nullptr);
/// \brief Creates and registers a function parameter of the given pointer
/// type in the current function and returns its pointer.
@ -70,10 +70,12 @@ public:
/// functions. This does not change the current function under construction.
/// The handle can be used to create function call instructions for functions
/// that we have not yet been discovered in the source code.
/*
SpirvFunction *createFunction(QualType returnType,
const SpirvType *functionType, SourceLocation,
llvm::StringRef name = "",
bool isAlias = false);
*/
/// \brief Creates a local variable of the given type in the current
/// function and returns it.
@ -199,6 +201,9 @@ public:
SpirvBinaryOp *createBinaryOp(spv::Op op, QualType resultType,
SpirvInstruction *lhs, SpirvInstruction *rhs,
SourceLocation loc = {});
SpirvBinaryOp *createBinaryOp(spv::Op op, const SpirvType *,
SpirvInstruction *lhs, SpirvInstruction *rhs,
SourceLocation loc = {});
SpirvSpecConstantBinaryOp *
createSpecConstantBinaryOp(spv::Op op, QualType resultType,
SpirvInstruction *lhs, SpirvInstruction *rhs,
@ -538,6 +543,9 @@ public:
void decorateNonUniformEXT(SpirvInstruction *target,
SourceLocation srcLoc = {});
public:
std::vector<uint32_t> takeModule();
private:
/// \brief Returns the composed ImageOperandsMask from non-zero parameters
/// and pushes non-zero parameters to *orderedParams in the expected order.

Просмотреть файл

@ -24,7 +24,7 @@ class SpirvVisitor;
/// The class representing a SPIR-V function in memory.
class SpirvFunction {
public:
SpirvFunction(QualType returnType, const SpirvType *fnSpirvType, uint32_t id,
SpirvFunction(QualType astReturnType, SpirvType *fnSpirvType, uint32_t id,
spv::FunctionControlMask, SourceLocation,
llvm::StringRef name = "");
~SpirvFunction() = default;
@ -40,9 +40,8 @@ public:
// Handle SPIR-V function visitors.
bool invokeVisitor(Visitor *);
// TODO: The responsibility of assigning the result-id of a function shouldn't
// be on the function itself.
uint32_t getResultId() const { return functionId; }
void setResultId(uint32_t id) { functionId = id; }
// TODO: There should be a pass for lowering QualType to SPIR-V type,
// and this method should be able to return the result-id of the SPIR-V type.
@ -52,15 +51,18 @@ public:
uint32_t getReturnTypeId() const { return returnTypeId; }
void setReturnTypeId(uint32_t id) { returnTypeId = id; }
// Sets the lowered (SPIR-V) function type.
// Sets the lowered (SPIR-V) return type.
void setReturnType(SpirvType *type) { returnType = type; }
// Returns the lowered (SPIR-V) function type.
const SpirvType *getReturnType() const { return returnType; }
// Returns the lowered (SPIR-V) return type.
SpirvType *getReturnType() const { return returnType; }
void setAstReturnType(QualType type) { astReturnType = type; }
QualType getAstReturnType() const { return astReturnType; }
// Sets the SPIR-V type of the function
void setFunctionType(FunctionType *type) { fnType = type; }
void setFunctionType(SpirvType *type) { fnType = type; }
// Returns the SPIR-V type of the function
const FunctionType *getFunctionType() const { return fnType; }
SpirvType *getFunctionType() const { return fnType; }
// Sets the result-id of the OpTypeFunction
void setFunctionTypeId(uint32_t id) { fnTypeId = id; }
@ -87,8 +89,8 @@ private:
SpirvType *returnType; ///< The lowered return type
uint32_t returnTypeId; ///< result-id for the return type
const SpirvType *fnType; ///< The SPIR-V function type
uint32_t fnTypeId; ///< result-id for the SPIR-V function type
SpirvType *fnType; ///< The SPIR-V function type
uint32_t fnTypeId; ///< result-id for the SPIR-V function type
bool containsAlias; ///< Whether function return type is aliased
bool rvalue; ///< Whether the return value is an rvalue

Просмотреть файл

@ -1062,10 +1062,10 @@ private:
class SpirvConstantComposite : public SpirvConstant {
public:
SpirvConstantComposite(const SpirvType *type,
llvm::ArrayRef<const SpirvConstant *> constituents,
llvm::ArrayRef<SpirvConstant *> constituents,
bool isSpecConst = false);
SpirvConstantComposite(QualType type,
llvm::ArrayRef<const SpirvConstant *> constituents,
llvm::ArrayRef<SpirvConstant *> constituents,
bool isSpecConst = false);
// For LLVM-style RTTI
@ -1077,12 +1077,12 @@ public:
DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite)
llvm::ArrayRef<const SpirvConstant *> getConstituents() const {
llvm::ArrayRef<SpirvConstant *> getConstituents() const {
return constituents;
}
private:
llvm::SmallVector<const SpirvConstant *, 4> constituents;
llvm::SmallVector<SpirvConstant *, 4> constituents;
};
class SpirvConstantNull : public SpirvConstant {

Просмотреть файл

@ -39,8 +39,10 @@ public:
TK_Struct,
TK_Pointer,
TK_Function,
// Order matters: all the following are hybrid types
TK_HybridStruct,
TK_HybridPointer,
TK_HybridSampledImage,
TK_HybridFunction,
};
@ -335,6 +337,7 @@ public:
return returnType == that.returnType && paramTypes == that.paramTypes;
}
//void setReturnType(const SpirvType *t) { returnType = t; }
const SpirvType *getReturnType() const { return returnType; }
llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
@ -436,26 +439,44 @@ private:
spv::StorageClass storageClass;
};
class HybridSampledImageType : public HybridType {
public:
HybridSampledImageType(QualType image)
: HybridType(TK_HybridSampledImage), imageType(image) {}
static bool classof(const SpirvType *t) {
return t->getKind() == TK_HybridSampledImage;
}
QualType getImageType() const { return imageType; }
private:
QualType imageType;
};
// This class can be extended to also accept QualType vector as param types.
class HybridFunctionType : public HybridType {
public:
HybridFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param)
: HybridType(TK_HybridFunction), returnType(ret),
: HybridType(TK_HybridFunction), astReturnType(ret),
paramTypes(param.begin(), param.end()) {}
static bool classof(const SpirvType *t) {
return t->getKind() == TK_Function;
return t->getKind() == TK_HybridFunction;
}
bool operator==(const HybridFunctionType &that) const {
return returnType == that.returnType && paramTypes == that.paramTypes;
return astReturnType == that.astReturnType &&
returnType == that.returnType && paramTypes == that.paramTypes;
}
QualType getReturnType() const { return returnType; }
void setReturnType(const SpirvType *t) { returnType = t; }
const SpirvType *getReturnType() const { return returnType; }
llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
private:
QualType returnType;
QualType astReturnType;
const SpirvType *returnType;
llvm::SmallVector<const SpirvType *, 8> paramTypes;
};

Просмотреть файл

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "clang/SPIRV/AstTypeProbe.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/HlslTypes.h"
@ -538,5 +539,55 @@ QualType getElementType(QualType type) {
return type;
}
QualType getTypeWithCustomBitwidth(const ASTContext &ctx, QualType type,
uint32_t bitwidth) {
// Cases where the given type is a vector of float/int.
{
QualType elemType = {};
uint32_t elemCount = 0;
const bool isVec = isVectorType(type, &elemType, &elemCount);
if (isVec) {
return ctx.getExtVectorType(
getTypeWithCustomBitwidth(ctx, elemType, bitwidth), elemCount);
}
}
// Scalar cases.
assert(!type->isBooleanType());
assert(type->isIntegerType() || type->isFloatingType());
if (type->isFloatingType()) {
switch (bitwidth) {
case 16:
return ctx.HalfTy;
case 32:
return ctx.FloatTy;
case 64:
return ctx.DoubleTy;
}
}
if (type->isSignedIntegerType()) {
switch (bitwidth) {
case 16:
return ctx.ShortTy;
case 32:
return ctx.IntTy;
case 64:
return ctx.LongLongTy;
}
}
if (type->isUnsignedIntegerType()) {
switch (bitwidth) {
case 16:
return ctx.UnsignedShortTy;
case 32:
return ctx.UnsignedIntTy;
case 64:
return ctx.UnsignedLongLongTy;
}
}
llvm_unreachable(
"invalid type or bitwidth passed to getTypeWithCustomBitwidth");
}
} // namespace spirv
} // namespace clang

Просмотреть файл

@ -182,19 +182,14 @@ bool shouldSkipInStructLayout(const Decl *decl) {
return false;
}
void collectDeclsInNamespace(const NamespaceDecl *nsDecl,
llvm::SmallVector<const Decl *, 4> *decls) {
for (const auto *decl : nsDecl->decls()) {
collectDeclsInField(decl, decls);
}
}
void collectDeclsInField(const Decl *field,
llvm::SmallVector<const Decl *, 4> *decls) {
// Case of nested namespaces.
if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
collectDeclsInNamespace(nsDecl, decls);
for (const auto *decl : nsDecl->decls()) {
collectDeclsInField(decl, decls);
}
}
if (shouldSkipInStructLayout(field))
@ -895,8 +890,9 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
bool isAlias = false;
(void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
SpirvFunction *spirvFunction = spvBuilder.createFunction(
fn->getReturnType(), fn->getLocation(), fn->getName(), isAlias);
SpirvFunction *spirvFunction = new (spvContext) SpirvFunction(
fn->getReturnType(), /*functionType*/ nullptr, /*id*/ 0,
spv::FunctionControlMask::MaskNone, fn->getLocation(), fn->getName());
// No need to dereference to get the pointer. Function returns that are
// stand-alone aliases are already pointers to values. All other cases should
@ -2058,7 +2054,7 @@ bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
const auto found = stageVarInstructions.find(cast<DeclaratorDecl>(decl));
// We should have recorded its stage output variable previously.
assert(found != stageVarIds.end());
assert(found != stageVarInstructions.end());
// Negate SV_Position.y if requested
if (semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position)

Просмотреть файл

@ -18,6 +18,9 @@
namespace {
constexpr uint32_t kGeneratorNumber = 14;
constexpr uint32_t kToolVersion = 0;
/// The alignment for 4-component float vectors.
constexpr uint32_t kStd140Vec4Alignment = 16u;
@ -77,11 +80,29 @@ uint32_t signExtendTo32Bits(int16_t value) {
}
return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
}
} // anonymous namespace
namespace clang {
namespace spirv {
EmitVisitor::Header::Header(uint32_t bound_)
// We are using the unfied header, which shows spv::Version as the newest
// version. But we need to stick to 1.0 for Vulkan consumption by default.
: magicNumber(spv::MagicNumber), version(0x00010000),
generator((kGeneratorNumber << 16) | kToolVersion), bound(bound_),
reserved(0) {}
std::vector<uint32_t> EmitVisitor::Header::takeBinary() {
std::vector<uint32_t> words;
words.push_back(magicNumber);
words.push_back(version);
words.push_back(generator);
words.push_back(bound);
words.push_back(reserved);
return words;
}
void EmitVisitor::emitDebugNameForInstruction(uint32_t resultId,
llvm::StringRef debugName) {
// Most instructions do not have a debug name associated with them.
@ -150,6 +171,21 @@ void EmitVisitor::finalizeInstruction() {
}
}
std::vector<uint32_t> EmitVisitor::takeBinary() {
std::vector<uint32_t> result;
Header header(takeNextId());
auto headerBinary = header.takeBinary();
result.insert(result.end(), headerBinary.begin(), headerBinary.end());
result.insert(result.end(), preambleBinary.begin(), preambleBinary.end());
result.insert(result.end(), debugBinary.begin(), debugBinary.end());
result.insert(result.end(), annotationsBinary.begin(),
annotationsBinary.end());
result.insert(result.end(), typeConstantBinary.begin(),
typeConstantBinary.end());
result.insert(result.end(), mainBinary.begin(), mainBinary.end());
return result;
}
void EmitVisitor::encodeString(llvm::StringRef value) {
const auto &words = string::encodeSPIRVString(value);
curInst.insert(curInst.end(), words.begin(), words.end());
@ -179,12 +215,13 @@ bool EmitVisitor::visit(SpirvFunction *fn, Phase phase) {
// Emit OpFunction
initInstruction(spv::Op::OpFunction);
curInst.push_back(returnTypeId);
curInst.push_back(fn->getResultId());
curInst.push_back(getResultId<SpirvFunction>(fn));
curInst.push_back(
static_cast<uint32_t>(spv::FunctionControlMask::MaskNone));
curInst.push_back(functionTypeId);
finalizeInstruction();
emitDebugNameForInstruction(fn->getResultId(), fn->getFunctionName());
emitDebugNameForInstruction(getResultId<SpirvFunction>(fn),
fn->getFunctionName());
}
// After emitting the function
else if (phase == Visitor::Phase::Done) {
@ -203,9 +240,10 @@ bool EmitVisitor::visit(SpirvBasicBlock *bb, Phase phase) {
if (phase == Visitor::Phase::Init) {
// Emit OpLabel
initInstruction(spv::Op::OpLabel);
curInst.push_back(bb->getLabelId());
curInst.push_back(getResultId<SpirvBasicBlock>(bb));
finalizeInstruction();
emitDebugNameForInstruction(bb->getLabelId(), bb->getName());
emitDebugNameForInstruction(getResultId<SpirvBasicBlock>(bb),
bb->getName());
}
// After emitting the basic block
else if (phase == Visitor::Phase::Done) {
@ -230,7 +268,7 @@ bool EmitVisitor::visit(SpirvExtension *ext) {
bool EmitVisitor::visit(SpirvExtInstImport *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
encodeString(inst->getExtendedInstSetName());
finalizeInstruction();
return true;
@ -247,18 +285,19 @@ bool EmitVisitor::visit(SpirvMemoryModel *inst) {
bool EmitVisitor::visit(SpirvEntryPoint *inst) {
initInstruction(inst);
curInst.push_back(static_cast<uint32_t>(inst->getExecModel()));
curInst.push_back(inst->getEntryPoint()->getResultId());
curInst.push_back(getResultId<SpirvFunction>(inst->getEntryPoint()));
encodeString(inst->getEntryPointName());
for (auto *var : inst->getInterface())
curInst.push_back(var->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(var));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvExecutionMode *inst) {
initInstruction(inst);
curInst.push_back(inst->getEntryPoint()->getResultId());
curInst.push_back(getResultId<SpirvFunction>(inst->getEntryPoint()));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
curInst.insert(curInst.end(), inst->getParams().begin(),
inst->getParams().end());
@ -268,7 +307,7 @@ bool EmitVisitor::visit(SpirvExecutionMode *inst) {
bool EmitVisitor::visit(SpirvString *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
encodeString(inst->getString());
finalizeInstruction();
return true;
@ -279,7 +318,7 @@ bool EmitVisitor::visit(SpirvSource *inst) {
curInst.push_back(static_cast<uint32_t>(inst->getSourceLanguage()));
curInst.push_back(static_cast<uint32_t>(inst->getVersion()));
if (inst->hasFile())
curInst.push_back(inst->getFile()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getFile()));
if (!inst->getSource().empty()) {
// Note: in order to improve performance and avoid multiple copies, we
// encode this (potentially large) string directly into the debugBinary.
@ -301,7 +340,7 @@ bool EmitVisitor::visit(SpirvModuleProcessed *inst) {
bool EmitVisitor::visit(SpirvDecoration *inst) {
initInstruction(inst);
curInst.push_back(inst->getTarget()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getTarget()));
if (inst->isMemberDecoration())
curInst.push_back(inst->getMemberIndex());
curInst.push_back(static_cast<uint32_t>(inst->getDecoration()));
@ -310,8 +349,8 @@ bool EmitVisitor::visit(SpirvDecoration *inst) {
inst->getParams().end());
}
if (!inst->getIdParams().empty()) {
curInst.insert(curInst.end(), inst->getIdParams().begin(),
inst->getIdParams().end());
for (auto *paramInstr : inst->getIdParams())
curInst.push_back(getResultId<SpirvInstruction>(paramInstr));
}
finalizeInstruction();
return true;
@ -320,104 +359,115 @@ bool EmitVisitor::visit(SpirvDecoration *inst) {
bool EmitVisitor::visit(SpirvVariable *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getStorageClass()));
if (inst->hasInitializer())
curInst.push_back(inst->getInitializer()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getInitializer()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvFunctionParameter *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvLoopMerge *inst) {
initInstruction(inst);
curInst.push_back(inst->getMergeBlock()->getLabelId());
curInst.push_back(inst->getContinueTarget()->getLabelId());
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getMergeBlock()));
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getContinueTarget()));
curInst.push_back(static_cast<uint32_t>(inst->getLoopControlMask()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSelectionMerge *inst) {
initInstruction(inst);
curInst.push_back(inst->getMergeBlock()->getLabelId());
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getMergeBlock()));
curInst.push_back(static_cast<uint32_t>(inst->getSelectionControlMask()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvBranch *inst) {
initInstruction(inst);
curInst.push_back(inst->getTargetLabel()->getLabelId());
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getTargetLabel()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvBranchConditional *inst) {
initInstruction(inst);
curInst.push_back(inst->getCondition()->getResultId());
curInst.push_back(inst->getTrueLabel()->getLabelId());
curInst.push_back(inst->getFalseLabel()->getLabelId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getCondition()));
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getTrueLabel()));
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getFalseLabel()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvKill *inst) {
initInstruction(inst);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvReturn *inst) {
initInstruction(inst);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSwitch *inst) {
initInstruction(inst);
curInst.push_back(inst->getSelector()->getResultId());
curInst.push_back(inst->getDefaultLabel()->getLabelId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getSelector()));
curInst.push_back(getResultId<SpirvBasicBlock>(inst->getDefaultLabel()));
for (const auto &target : inst->getTargets()) {
curInst.push_back(target.first);
curInst.push_back(target.second->getLabelId());
curInst.push_back(getResultId<SpirvBasicBlock>(target.second));
}
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvUnreachable *inst) {
initInstruction(inst);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvAccessChain *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getBase()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getBase()));
for (const auto index : inst->getIndexes())
curInst.push_back(index->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(index));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
@ -426,19 +476,20 @@ bool EmitVisitor::visit(SpirvAtomic *inst) {
initInstruction(inst);
if (op != spv::Op::OpAtomicStore && op != spv::Op::OpAtomicFlagClear) {
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
}
curInst.push_back(inst->getPointer()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getPointer()));
curInst.push_back(static_cast<uint32_t>(inst->getScope()));
curInst.push_back(static_cast<uint32_t>(inst->getMemorySemantics()));
if (inst->hasComparator())
curInst.push_back(static_cast<uint32_t>(inst->getMemorySemanticsUnequal()));
if (inst->hasValue())
curInst.push_back(inst->getValue()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getValue()));
if (inst->hasComparator())
curInst.push_back(inst->getComparator()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getComparator()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
@ -449,59 +500,64 @@ bool EmitVisitor::visit(SpirvBarrier *inst) {
curInst.push_back(static_cast<uint32_t>(inst->getMemoryScope()));
curInst.push_back(static_cast<uint32_t>(inst->getMemorySemantics()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvBinaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getOperand1()->getResultId());
curInst.push_back(inst->getOperand2()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand1()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand2()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvBitFieldExtract *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getBase()->getResultId());
curInst.push_back(inst->getOffset()->getResultId());
curInst.push_back(inst->getCount()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getBase()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOffset()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getCount()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvBitFieldInsert *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getBase()->getResultId());
curInst.push_back(inst->getInsert()->getResultId());
curInst.push_back(inst->getOffset()->getResultId());
curInst.push_back(inst->getCount()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getBase()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getInsert()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOffset()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getCount()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvConstantBoolean *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvConstantInteger *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
// 16-bit cases
if (inst->getBitwidth() == 16) {
if (inst->isSigned()) {
@ -536,14 +592,15 @@ bool EmitVisitor::visit(SpirvConstantInteger *inst) {
curInst.push_back(words.word1);
}
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvConstantFloat *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
if (inst->getBitwidth() == 16) {
// According to the SPIR-V Spec:
// When the type's bit width is less than 32-bits, the literal's value
@ -564,63 +621,69 @@ bool EmitVisitor::visit(SpirvConstantFloat *inst) {
curInst.push_back(words.word1);
}
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvConstantComposite *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
for (const auto constituent : inst->getConstituents())
curInst.push_back(constituent->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
for (auto constituent : inst->getConstituents())
curInst.push_back(getResultId<SpirvInstruction>(constituent));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvConstantNull *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvComposite *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
for (const auto constituent : inst->getConstituents())
curInst.push_back(constituent->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(constituent));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvCompositeExtract *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getComposite()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getComposite()));
for (const auto constituent : inst->getIndexes())
curInst.push_back(constituent);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvCompositeInsert *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getObject()->getResultId());
curInst.push_back(inst->getComposite()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getObject()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getComposite()));
for (const auto constituent : inst->getIndexes())
curInst.push_back(constituent);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
@ -639,60 +702,65 @@ bool EmitVisitor::visit(SpirvEndPrimitive *inst) {
bool EmitVisitor::visit(SpirvExtInst *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getInstructionSet()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getInstructionSet()));
curInst.push_back(inst->getInstruction());
for (const auto operand : inst->getOperands())
curInst.push_back(operand->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(operand));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvFunctionCall *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getFunction()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvFunction>(inst->getFunction()));
for (const auto arg : inst->getArgs())
curInst.push_back(arg->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(arg));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
curInst.push_back(inst->getArg1()->getResultId());
curInst.push_back(inst->getArg2()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getArg1()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getArg2()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionScope()));
if (inst->hasGroupOp())
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
curInst.push_back(inst->getArg()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getArg()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
@ -701,170 +769,182 @@ bool EmitVisitor::visit(SpirvImageOp *inst) {
if (!inst->isImageWrite()) {
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
}
curInst.push_back(inst->getImage()->getResultId());
curInst.push_back(inst->getCoordinate()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getImage()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getCoordinate()));
if (inst->isImageWrite())
curInst.push_back(inst->getTexelToWrite()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getTexelToWrite()));
if (inst->hasDref())
curInst.push_back(inst->getDref()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getDref()));
if (inst->hasComponent())
curInst.push_back(inst->getComponent()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getComponent()));
if (inst->getImageOperandsMask() != spv::ImageOperandsMask::MaskNone) {
curInst.push_back(static_cast<uint32_t>(inst->getImageOperandsMask()));
if (inst->hasBias())
curInst.push_back(inst->getBias()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getBias()));
if (inst->hasLod())
curInst.push_back(inst->getLod()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getLod()));
if (inst->hasGrad()) {
curInst.push_back(inst->getGradDx()->getResultId());
curInst.push_back(inst->getGradDy()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getGradDx()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getGradDy()));
}
if (inst->hasConstOffset())
curInst.push_back(inst->getConstOffset()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getConstOffset()));
if (inst->hasOffset())
curInst.push_back(inst->getOffset()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getOffset()));
if (inst->hasConstOffsets())
curInst.push_back(inst->getConstOffsets()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getConstOffsets()));
if (inst->hasSample())
curInst.push_back(inst->getSample()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getSample()));
if (inst->hasMinLod())
curInst.push_back(inst->getMinLod()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getMinLod()));
}
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvImageQuery *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getImage()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getImage()));
if (inst->hasCoordinate())
curInst.push_back(inst->getCoordinate()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getCoordinate()));
if (inst->hasLod())
curInst.push_back(inst->getLod()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getLod()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvImageSparseTexelsResident *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getResidentCode()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getResidentCode()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvImageTexelPointer *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getImage()->getResultId());
curInst.push_back(inst->getCoordinate()->getResultId());
curInst.push_back(inst->getSample()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getImage()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getCoordinate()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getSample()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvLoad *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getPointer()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getPointer()));
if (inst->hasMemoryAccessSemantics())
curInst.push_back(static_cast<uint32_t>(inst->getMemoryAccess()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSampledImage *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getImage()->getResultId());
curInst.push_back(inst->getSampler()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getImage()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getSampler()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSelect *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getCondition()->getResultId());
curInst.push_back(inst->getTrueObject()->getResultId());
curInst.push_back(inst->getFalseObject()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getCondition()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getTrueObject()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getFalseObject()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSpecConstantBinaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getSpecConstantopcode()));
curInst.push_back(inst->getOperand1()->getResultId());
curInst.push_back(inst->getOperand2()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand1()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand2()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvSpecConstantUnaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(static_cast<uint32_t>(inst->getSpecConstantopcode()));
curInst.push_back(inst->getOperand()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvStore *inst) {
initInstruction(inst);
curInst.push_back(inst->getPointer()->getResultId());
curInst.push_back(inst->getObject()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst->getPointer()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getObject()));
if (inst->hasMemoryAccessSemantics())
curInst.push_back(static_cast<uint32_t>(inst->getMemoryAccess()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvUnaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getOperand()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getOperand()));
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
bool EmitVisitor::visit(SpirvVectorShuffle *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(inst->getResultId());
curInst.push_back(inst->getVec1()->getResultId());
curInst.push_back(inst->getVec2()->getResultId());
curInst.push_back(getResultId<SpirvInstruction>(inst));
curInst.push_back(getResultId<SpirvInstruction>(inst->getVec1()));
curInst.push_back(getResultId<SpirvInstruction>(inst->getVec2()));
for (const auto component : inst->getComponents())
curInst.push_back(component);
finalizeInstruction();
emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
emitDebugNameForInstruction(getResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}
@ -996,14 +1076,14 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
// the array length.
SpirvConstant *constant =
spirvContext.getConstantUint32(arrayType->getElementCount());
if (constant->getResultId() == 0) {
if (getResultId<SpirvInstruction>(constant) == 0) {
constant->setResultId(takeNextIdFunction());
}
IntegerType constantIntType(32, 0);
const uint32_t uint32TypeId = emitType(&constantIntType, rule);
initTypeInstruction(spv::Op::OpConstant);
curTypeInst.push_back(uint32TypeId);
curTypeInst.push_back(constant->getResultId());
curTypeInst.push_back(getResultId<SpirvInstruction>(constant));
curTypeInst.push_back(arrayType->getElementCount());
finalizeTypeInstruction();
@ -1012,7 +1092,7 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
initTypeInstruction(spv::Op::OpTypeArray);
curTypeInst.push_back(id);
curTypeInst.push_back(elemTypeId);
curTypeInst.push_back(constant->getResultId());
curTypeInst.push_back(getResultId<SpirvInstruction>(constant));
finalizeTypeInstruction();
// ArrayStride decoration is needed for array types, but we won't have
@ -1089,6 +1169,20 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
curTypeInst.push_back(paramTypeId);
finalizeTypeInstruction();
}
// Hybrid Function types
else if (const auto *fnType = dyn_cast<HybridFunctionType>(type)) {
const uint32_t retTypeId = emitType(fnType->getReturnType(), rule);
llvm::SmallVector<uint32_t, 4> paramTypeIds;
for (auto *paramType : fnType->getParamTypes())
paramTypeIds.push_back(emitType(paramType, rule));
initTypeInstruction(spv::Op::OpTypeFunction);
curTypeInst.push_back(id);
curTypeInst.push_back(retTypeId);
for (auto paramTypeId : paramTypeIds)
curTypeInst.push_back(paramTypeId);
finalizeTypeInstruction();
}
// Unhandled types
else {
llvm_unreachable("unhandled type in emitType");

Просмотреть файл

@ -7,16 +7,47 @@
//
//===----------------------------------------------------------------------===//
#include "LowerTypeVisitor.h"
#include "clang/SPIRV/LowerTypeVisitor.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/HlslTypes.h"
#include "clang/SPIRV/AstTypeProbe.h"
#include "clang/SPIRV/SpirvFunction.h"
namespace clang {
namespace spirv {
bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
if (phase == Visitor::Phase::Init) {
// Lower the function return type.
const SpirvType *spirvReturnType =
lowerType(fn->getAstReturnType(), SpirvLayoutRule::Void,
/*SourceLocation*/ {});
fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
// In case the function type is a hybrid type, we should also lower the
// return type of the SPIR-V function type.
if (auto *fnRetType = dyn_cast<HybridFunctionType>(fn->getFunctionType())) {
fnRetType->setReturnType(spirvReturnType);
}
}
return true;
}
bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
if (instr->getAstResultType() != QualType({})) {
const auto loweredType =
lowerType(instr->getAstResultType(), instr->getLayoutRule(),
instr->getSourceLocation());
instr->setResultType(loweredType);
return loweredType != nullptr;
}
return true;
}
const SpirvType *LowerTypeVisitor::lowerType(QualType type,
SpirvLayoutRule rule,
SourceLocation srcLoc) {

Просмотреть файл

@ -201,6 +201,17 @@ SpirvContext::getSampledImageType(const ImageType *image) {
return sampledImageTypes[image] = new (this) SampledImageType(image);
}
const HybridSampledImageType *
SpirvContext::getSampledImageType(QualType image) {
auto found = hybridSampledImageTypes.find(image);
if (found != hybridSampledImageTypes.end())
return found->second;
return hybridSampledImageTypes[image] =
new (this) HybridSampledImageType(image);
}
const ArrayType *SpirvContext::getArrayType(const SpirvType *elemType,
uint32_t elemCount) {
auto foundElemType = arrayTypes.find(elemType);
@ -306,7 +317,7 @@ const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
new (this) HybridPointerType(pointee, sc);
}
const FunctionType *
FunctionType *
SpirvContext::getFunctionType(const SpirvType *ret,
llvm::ArrayRef<const SpirvType *> param) {
// Create a temporary object for finding in the vector.
@ -324,7 +335,7 @@ SpirvContext::getFunctionType(const SpirvType *ret,
return functionTypes.back();
}
const HybridFunctionType *
HybridFunctionType *
SpirvContext::getFunctionType(QualType ret,
llvm::ArrayRef<const SpirvType *> param) {
// Create a temporary object for finding in the vector.

Просмотреть файл

@ -708,9 +708,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
return;
// Output the constructed module.
// TODO(ehsan): Switch to new infra. Should get the module binary from the
// EmitVisitor.
std::vector<uint32_t> m = theBuilder.takeModule();
std::vector<uint32_t> m = spvBuilder.takeModule();
if (!spirvOptions.codeGenHighLevel) {
// Run legalization passes
@ -1144,13 +1142,14 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
paramTypes.push_back(ptrType);
}
const auto *funcType = spvContext.getFunctionType(retType, paramTypes);
auto *funcType = spvContext.getFunctionType(retType, paramTypes);
spvBuilder.beginFunction(retType, funcType, decl->getLocation(), funcName);
if (isNonStaticMemberFn) {
// Remember the parameter for the this object so later we can handle
// CXXThisExpr correctly.
curThis = spvBuilder.addFnParam(paramTypes[0], "param.this");
curThis = spvBuilder.addFnParam(paramTypes[0], /*SourceLocation*/ {},
"param.this");
}
// Create all parameters.
@ -1183,7 +1182,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
}
}
theBuilder.endFunction();
spvBuilder.endFunction();
}
bool SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) {
@ -1515,7 +1514,7 @@ void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
if (const Stmt *body = theDoStmt->getBody()) {
doStmt(body);
}
if (!theBuilder.isCurrentBasicBlockTerminated())
if (!spvBuilder.isCurrentBasicBlockTerminated())
spvBuilder.createBranch(continueBB);
spvBuilder.addSuccessor(continueBB);
@ -1857,7 +1856,7 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt,
// Handle the then branch
spvBuilder.setInsertPoint(thenBB);
doStmt(ifStmt->getThen());
if (!theBuilder.isCurrentBasicBlockTerminated())
if (!spvBuilder.isCurrentBasicBlockTerminated())
spvBuilder.createBranch(mergeBB);
spvBuilder.addSuccessor(mergeBB);
@ -1993,8 +1992,8 @@ SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
assert(!foundNonUniformResourceIndex);
llvm::SmallVector<SpirvInstruction *, 4> indices;
const auto *base =
collectArrayStructIndices(expr, &indices, /*rawIndices*/ nullptr);
const auto *base = collectArrayStructIndices(
expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
auto *info = loadIfAliasVarRef(base);
if (foundNonUniformResourceIndex) {
@ -2323,6 +2322,8 @@ SpirvInstruction *SPIRVEmitter::doCastExpr(const CastExpr *expr) {
// is casting from size-4 vector to size-2-by-2 matrix.
auto *vec = loadIfGLValue(subExpr);
// TODO: remove this line:
(void)vec;
QualType elemType = {};
uint32_t rowCount = 0, colCount = 0;
@ -3039,9 +3040,7 @@ SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
auto *samplerState = doExpr(expr->getArg(0));
auto *coordinate = doExpr(expr->getArg(1));
// EHSAN: need to create SPIRV sampledImageType that takes QualType :(
const auto *sampledImageType =
spvContext.getSampledImageType(object->getType());
auto *sampledImageType = spvContext.getSampledImageType(object->getType());
// EHSAN: need to create createBinaryOp that takes SpirvType :(
auto *sampledImage = spvBuilder.createBinaryOp(
spv::Op::OpSampledImage, sampledImageType, objectInfo, samplerState);
@ -3533,21 +3532,21 @@ SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
return declIdMapper.getCounterIdAliasPair(decl);
// AssocCounter#2: referencing some non-struct field
llvm::SmallVector<uint32_t, 4> indices;
llvm::SmallVector<uint32_t, 4> rawIndices;
const auto *base =
collectArrayStructIndices(expr, &indices, /*rawIndex=*/true);
const auto *base = collectArrayStructIndices(
expr, /*rawIndex=*/true, &rawIndices, /*indices*/ nullptr);
const auto *decl =
(base && isa<CXXThisExpr>(base))
? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
: getReferencedDef(base);
return declIdMapper.getCounterIdAliasPair(decl, &indices);
return declIdMapper.getCounterIdAliasPair(decl, &rawIndices);
}
const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
const Expr *expr, llvm::SmallVector<uint32_t, 4> *indices) {
const auto *base =
collectArrayStructIndices(expr, indices, /*rawIndex=*/true);
const Expr *expr, llvm::SmallVector<uint32_t, 4> *rawIndices) {
const auto *base = collectArrayStructIndices(expr, /*rawIndex=*/true,
rawIndices, /*indices*/ nullptr);
const auto *decl =
(base && isa<CXXThisExpr>(base))
// Use the decl we created to represent the implicit object
@ -3619,7 +3618,7 @@ SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
auto *value = doExpr(expr->getArg(0));
declIdMapper.writeBackOutputStream(stream, stream->getType(), value);
theBuilder.createEmitVertex();
spvBuilder.createEmitVertex();
return nullptr;
}
@ -3690,7 +3689,6 @@ SPIRVEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
// }
const auto v2f32Type = astContext.getExtVectorType(astContext.FloatTy, 2);
const uint32_t boolType = theBuilder.getBoolType();
// Creates a SPIR-V function scope variable of type float2[len].
const auto createArray = [this, v2f32Type](const Float2 *ptr, uint32_t len) {
@ -4003,8 +4001,6 @@ SpirvInstruction *SPIRVEmitter::createImageSample(
SpirvInstruction *constOffsets, SpirvInstruction *sample,
SpirvInstruction *minLod, SpirvInstruction *residencyCodeId) {
const auto retType = retType;
// SampleDref* instructions in SPIR-V always return a scalar.
// They also have the correct type in HLSL.
if (compareVal) {
@ -4492,8 +4488,9 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
}
}
llvm::SmallVector<uint32_t, 4> indices;
const Expr *baseExpr = collectArrayStructIndices(expr, &indices);
llvm::SmallVector<SpirvInstruction *, 4> indices;
const Expr *baseExpr = collectArrayStructIndices(
expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
auto base = loadIfAliasVarRef(baseExpr);
@ -4505,7 +4502,7 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
//
// TODO: We can optimize the codegen by emitting OpCompositeExtract if
// all indices are contant integers.
if (base.isRValue()) {
if (base->isRValue()) {
base = createTemporaryVar(baseExpr->getType(), "vector", base);
}
@ -4690,8 +4687,9 @@ SpirvInstruction *SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
}
SpirvInstruction *SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
llvm::SmallVector<uint32_t, 4> indices;
const Expr *base = collectArrayStructIndices(expr, &indices);
llvm::SmallVector<SpirvInstruction *, 4> indices;
const Expr *base = collectArrayStructIndices(
expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
auto *instr = loadIfAliasVarRef(base);
if (!indices.empty()) {
@ -5751,6 +5749,7 @@ SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
{accessor.Swz0}, rhs);
auto *result = tryToAssignToRWBufferRWTexture(base, newVec);
assert(result); // Definitely RWBuffer/RWTexture assignment
(void)result;
return rhs; // TODO: incorrect for compound assignments
} else {
// Assigning to one normal vector component. Nothing special, just fall
@ -6039,9 +6038,11 @@ SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
}
const Expr *SPIRVEmitter::collectArrayStructIndices(
const Expr *expr, llvm::SmallVectorImpl<SpirvInstruction *> *indices,
llvm::SmallVectorImpl<uint32_t> *rawIndices) {
assert(indices || rawIndices);
const Expr *expr, bool rawIndex,
llvm::SmallVectorImpl<uint32_t> *rawIndices,
llvm::SmallVectorImpl<SpirvInstruction *> *indices) {
assert((rawIndex && rawIndices) || (!rawIndex && indices));
if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
// First check whether this is referring to a static member. If it is, we
// create a DeclRefExpr for it.
@ -6053,8 +6054,8 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
varDecl->getType(), VK_LValue);
const Expr *base = collectArrayStructIndices(
indexing->getBase()->IgnoreParenNoopCasts(astContext), indices,
rawIndices);
indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
rawIndices, indices);
// Append the index of the current level
const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
@ -6064,10 +6065,11 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
// derived struct.
const uint32_t index = getNumBaseClasses(indexing->getBase()->getType()) +
fieldDecl->getFieldIndex();
if (indices)
indices->push_back(spvContext.getConstantInt32(index));
if (rawIndices)
if (rawIndex) {
rawIndices->push_back(index);
} else {
indices->push_back(spvContext.getConstantInt32(index));
}
return base;
}
@ -6077,20 +6079,21 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
if (rawIndices)
if (rawIndex)
return nullptr; // TODO: handle constant array index
// The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
// cast. We need to ingore it to avoid creating OpLoad.
const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
const Expr *base = collectArrayStructIndices(thisBase, indices, rawIndices);
const Expr *base =
collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
indices->push_back(doExpr(indexing->getIdx()));
return base;
}
if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
if (rawIndices)
if (rawIndex)
return nullptr; // TODO: handle constant array index
// If this is indexing into resources, we need specific OpImage*
@ -6104,7 +6107,7 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
const auto thisBaseType = thisBase->getType();
const Expr *base =
collectArrayStructIndices(thisBase, indices, rawIndices);
collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
if (thisBaseType != base->getType() &&
TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
@ -6138,7 +6141,7 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
const Expr *index = nullptr;
// TODO: the following is duplicating the logic in doCXXMemberCallExpr.
if (const auto *object = isStructuredBufferLoad(expr, &index)) {
if (rawIndices)
if (rawIndex)
return nullptr; // TODO: handle constant array index
// For object.Load(index), there should be no more indexing into the
@ -6217,15 +6220,12 @@ SpirvInstruction *SPIRVEmitter::castToBool(SpirvInstruction *fromVal,
return spvBuilder.createBinaryOp(spvOp, toBoolType, fromVal, zeroVal);
}
// TODO: casting bitwidths is a problem
SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
QualType fromType, QualType toIntType,
SourceLocation srcLoc) {
if (TypeTranslator::isSameScalarOrVecType(fromType, toIntType))
return fromVal;
uint32_t intType = typeTranslator.translateType(toIntType);
if (isBoolOrVecOfBoolType(fromType)) {
auto *one = getValueOne(toIntType);
auto *zero = getValueZero(toIntType);
@ -6234,21 +6234,23 @@ SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
// First convert the source to the bitwidth of the destination if necessary.
uint32_t convertedType = 0;
QualType convertedType = {};
fromVal = convertBitwidth(fromVal, fromType, toIntType, &convertedType);
// If bitwidth conversion was the only thing we needed to do, we're done.
if (convertedType == typeTranslator.translateType(toIntType))
if (convertedType == toIntType)
return fromVal;
return theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toIntType, fromVal);
}
if (isFloatOrVecOfFloatType(fromType)) {
// First convert the source to the bitwidth of the destination if necessary.
fromVal = convertBitwidth(fromVal, fromType, toIntType);
if (isSintOrVecOfSintType(toIntType)) {
return theBuilder.createUnaryOp(spv::Op::OpConvertFToS, intType, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpConvertFToS, toIntType,
fromVal);
} else if (isUintOrVecOfUintType(toIntType)) {
return theBuilder.createUnaryOp(spv::Op::OpConvertFToU, intType, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpConvertFToU, toIntType,
fromVal);
} else {
emitError("casting from floating point to integer unimplemented", srcLoc);
}
@ -6270,27 +6272,29 @@ SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
// Casting to a matrix of integers: Cast each row and construct a
// composite.
llvm::SmallVector<uint32_t, 4> castedRows;
const uint32_t vecType = typeTranslator.getComponentVectorType(fromType);
llvm::SmallVector<SpirvInstruction *, 4> castedRows;
const QualType vecType = typeTranslator.getComponentVectorType(fromType);
const auto fromVecQualType =
astContext.getExtVectorType(elemType, numCols);
const auto toIntVecQualType =
astContext.getExtVectorType(toElemType, numCols);
for (uint32_t row = 0; row < numRows; ++row) {
const auto rowId =
theBuilder.createCompositeExtract(vecType, fromVal, {row});
auto *rowId =
spvBuilder.createCompositeExtract(vecType, fromVal, {row});
castedRows.push_back(
castToInt(rowId, fromVecQualType, toIntVecQualType, srcLoc));
}
return theBuilder.createCompositeConstruct(intType, castedRows);
return spvBuilder.createCompositeConstruct(toIntType, castedRows);
}
}
return 0;
return nullptr;
}
uint32_t SPIRVEmitter::convertBitwidth(uint32_t fromVal, QualType fromType,
QualType toType, uint32_t *resultType) {
SpirvInstruction *SPIRVEmitter::convertBitwidth(SpirvInstruction *fromVal,
QualType fromType,
QualType toType,
QualType *resultType) {
// At the moment, we will not make bitwidth conversions for literal int and
// literal float types because they always indicate 64-bit and do not
// represent what SPIR-V was actually resolved to.
@ -6304,26 +6308,25 @@ uint32_t SPIRVEmitter::convertBitwidth(uint32_t fromVal, QualType fromType,
const auto toBitwidth = typeTranslator.getElementSpirvBitwidth(toType);
if (fromBitwidth == toBitwidth) {
if (resultType)
*resultType = typeTranslator.translateType(fromType);
*resultType = fromType;
return fromVal;
}
// We want the 'fromType' with the 'toBitwidth'.
const uint32_t targetTypeId =
typeTranslator.getTypeWithCustomBitwidth(fromType, toBitwidth);
const QualType targetType =
getTypeWithCustomBitwidth(astContext, fromType, toBitwidth);
if (resultType)
*resultType = targetTypeId;
*resultType = targetType;
if (isFloatOrVecOfFloatType(fromType))
return theBuilder.createUnaryOp(spv::Op::OpFConvert, targetTypeId, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpFConvert, targetType, fromVal);
if (isSintOrVecOfSintType(fromType))
return theBuilder.createUnaryOp(spv::Op::OpSConvert, targetTypeId, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpSConvert, targetType, fromVal);
if (isUintOrVecOfUintType(fromType))
return theBuilder.createUnaryOp(spv::Op::OpUConvert, targetTypeId, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpUConvert, targetType, fromVal);
llvm_unreachable("invalid type passed to convertBitwidth");
}
// TODO: converting bitwidths is a problem
SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
QualType fromType,
QualType toFloatType,
@ -6331,8 +6334,6 @@ SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
if (TypeTranslator::isSameScalarOrVecType(fromType, toFloatType))
return fromVal;
const uint32_t floatType = typeTranslator.translateType(toFloatType);
if (isBoolOrVecOfBoolType(fromType)) {
auto *one = getValueOne(toFloatType);
auto *zero = getValueZero(toFloatType);
@ -6342,13 +6343,15 @@ SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
if (isSintOrVecOfSintType(fromType)) {
// First convert the source to the bitwidth of the destination if necessary.
fromVal = convertBitwidth(fromVal, fromType, toFloatType);
return theBuilder.createUnaryOp(spv::Op::OpConvertSToF, floatType, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpConvertSToF, toFloatType,
fromVal);
}
if (isUintOrVecOfUintType(fromType)) {
// First convert the source to the bitwidth of the destination if necessary.
fromVal = convertBitwidth(fromVal, fromType, toFloatType);
return theBuilder.createUnaryOp(spv::Op::OpConvertUToF, floatType, fromVal);
return spvBuilder.createUnaryOp(spv::Op::OpConvertUToF, toFloatType,
fromVal);
}
if (isFloatOrVecOfFloatType(fromType)) {
@ -6373,24 +6376,24 @@ SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
// Casting to a matrix of floats: Cast each row and construct a
// composite.
llvm::SmallVector<uint32_t, 4> castedRows;
const uint32_t vecType = typeTranslator.getComponentVectorType(fromType);
llvm::SmallVector<SpirvInstruction *, 4> castedRows;
const QualType vecType = typeTranslator.getComponentVectorType(fromType);
const auto fromVecQualType =
astContext.getExtVectorType(elemType, numCols);
const auto toIntVecQualType =
astContext.getExtVectorType(toElemType, numCols);
for (uint32_t row = 0; row < numRows; ++row) {
const auto rowId =
theBuilder.createCompositeExtract(vecType, fromVal, {row});
auto *rowId =
spvBuilder.createCompositeExtract(vecType, fromVal, {row});
castedRows.push_back(
castToFloat(rowId, fromVecQualType, toIntVecQualType, srcLoc));
}
return theBuilder.createCompositeConstruct(floatType, castedRows);
return spvBuilder.createCompositeConstruct(toFloatType, castedRows);
}
}
emitError("casting to floating point unimplemented", srcLoc);
return 0;
return nullptr;
}
// ehsan was here.
@ -6414,7 +6417,7 @@ SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
#define INTRINSIC_SPIRV_OP_WITH_CAP_CASE(intrinsicOp, spirvOp, doEachVec, cap) \
case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
theBuilder.requireCapability(cap); \
spvBuilder.requireCapability(cap); \
retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
doEachVec); \
} break
@ -6618,20 +6621,17 @@ SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneCount",
callExpr->getExprLoc());
const uint32_t retType =
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
const uint32_t varId =
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize);
retVal = theBuilder.createLoad(retType, varId);
const QualType retType = callExpr->getCallReturnType(astContext);
auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize);
retVal = spvBuilder.createLoad(retType, var);
} break;
case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneIndex",
callExpr->getExprLoc());
const uint32_t retType =
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
const uint32_t varId =
const QualType retType = callExpr->getCallReturnType(astContext);
auto *var =
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
retVal = theBuilder.createLoad(retType, varId);
retVal = spvBuilder.createLoad(retType, var);
} break;
case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
retVal = processWaveQuery(callExpr, spv::Op::OpGroupNonUniformElect);
@ -6810,7 +6810,6 @@ SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
// where necessary. To ensure SPIR-V validity, we add that where necessary.
auto *zero = spvContext.getConstantUint32(0);
auto *scope = spvContext.getConstantUint32(1); // Device
const auto *dest = expr->getArg(0);
const auto baseType = dest->getType();
@ -9277,9 +9276,9 @@ void SPIRVEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
decl->getLocation());
}
if (decl->getAttr<VKPostDepthCoverageAttr>()) {
theBuilder.addExtension(Extension::KHR_post_depth_coverage,
spvBuilder.addExtension(Extension::KHR_post_depth_coverage,
"[[vk::post_depth_coverage]]", decl->getLocation());
theBuilder.requireCapability(spv::Capability::SampleMaskPostDepthCoverage);
spvBuilder.requireCapability(spv::Capability::SampleMaskPostDepthCoverage);
spvBuilder.addExecutionMode(entryFunction,
spv::ExecutionMode::PostDepthCoverage, {},
decl->getLocation());
@ -9406,13 +9405,14 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
// Construct the wrapper function signature.
const SpirvType *voidType = spvContext.getVoidType();
const FunctionType *funcType = spvContext.getFunctionType(voidType, {});
FunctionType *funcType = spvContext.getFunctionType(voidType, {});
// The wrapper entry function surely does not have pre-assigned <result-id>
// for it like other functions that got added to the work queue following
// function calls. And the wrapper is the entry function.
entryFunction = spvBuilder.beginFunction(
astContext.VoidTy, funcType, /*SourceLocation*/ {}, decl->getName());
entryFunction = spvBuilder.beginFunction(astContext.VoidTy, funcType,
/*SourceLocation*/ {},
decl->getName(), entryFuncInstr);
// Note this should happen before using declIdMapper for other tasks.
declIdMapper.setEntryFunction(entryFunction);

Просмотреть файл

@ -290,16 +290,11 @@ private:
SpirvInstruction *initValue);
/// Collects all indices from consecutive MemberExprs
/// or ArraySubscriptExprs or operator[] calls. If indices is not null,
/// SPIR-V constant values are written into *indices. Returns the real base.
/// If rawIndices is not null, the raw integer indices will be written to
/// *rawIndices, and the base returned can be nullptr, which means some
/// indices are not constant.
/// Either indices or rawIndices must be non-null.
/// TODO: Update method description here.
const Expr *
collectArrayStructIndices(const Expr *expr,
llvm::SmallVectorImpl<SpirvInstruction *> *indices,
llvm::SmallVectorImpl<uint32_t> *rawIndices);
collectArrayStructIndices(const Expr *expr, bool rawIndex,
llvm::SmallVectorImpl<uint32_t> *rawIndices,
llvm::SmallVectorImpl<SpirvInstruction *> *indices);
/// Creates an access chain to index into the given SPIR-V evaluation result
/// and returns the new SPIR-V evaluation result.
@ -319,8 +314,9 @@ private:
/// If resultType is not nullptr, the resulting value's type will be written
/// to resultType. Panics if the given types are not scalar or vector of
/// float/integer type.
uint32_t convertBitwidth(uint32_t value, QualType fromType, QualType toType,
uint32_t *resultType = nullptr);
SpirvInstruction *convertBitwidth(SpirvInstruction *value, QualType fromType,
QualType toType,
QualType *resultType = nullptr);
/// Processes the given expr, casts the result into the given bool (vector)
/// type and returns the <result-id> of the casted value.

Просмотреть файл

@ -9,6 +9,8 @@
#include "clang/SPIRV/SpirvBuilder.h"
#include "TypeTranslator.h"
#include "clang/SPIRV/EmitVisitor.h"
#include "clang/SPIRV/LowerTypeVisitor.h"
namespace clang {
namespace spirv {
@ -21,28 +23,38 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
}
SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
const SpirvType *functionType,
SpirvType *functionType,
SourceLocation loc,
llvm::StringRef funcName) {
llvm::StringRef funcName,
SpirvFunction *func) {
assert(!function && "found nested function");
function = new (context)
SpirvFunction(returnType, functionType, /*id*/ 0,
spv::FunctionControlMask::MaskNone, loc, funcName);
if (func) {
function = func;
function->setAstReturnType(returnType);
function->setFunctionType(functionType);
} else {
function = new (context)
SpirvFunction(returnType, functionType, /*id*/ 0,
spv::FunctionControlMask::MaskNone, loc, funcName);
}
return function;
}
/*
SpirvFunction *SpirvBuilder::createFunction(QualType returnType,
const SpirvType *functionType,
SourceLocation loc,
llvm::StringRef funcName,
bool isAlias) {
SpirvFunction *fn = new (context)
SpirvFunction(returnType, functionType, /*id*/ 0,
SpirvFunction(returnType, functionType, 0,
spv::FunctionControlMask::MaskNone, loc, funcName);
function->setConstainsAliasComponent(isAlias);
module->addFunction(function);
return function;
}
*/
SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
SourceLocation loc,
@ -268,6 +280,28 @@ SpirvBinaryOp *SpirvBuilder::createBinaryOp(spv::Op op, QualType resultType,
return instruction;
}
SpirvBinaryOp *SpirvBuilder::createBinaryOp(spv::Op op,
const SpirvType *resultType,
SpirvInstruction *lhs,
SpirvInstruction *rhs,
SourceLocation loc) {
assert(insertPoint && "null insert point");
auto *instruction =
new (context) SpirvBinaryOp(op, /*QualType*/ {}, /*id*/ 0, loc, lhs, rhs);
instruction->setResultType(resultType);
insertPoint->addInstruction(instruction);
switch (op) {
case spv::Op::OpImageQueryLod:
case spv::Op::OpImageQuerySizeLod:
requireCapability(spv::Capability::ImageQuery);
break;
default:
// Only checking for ImageQueries, the other Ops can be ignored.
break;
}
return instruction;
}
SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
spv::Op op, QualType resultType, SpirvInstruction *lhs,
SpirvInstruction *rhs, SourceLocation loc) {
@ -955,5 +989,15 @@ void SpirvBuilder::decorateNonUniformEXT(SpirvInstruction *target,
module->addDecoration(decor);
}
std::vector<uint32_t> SpirvBuilder::takeModule() {
// Run necessary visitor passes first
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
EmitVisitor emitVisitor(astContext, context, spirvOptions);
module->invokeVisitor(&lowerTypeVisitor);
module->invokeVisitor(&emitVisitor);
return emitVisitor.takeBinary();
}
} // end namespace spirv
} // end namespace clang

Просмотреть файл

@ -13,7 +13,7 @@
namespace clang {
namespace spirv {
SpirvFunction::SpirvFunction(QualType returnType, const SpirvType *functionType,
SpirvFunction::SpirvFunction(QualType returnType, SpirvType *functionType,
uint32_t id, spv::FunctionControlMask control,
SourceLocation loc, llvm::StringRef name)
: functionId(id), astReturnType(returnType), returnType(nullptr),
@ -27,6 +27,9 @@ bool SpirvFunction::invokeVisitor(Visitor *visitor) {
for (auto *param : parameters)
visitor->visit(param);
for (auto *var : variables)
visitor->visit(var);
for (auto *bb : basicBlocks)
if (!bb->invokeVisitor(visitor))
return false;

Просмотреть файл

@ -118,7 +118,7 @@ SpirvEntryPoint::SpirvEntryPoint(SourceLocation loc,
SpirvFunction *entryPointFn,
llvm::StringRef nameStr,
llvm::ArrayRef<SpirvVariable *> iface)
: SpirvInstruction(IK_EntryPoint, spv::Op::OpMemoryModel, QualType(),
: SpirvInstruction(IK_EntryPoint, spv::Op::OpEntryPoint, QualType(),
/*resultId=*/0, loc),
execModel(executionModel), entryPoint(entryPointFn), name(nameStr),
interfaceVec(iface.begin(), iface.end()) {}
@ -562,8 +562,8 @@ bool SpirvConstantFloat::operator==(const SpirvConstantFloat &that) const {
}
SpirvConstantComposite::SpirvConstantComposite(
const SpirvType *type,
llvm::ArrayRef<const SpirvConstant *> constituentsVec, bool isSpecConst)
const SpirvType *type, llvm::ArrayRef<SpirvConstant *> constituentsVec,
bool isSpecConst)
: SpirvConstant(IK_ConstantComposite,
isSpecConst ? spv::Op::OpSpecConstantComposite
: spv::Op::OpConstantComposite,
@ -571,7 +571,7 @@ SpirvConstantComposite::SpirvConstantComposite(
constituents(constituentsVec.begin(), constituentsVec.end()) {}
SpirvConstantComposite::SpirvConstantComposite(
QualType type, llvm::ArrayRef<const SpirvConstant *> constituentsVec,
QualType type, llvm::ArrayRef<SpirvConstant *> constituentsVec,
bool isSpecConst)
: SpirvConstant(IK_ConstantComposite,
isSpecConst ? spv::Op::OpSpecConstantComposite

Просмотреть файл

@ -44,8 +44,9 @@ bool SpirvModule::invokeVisitor(Visitor *visitor) {
if (!execMode->invokeVisitor(visitor))
return false;
if (!debugSource->invokeVisitor(visitor))
return false;
if (debugSource)
if (!debugSource->invokeVisitor(visitor))
return false;
for (auto decoration : decorations)
if (!decoration->invokeVisitor(visitor))