[SPIR-V] Implement SpirvType and SpirvOpaqueType (#6156)

Implements hlsl-specs proposal 0011, adding `vk::SpirvType` and
`vk::SpirvOpaqueType` templates which allow users to define and use
SPIR-V level types.
This commit is contained in:
Cassandra Beckley 2024-04-15 13:58:49 -07:00 коммит произвёл GitHub
Родитель dc84d72d18
Коммит d60dffef1a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
23 изменённых файлов: 533 добавлений и 54 удалений

Просмотреть файл

@ -34,7 +34,8 @@ public:
clang::TemplateTypeParmDecl *
addTypeTemplateParam(llvm::StringRef name,
clang::TypeSourceInfo *defaultValue = nullptr);
clang::TypeSourceInfo *defaultValue = nullptr,
bool parameterPack = false);
clang::TemplateTypeParmDecl *
addTypeTemplateParam(llvm::StringRef name, clang::QualType defaultValue);
clang::NonTypeTemplateParmDecl *

Просмотреть файл

@ -399,6 +399,16 @@ DeclareNodeOrRecordType(clang::ASTContext &Ctx, DXIL::NodeIOKind Type,
bool HasGetMethods = false, bool IsArray = false,
bool IsCompleteType = false);
#ifdef ENABLE_SPIRV_CODEGEN
clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context,
clang::DeclContext *declContext,
llvm::StringRef typeName,
bool opaque);
clang::CXXRecordDecl *DeclareVkIntegralConstant(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, clang::ClassTemplateDecl **templateDecl);
#endif
clang::CXXRecordDecl *DeclareNodeOutputArray(clang::ASTContext &Ctx,
DXIL::NodeIOKind Type,
clang::CXXRecordDecl *OutputType,

Просмотреть файл

@ -286,9 +286,12 @@ public:
const RayQueryTypeKHR *getRayQueryTypeKHR() const { return rayQueryTypeKHR; }
const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType(
unsigned typeOpCode, llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);
SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);
@ -471,7 +474,8 @@ private:
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Просмотреть файл

@ -1193,6 +1193,8 @@ public:
inst->getKind() <= IK_ConstantNull;
}
bool operator==(const SpirvConstant &that) const;
bool isSpecConstant() const;
void setLiteral(bool literal = true) { literalConstant = literal; }
bool isLiteral() { return literalConstant; }

Просмотреть файл

@ -429,15 +429,16 @@ public:
class SpirvInstruction;
struct SpvIntrinsicTypeOperand {
SpvIntrinsicTypeOperand(SpirvType *type_operand)
SpvIntrinsicTypeOperand(const SpirvType *type_operand)
: operand_as_type(type_operand), isTypeOperand(true) {}
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
: operand_as_inst(inst_operand), isTypeOperand(false) {}
bool operator==(const SpvIntrinsicTypeOperand &that) const;
union {
SpirvType *operand_as_type;
const SpirvType *operand_as_type;
SpirvInstruction *operand_as_inst;
};
bool isTypeOperand;
const bool isTypeOperand;
};
class SpirvIntrinsicType : public SpirvType {
@ -453,6 +454,12 @@ public:
return operands;
}
bool operator==(const SpirvIntrinsicType &that) const {
return typeOpCode == that.typeOpCode &&
operands.size() == that.operands.size() &&
std::equal(operands.begin(), operands.end(), that.operands.begin());
}
private:
unsigned typeOpCode;
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;

Просмотреть файл

@ -1241,6 +1241,43 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType(
return Builder.getRecordDecl();
}
#ifdef ENABLE_SPIRV_CODEGEN
CXXRecordDecl *hlsl::DeclareInlineSpirvType(clang::ASTContext &context,
clang::DeclContext *declContext,
llvm::StringRef typeName,
bool opaque) {
// template<uint opcode, int size, int alignment> vk::SpirvType { ... }
// template<uint opcode> vk::SpirvOpaqueType { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
clang::TagTypeKind::TTK_Class);
typeDeclBuilder.addIntegerTemplateParam("opcode", context.UnsignedIntTy);
if (!opaque) {
typeDeclBuilder.addIntegerTemplateParam("size", context.UnsignedIntTy);
typeDeclBuilder.addIntegerTemplateParam("alignment", context.UnsignedIntTy);
}
typeDeclBuilder.addTypeTemplateParam("operands", nullptr, true);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField(
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
return typeDeclBuilder.getRecordDecl();
}
CXXRecordDecl *hlsl::DeclareVkIntegralConstant(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, ClassTemplateDecl **templateDecl) {
// template<typename T, T v> vk::integral_constant { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName,
clang::TagTypeKind::TTK_Class);
typeDeclBuilder.addTypeTemplateParam("T");
typeDeclBuilder.addIntegerTemplateParam("v", context.UnsignedIntTy);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField(
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
*templateDecl = typeDeclBuilder.getTemplateDecl();
return typeDeclBuilder.getRecordDecl();
}
#endif
CXXRecordDecl *hlsl::DeclareNodeOutputArray(clang::ASTContext &Ctx,
DXIL::NodeIOKind Type,
CXXRecordDecl *OutputType,

Просмотреть файл

@ -33,9 +33,8 @@ BuiltinTypeDeclBuilder::BuiltinTypeDeclBuilder(DeclContext *declContext,
m_recordDecl->setImplicit(true);
}
TemplateTypeParmDecl *
BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
TypeSourceInfo *defaultValue) {
TemplateTypeParmDecl *BuiltinTypeDeclBuilder::addTypeTemplateParam(
StringRef name, TypeSourceInfo *defaultValue, bool parameterPack) {
DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() &&
!m_recordDecl->isCompleteDefinition());
@ -45,7 +44,7 @@ BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name,
astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc,
/* TemplateDepth */ 0, index,
&astContext.Idents.get(name, tok::TokenKind::identifier),
/* Typename */ false, /* ParameterPack */ false);
/* Typename */ false, parameterPack);
if (defaultValue != nullptr)
decl->setDefaultArgument(defaultValue);
m_templateParams.emplace_back(decl);

Просмотреть файл

@ -9,6 +9,7 @@
#include "AlignmentSizeCalculator.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclTemplate.h"
namespace {
@ -264,6 +265,21 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
return getAlignmentAndSize(desugaredType, rule, isRowMajor, stride);
}
const auto *recordType = type->getAs<RecordType>();
if (recordType != nullptr) {
const llvm::StringRef name = recordType->getDecl()->getName();
if (isTypeInVkNamespace(recordType) && name == "SpirvType") {
const ClassTemplateSpecializationDecl *templateDecl =
cast<ClassTemplateSpecializationDecl>(recordType->getDecl());
const uint64_t size =
templateDecl->getTemplateArgs()[1].getAsIntegral().getZExtValue();
const uint64_t alignment =
templateDecl->getTemplateArgs()[2].getAsIntegral().getZExtValue();
return {alignment, size};
}
}
if (isEnumType(type))
type = astContext.IntTy;

Просмотреть файл

@ -35,6 +35,12 @@ public:
SpirvConstant *translateAPFloat(llvm::APFloat floatValue, QualType targetType,
bool isSpecConstantMode);
/// Translates the given frontend APValue into its SPIR-V equivalent for the
/// given targetType.
SpirvConstant *translateAPValue(const APValue &value,
const QualType targetType,
bool isSpecConstantMode);
/// Tries to evaluate the given APInt as a 32-bit integer. If the evaluation
/// can be performed without loss, it returns the <result-id> of the SPIR-V
/// constant for that value.
@ -52,12 +58,6 @@ public:
bool isSpecConstantMode);
private:
/// Translates the given frontend APValue into its SPIR-V equivalent for the
/// given targetType.
SpirvConstant *translateAPValue(const APValue &value,
const QualType targetType,
bool isSpecConstantMode);
/// Emits error to the diagnostic engine associated with the AST context.
template <unsigned N>
DiagnosticBuilder emitError(const char (&message)[N],

Просмотреть файл

@ -2577,7 +2577,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
for (const SpvIntrinsicTypeOperand &operand :
spvIntrinsicType->getOperands()) {
if (operand.isTypeOperand) {
curTypeInst.push_back(emitType(operand.operand_as_type));
// calling emitType recursively will potentially replace the contents of
// curTypeInst, so we need to save them and restore after the call
std::vector<uint32_t> outerTypeInst = curTypeInst;
outerTypeInst.push_back(emitType(operand.operand_as_type));
curTypeInst = outerTypeInst;
} else {
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
if (literal && literal->isLiteral()) {

Просмотреть файл

@ -417,7 +417,8 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc,
assert(recordType);
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
theEmitter.getSpirvOptions());
theEmitter.getSpirvOptions(),
theEmitter.getSpirvBuilder());
const SpirvType *spirvType =
lowerTypeVisitor.lowerType(type, SpirvLayoutRule::Void, false, srcLoc);

Просмотреть файл

@ -8,6 +8,8 @@
//===----------------------------------------------------------------------===//
#include "LowerTypeVisitor.h"
#include "ConstEvaluator.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/HlslTypes.h"
@ -530,7 +532,8 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
// (ClassTemplateSpecializationDecl is a subclass of CXXRecordDecl, which
// is then a subclass of RecordDecl.) So we need to check them before
// checking the general struct type.
if (const auto *spvType = lowerResourceType(type, rule, srcLoc)) {
if (const auto *spvType =
lowerResourceType(type, rule, isRowMajor, srcLoc)) {
spvContext.registerStructDeclForSpirvType(spvType, decl);
return spvType;
}
@ -620,10 +623,152 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
return 0;
}
const SpirvType *
LowerTypeVisitor::lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name,
SpirvLayoutRule rule,
SourceLocation srcLoc) {
QualType LowerTypeVisitor::createASTTypeFromTemplateName(TemplateName name) {
auto *decl = name.getAsTemplateDecl();
if (decl == nullptr) {
return QualType();
}
auto *classTemplateDecl = dyn_cast<ClassTemplateDecl>(decl);
if (classTemplateDecl == nullptr) {
return QualType();
}
TemplateParameterList *parameters =
classTemplateDecl->getTemplateParameters();
if (parameters->size() != 1) {
return QualType();
}
auto *parmDecl = dyn_cast<TemplateTypeParmDecl>(parameters->getParam(0));
if (parmDecl == nullptr) {
return QualType();
}
if (!parmDecl->hasDefaultArgument()) {
return QualType();
}
TemplateArgument *arg =
new (context) TemplateArgument(parmDecl->getDefaultArgument());
auto *specialized = ClassTemplateSpecializationDecl::Create(
astContext, TagDecl::TagKind::TTK_Class,
classTemplateDecl->getDeclContext(), classTemplateDecl->getLocStart(),
classTemplateDecl->getLocStart(), classTemplateDecl, /* Args */ arg,
/* NumArgs */ 1,
/* PrevDecl */ nullptr);
QualType type = astContext.getTypeDeclType(specialized);
return type;
}
bool LowerTypeVisitor::getVkIntegralConstantValue(QualType type,
SpirvConstant *&result,
SourceLocation srcLoc) {
auto *recordType = type->getAs<RecordType>();
if (!recordType)
return false;
if (!isTypeInVkNamespace(recordType))
return false;
if (recordType->getDecl()->getName() == "Literal") {
auto *specDecl =
dyn_cast<ClassTemplateSpecializationDecl>(recordType->getDecl());
assert(specDecl);
const TemplateArgumentList &args = specDecl->getTemplateArgs();
QualType constant = args[0].getAsType();
bool val = getVkIntegralConstantValue(constant, result, srcLoc);
if (val) {
result->setLiteral(true);
} else {
emitError("The template argument to vk::Literal must be a "
"vk::integral_constant",
srcLoc);
}
return true;
}
if (recordType->getDecl()->getName() != "integral_constant")
return false;
auto *specDecl =
dyn_cast<ClassTemplateSpecializationDecl>(recordType->getDecl());
assert(specDecl);
const TemplateArgumentList &args = specDecl->getTemplateArgs();
QualType constantType = args[0].getAsType();
llvm::APSInt value = args[1].getAsIntegral();
result = ConstEvaluator(astContext, spvBuilder)
.translateAPValue(APValue(value), constantType, false);
return true;
}
const SpirvType *LowerTypeVisitor::lowerInlineSpirvType(
llvm::StringRef name, unsigned int opcode,
const ClassTemplateSpecializationDecl *specDecl, SpirvLayoutRule rule,
llvm::Optional<bool> isRowMajor, SourceLocation srcLoc) {
assert(specDecl);
SmallVector<SpvIntrinsicTypeOperand, 4> operands;
// Lower each operand argument
size_t operandsIndex = 1;
if (name == "SpirvType")
operandsIndex = 3;
auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray();
for (TemplateArgument arg : args) {
switch (arg.getKind()) {
case TemplateArgument::ArgKind::Type: {
QualType typeArg = arg.getAsType();
SpirvConstant *constant = nullptr;
if (getVkIntegralConstantValue(typeArg, constant, srcLoc)) {
if (constant) {
visitInstruction(constant);
operands.emplace_back(constant);
}
} else {
operands.emplace_back(lowerType(typeArg, rule, isRowMajor, srcLoc));
}
break;
}
case TemplateArgument::ArgKind::Template: {
// Handle HLSL template types that allow the omission of < and >; for
// example, Texture2D
TemplateName templateName = arg.getAsTemplate();
QualType typeArg = createASTTypeFromTemplateName(templateName);
assert(!typeArg.isNull() &&
"Could not create HLSL type from template name");
operands.emplace_back(lowerType(typeArg, rule, isRowMajor, srcLoc));
break;
}
default:
emitError("template argument kind %0 unimplemented", srcLoc)
<< arg.getKind();
}
}
return spvContext.getOrCreateSpirvIntrinsicType(opcode, operands);
}
const SpirvType *LowerTypeVisitor::lowerVkTypeInVkNamespace(
QualType type, llvm::StringRef name, SpirvLayoutRule rule,
llvm::Optional<bool> isRowMajor, SourceLocation srcLoc) {
if (name == "SpirvType" || name == "SpirvOpaqueType") {
auto opcode = hlsl::GetHLSLResourceTemplateUInt(type);
auto *specDecl = dyn_cast<ClassTemplateSpecializationDecl>(
type->getAs<RecordType>()->getDecl());
return lowerInlineSpirvType(name, opcode, specDecl, rule, isRowMajor,
srcLoc);
}
if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
@ -636,9 +781,10 @@ LowerTypeVisitor::lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name,
return nullptr;
}
const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
SpirvLayoutRule rule,
SourceLocation srcLoc) {
const SpirvType *
LowerTypeVisitor::lowerResourceType(QualType type, SpirvLayoutRule rule,
llvm::Optional<bool> isRowMajor,
SourceLocation srcLoc) {
// Resource types are either represented like C struct or C++ class in the
// AST. Samplers are represented like C struct, so isStructureType() will
// return true for it; textures are represented like C++ class, so
@ -651,7 +797,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
const llvm::StringRef name = recordType->getDecl()->getName();
if (isTypeInVkNamespace(recordType)) {
return lowerVkTypeInVkNamespace(type, name, rule, srcLoc);
return lowerVkTypeInVkNamespace(type, name, rule, isRowMajor, srcLoc);
}
// TODO: avoid string comparison once hlsl::IsHLSLResouceType() does that.

Просмотреть файл

@ -12,6 +12,7 @@
#include "AlignmentSizeCalculator.h"
#include "clang/AST/ASTContext.h"
#include "clang/SPIRV/SpirvBuilder.h"
#include "clang/SPIRV/SpirvContext.h"
#include "clang/SPIRV/SpirvVisitor.h"
#include "llvm/ADT/Optional.h"
@ -23,9 +24,10 @@ namespace spirv {
class LowerTypeVisitor : public Visitor {
public:
LowerTypeVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
const SpirvCodeGenOptions &opts)
const SpirvCodeGenOptions &opts, SpirvBuilder &builder)
: Visitor(opts, spvCtx), astContext(astCtx), spvContext(spvCtx),
alignmentCalc(astCtx, opts), useArrayForMat1xN(false) {}
alignmentCalc(astCtx, opts), useArrayForMat1xN(false),
spvBuilder(builder) {}
// Visiting different SPIR-V constructs.
bool visit(SpirvModule *, Phase) override { return true; }
@ -69,6 +71,7 @@ private:
/// Lowers the given HLSL resource type into its SPIR-V type.
const SpirvType *lowerResourceType(QualType type, SpirvLayoutRule rule,
llvm::Optional<bool> isRowMajor,
SourceLocation);
/// Lowers the fields of a RecordDecl into SPIR-V StructType field
@ -76,9 +79,28 @@ private:
llvm::SmallVector<StructType::FieldInfo, 4>
lowerStructFields(const RecordDecl *structType, SpirvLayoutRule rule);
/// Creates the default AST type from a TemplateName for HLSL templates
/// which have optional parameters (e.g. Texture2D).
QualType createASTTypeFromTemplateName(TemplateName templateName);
/// If the given type is an integral_constant or a Literal<integral_constant>,
/// return the constant value as a SpirvConstant, which will be set as a
/// literal constant if wrapped in Literal.
bool getVkIntegralConstantValue(QualType type, SpirvConstant *&result,
SourceLocation srcLoc);
/// Lowers the given vk::SpirvType or vk::SpirvOpaqueType into its SPIR-V
/// type.
const SpirvType *
lowerInlineSpirvType(llvm::StringRef name, unsigned int opcode,
const ClassTemplateSpecializationDecl *specDecl,
SpirvLayoutRule rule, llvm::Optional<bool> isRowMajor,
SourceLocation srcLoc);
/// Lowers the given type defined in vk namespace into its SPIR-V type.
const SpirvType *lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name,
SpirvLayoutRule rule,
llvm::Optional<bool> isRowMajor,
SourceLocation srcLoc);
/// For the given sampled type, returns the corresponding image format
@ -107,6 +129,7 @@ private:
SpirvContext &spvContext; /// SPIR-V context
AlignmentSizeCalculator alignmentCalc; /// alignment calculator
bool useArrayForMat1xN; /// SPIR-V array for HLSL Matrix 1xN
SpirvBuilder &spvBuilder;
};
} // end namespace spirv

Просмотреть файл

@ -293,7 +293,7 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address,
if (bitfieldInfo.hasValue()) {
// Generate SPIR-V type for value. This is required to know the final
// layout.
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this);
lowerTypeVisitor.visitInstruction(value);
context.addToInstructionsWithLoweredType(value);
@ -1403,7 +1403,7 @@ SpirvBuilder::initializeCloneVarForFxcCTBuffer(SpirvInstruction *instr) {
auto astType = var->getAstResultType();
const auto *spvType = var->getResultType();
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this);
lowerTypeVisitor.visitInstruction(var);
context.addToInstructionsWithLoweredType(instr);
if (!lowerTypeVisitor.useSpvArrayForHlslMat1xN()) {
@ -1881,7 +1881,7 @@ std::vector<uint32_t> SpirvBuilder::takeModule() {
// Run necessary visitor passes first
LiteralTypeVisitor literalTypeVisitor(astContext, context, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this);
CapabilityVisitor capabilityVisitor(astContext, context, spirvOptions, *this,
featureManager);
RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions);

Просмотреть файл

@ -96,10 +96,14 @@ SpirvContext::~SpirvContext() {
for (auto &typePair : typeTemplateParams)
typePair.second->releaseMemory();
for (auto &pair : spirvIntrinsicTypes) {
for (auto &pair : spirvIntrinsicTypesById) {
assert(pair.second);
pair.second->~SpirvIntrinsicType();
}
for (auto *spirvIntrinsicType : spirvIntrinsicTypes) {
spirvIntrinsicType->~SpirvIntrinsicType();
}
}
inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
@ -534,22 +538,41 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
typeTemplateParams.clear();
}
const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType(
const SpirvIntrinsicType *SpirvContext::getOrCreateSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands) {
if (spirvIntrinsicTypes[typeId] == nullptr) {
spirvIntrinsicTypes[typeId] =
if (spirvIntrinsicTypesById[typeId] == nullptr) {
spirvIntrinsicTypesById[typeId] =
new (this) SpirvIntrinsicType(typeOpCode, operands);
}
return spirvIntrinsicTypes[typeId];
return spirvIntrinsicTypesById[typeId];
}
const SpirvIntrinsicType *SpirvContext::getOrCreateSpirvIntrinsicType(
unsigned typeOpCode, llvm::ArrayRef<SpvIntrinsicTypeOperand> operands) {
SpirvIntrinsicType type(typeOpCode, operands);
auto found =
std::find_if(spirvIntrinsicTypes.begin(), spirvIntrinsicTypes.end(),
[&type](const SpirvIntrinsicType *cachedType) {
return type == *cachedType;
});
if (found != spirvIntrinsicTypes.end())
return *found;
spirvIntrinsicTypes.push_back(new (this)
SpirvIntrinsicType(typeOpCode, operands));
return spirvIntrinsicTypes.back();
}
SpirvIntrinsicType *
SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) {
if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()) {
if (spirvIntrinsicTypesById.find(typeId) == spirvIntrinsicTypesById.end()) {
return nullptr;
}
return spirvIntrinsicTypes[typeId];
return spirvIntrinsicTypesById[typeId];
}
} // end namespace spirv

Просмотреть файл

@ -6069,7 +6069,8 @@ SpirvInstruction *SpirvEmitter::doMemberExpr(const MemberExpr *expr,
}
const uint32_t indexAST =
getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const StructType *spirvStructType =
lowerStructType(spirvOptions, lowerTypeVisitor, baseType);
assert(spirvStructType);
@ -6608,7 +6609,8 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,
if (const auto *recordType = valType->getAs<RecordType>()) {
assert(recordType->isStructureType());
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const StructType *spirvStructType =
lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar());
@ -7157,7 +7159,8 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
SourceRange range) {
assert(astStructType->isStructureType());
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const StructType *spirvStructType =
lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
uint32_t vectorIndex = 0;
@ -7966,7 +7969,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
}
{
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const auto &astStructType =
/* structType */ indexing->getBase()->getType();
const StructType *spirvStructType =
@ -14224,8 +14228,8 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
}
auto typeDefAttr = funcDecl->getAttr<VKTypeDefExtAttr>();
spvContext.getSpirvIntrinsicType(typeDefAttr->getId(),
typeDefAttr->getOpcode(), operands);
spvContext.getOrCreateSpirvIntrinsicType(typeDefAttr->getId(),
typeDefAttr->getOpcode(), operands);
return createSpirvIntrInstExt(
funcDecl->getAttrs(), QualType(),
@ -14521,7 +14525,8 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) {
std::vector<SpirvInstruction *> result;
const SpirvType *type = nullptr;
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
type = lowerTypeVisitor.lowerType(resultType, inst->getLayoutRule(), false,
inst->getSourceLocation());
@ -14619,7 +14624,8 @@ SpirvEmitter::generateFromScalars(QualType type,
return result;
} else if (const RecordType *recordType = dyn_cast<RecordType>(type)) {
std::vector<SpirvInstruction *> elements;
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const SpirvType *spirvType =
lowerTypeVisitor.lowerType(type, layoutRule, false, sourceLocation);
@ -14701,7 +14707,8 @@ SpirvEmitter::splatScalarToGenerate(QualType type, SpirvInstruction *scalar,
return result;
} else if (const RecordType *recordType = dyn_cast<RecordType>(type)) {
std::vector<SpirvInstruction *> elements;
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions,
spvBuilder);
const SpirvType *spirvType = lowerTypeVisitor.lowerType(
type, SpirvLayoutRule::Void, false, sourceLocation);

Просмотреть файл

@ -513,6 +513,38 @@ SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType,
/*SourceLocation*/ {}),
literalConstant(literal) {}
bool SpirvConstant::operator==(const SpirvConstant &that) const {
if (auto *booleanInst = dyn_cast<SpirvConstantBoolean>(this)) {
auto *thatBooleanInst = dyn_cast<SpirvConstantBoolean>(&that);
if (thatBooleanInst == nullptr)
return false;
return *booleanInst == *thatBooleanInst;
} else if (auto *integerInst = dyn_cast<SpirvConstantInteger>(this)) {
auto *thatIntegerInst = dyn_cast<SpirvConstantInteger>(&that);
if (thatIntegerInst == nullptr)
return false;
return *integerInst == *thatIntegerInst;
} else if (auto *floatInst = dyn_cast<SpirvConstantFloat>(this)) {
auto *thatFloatInst = dyn_cast<SpirvConstantFloat>(&that);
if (thatFloatInst == nullptr)
return false;
return *floatInst == *thatFloatInst;
} else if (auto *compositeInst = dyn_cast<SpirvConstantComposite>(this)) {
auto *thatCompositeInst = dyn_cast<SpirvConstantComposite>(&that);
if (thatCompositeInst == nullptr)
return false;
return *compositeInst == *thatCompositeInst;
} else if (auto *nullInst = dyn_cast<SpirvConstantNull>(this)) {
auto *thatNullInst = dyn_cast<SpirvConstantNull>(&that);
if (thatNullInst == nullptr)
return false;
return *nullInst == *thatNullInst;
}
assert(false && "operator== undefined for SpirvConstant subclass");
return false;
}
bool SpirvConstant::isSpecConstant() const {
return opcode == spv::Op::OpSpecConstant ||
opcode == spv::Op::OpSpecConstantTrue ||

Просмотреть файл

@ -167,6 +167,22 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const {
(!stride.hasValue() || stride.getValue() == that.stride.getValue());
}
bool SpvIntrinsicTypeOperand::operator==(
const SpvIntrinsicTypeOperand &that) const {
if (isTypeOperand != that.isTypeOperand)
return false;
if (isTypeOperand) {
return operand_as_type == that.operand_as_type;
} else {
auto constantInst = dyn_cast<SpirvConstant>(operand_as_inst);
assert(constantInst != nullptr);
auto thatConstantInst = dyn_cast<SpirvConstant>(that.operand_as_inst);
assert(thatConstantInst != nullptr);
return *constantInst == *thatConstantInst;
}
}
SpirvIntrinsicType::SpirvIntrinsicType(
unsigned typeOp, llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps)
: SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"),

Просмотреть файл

@ -185,6 +185,10 @@ enum ArBasicKind {
#ifdef ENABLE_SPIRV_CODEGEN
AR_OBJECT_VK_SUBPASS_INPUT,
AR_OBJECT_VK_SUBPASS_INPUT_MS,
AR_OBJECT_VK_SPIRV_TYPE,
AR_OBJECT_VK_SPIRV_OPAQUE_TYPE,
AR_OBJECT_VK_INTEGRAL_CONSTANT,
AR_OBJECT_VK_LITERAL,
AR_OBJECT_VK_SPV_INTRINSIC_TYPE,
AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID,
#endif // ENABLE_SPIRV_CODEGEN
@ -557,6 +561,10 @@ const UINT g_uBasicKindProps[] = {
#ifdef ENABLE_SPIRV_CODEGEN
BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT
BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT_MS
BPROP_OBJECT, // AR_OBJECT_VK_SPIRV_TYPE
BPROP_OBJECT, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE
BPROP_OBJECT, // AR_OBJECT_VK_INTEGRAL_CONSTANT,
BPROP_OBJECT, // AR_OBJECT_VK_LITERAL,
BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE use recordType
BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID use recordType
#endif // ENABLE_SPIRV_CODEGEN
@ -1401,6 +1409,8 @@ static const ArBasicKind g_ArBasicKindsAsTypes[] = {
// SPIRV change starts
#ifdef ENABLE_SPIRV_CODEGEN
AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS,
AR_OBJECT_VK_SPIRV_TYPE, AR_OBJECT_VK_SPIRV_OPAQUE_TYPE,
AR_OBJECT_VK_INTEGRAL_CONSTANT, AR_OBJECT_VK_LITERAL,
AR_OBJECT_VK_SPV_INTRINSIC_TYPE, AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID,
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends
@ -1503,6 +1513,10 @@ static const uint8_t g_ArBasicKindsTemplateCount[] = {
#ifdef ENABLE_SPIRV_CODEGEN
1, // AR_OBJECT_VK_SUBPASS_INPUT
1, // AR_OBJECT_VK_SUBPASS_INPUT_MS,
1, // AR_OBJECT_VK_SPIRV_TYPE
1, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE
1, // AR_OBJECT_VK_INTEGRAL_CONSTANT,
1, // AR_OBJECT_VK_LITERAL,
1, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE
1, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID
#endif // ENABLE_SPIRV_CODEGEN
@ -1650,6 +1664,10 @@ static const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = {
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput)
{0, MipsFalse,
SampleFalse}, // AR_OBJECT_VK_SUBPASS_INPUT_MS (SubpassInputMS)
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPIRV_TYPE
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_INTEGRAL_CONSTANT,
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_LITERAL,
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE
{0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID
#endif // ENABLE_SPIRV_CODEGEN
@ -1743,7 +1761,8 @@ static const char *g_ArBasicTypeNames[] = {
// SPIRV change starts
#ifdef ENABLE_SPIRV_CODEGEN
"SubpassInput", "SubpassInputMS", "ext_type", "ext_result_id",
"SubpassInput", "SubpassInputMS", "SpirvType", "SpirvOpaqueType",
"integral_constant", "Literal", "ext_type", "ext_result_id",
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends
@ -2926,6 +2945,9 @@ private:
ClassTemplateDecl *m_matrixTemplateDecl;
ClassTemplateDecl *m_vectorTemplateDecl;
ClassTemplateDecl *m_vkIntegralConstantTemplateDecl;
ClassTemplateDecl *m_vkLiteralTemplateDecl;
// Declarations for Work Graph Output Record types
ClassTemplateDecl *m_GroupNodeOutputRecordsTemplateDecl;
ClassTemplateDecl *m_ThreadNodeOutputRecordsTemplateDecl;
@ -3795,7 +3817,25 @@ private:
recordDecl = m_ThreadNodeOutputRecordsTemplateDecl->getTemplatedDecl();
}
#ifdef ENABLE_SPIRV_CODEGEN
else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) {
else if (kind == AR_OBJECT_VK_SPIRV_TYPE && m_vkNSDecl) {
recordDecl =
DeclareInlineSpirvType(*m_context, m_vkNSDecl, typeName, false);
recordDecl->setImplicit(true);
} else if (kind == AR_OBJECT_VK_SPIRV_OPAQUE_TYPE && m_vkNSDecl) {
recordDecl =
DeclareInlineSpirvType(*m_context, m_vkNSDecl, typeName, true);
recordDecl->setImplicit(true);
} else if (kind == AR_OBJECT_VK_INTEGRAL_CONSTANT && m_vkNSDecl) {
recordDecl =
DeclareVkIntegralConstant(*m_context, m_vkNSDecl, typeName,
&m_vkIntegralConstantTemplateDecl);
recordDecl->setImplicit(true);
} else if (kind == AR_OBJECT_VK_LITERAL && m_vkNSDecl) {
recordDecl = DeclareTemplateTypeWithHandleInDeclContext(
*m_context, m_vkNSDecl, typeName, 1, nullptr);
recordDecl->setImplicit(true);
m_vkLiteralTemplateDecl = recordDecl->getDescribedClassTemplate();
} else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) {
recordDecl = DeclareUIntTemplatedTypeWithHandleInDeclContext(
*m_context, m_vkNSDecl, typeName, "id");
recordDecl->setImplicit(true);
@ -3914,8 +3954,10 @@ private:
public:
HLSLExternalSource()
: m_matrixTemplateDecl(nullptr), m_vectorTemplateDecl(nullptr),
m_hlslNSDecl(nullptr), m_vkNSDecl(nullptr), m_context(nullptr),
m_sema(nullptr), m_hlslStringTypedef(nullptr) {
m_vkIntegralConstantTemplateDecl(nullptr),
m_vkLiteralTemplateDecl(nullptr), m_hlslNSDecl(nullptr),
m_vkNSDecl(nullptr), m_context(nullptr), m_sema(nullptr),
m_hlslStringTypedef(nullptr) {
memset(m_matrixTypes, 0, sizeof(m_matrixTypes));
memset(m_matrixShorthandTypes, 0, sizeof(m_matrixShorthandTypes));
memset(m_vectorTypes, 0, sizeof(m_vectorTypes));
@ -4183,6 +4225,9 @@ public:
return AR_TOBJ_MATRIX;
else if (decl == m_vectorTemplateDecl)
return AR_TOBJ_VECTOR;
else if (decl == m_vkIntegralConstantTemplateDecl ||
decl == m_vkLiteralTemplateDecl)
return AR_TOBJ_COMPOUND;
else if (!decl->isImplicit())
return AR_TOBJ_COMPOUND;
return AR_TOBJ_OBJECT;
@ -14593,6 +14638,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
if (!getLangOpts().SPIRV) {
if (basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT ||
basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPIRV_TYPE ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPIRV_OPAQUE_TYPE ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID) {
Diag(D.getLocStart(), diag::err_hlsl_vulkan_specific_feature)

Просмотреть файл

@ -0,0 +1,36 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s
typedef vk::SpirvType</* OpTypeInt */ 21, /* size */ 1, /* alignment */ 1, vk::Literal<vk::integral_constant<uint, 8> >, vk::Literal<vk::integral_constant<bool, false> > > type1;
typedef vk::SpirvType</* OpTypeInt */ 21, /* size */ 1, /* alignment */ 32, vk::Literal<vk::integral_constant<uint, 8> >, vk::Literal<vk::integral_constant<bool, false> > > type2;
typedef vk::SpirvType</* OpTypeInt */ 21, /* size */ 32, /* alignment */ 1, vk::Literal<vk::integral_constant<uint, 8> >, vk::Literal<vk::integral_constant<bool, false> > > type3;
// CHECK: OpDecorate %_arr_spirvIntrinsicType_uint_3 ArrayStride 16
// CHECK: OpDecorate %_arr_spirvIntrinsicType_uint_3_0 ArrayStride 32
// CHECK: OpMemberDecorate %type__Globals 0 Offset 0
type1 a;
// CHECK: OpMemberDecorate %type__Globals 1 Offset 16
type1 a_arr[3];
// CHECK: OpMemberDecorate %type__Globals 2 Offset 64
type2 b;
// CHECK: OpMemberDecorate %type__Globals 3 Offset 96
type2 b_arr[3];
// CHECK: OpMemberDecorate %type__Globals 4 Offset 192
type3 c;
// CHECK: OpMemberDecorate %type__Globals 5 Offset 224
type3 c_arr[3];
// CHECK: OpMemberDecorate %type__Globals 6 Offset 320
int end;
// CHECK: %spirvIntrinsicType = OpTypeInt 8 0
[[vk::ext_capability(/* Int8 */ 39), vk::ext_capability(/* Int16 */ 22)]]
void main() {
}

Просмотреть файл

@ -0,0 +1,28 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv
// This won't work until #5554 is fixed. When it is, add a check that it compiles correctly.
enum class Scope {
CrossDevice = 0,
Device = 1,
Workgroup = 2,
Subgroup = 3,
Invocation = 4,
QueueFamily = 5,
QueueFamilyKHR = 5,
ShaderCallKHR = 6,
};
enum class CooperativeMatrixUse {
MatrixAKHR = 0,
MatrixBKHR = 1,
MatrixAccumulatorKHR = 2,
};
typedef vk::SpirvOpaqueType</* OpTypeCooperativeMatrixKHR */ 4456, float, vk::integral_constant<Scope, Scope::Subgroup>, 32, 32, CooperativeMatrixUse::MatrixAKHR> mat_t;
[[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
[[vk::ext_capability(/* CooperativeMatrixKHR */ 6022)]]
void main() {
mat_t mat;
}

Просмотреть файл

@ -0,0 +1,32 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s
// TODO(6498): enable Array test when using `Texture2D` with an alias template of `SpirvType` is fixed
// CHECK-TODO: %type_Array_type_2d_image = OpTypeArray %type_2d_image
// template<typename SomeType>
// using Array = vk::SpirvOpaqueType</* OpTypeArray */ 28, SomeType, 4>;
// CHECK: %spirvIntrinsicType = OpTypeArray %type_2d_image %uint_4
typedef vk::SpirvOpaqueType</* OpTypeArray */ 28, Texture2D, vk::integral_constant<uint, 4> > ArrayTex2D;
// CHECK: %spirvIntrinsicType_0 = OpTypeInt 8 0
typedef vk::SpirvOpaqueType</* OpTypeInt */ 21, vk::Literal<vk::integral_constant<uint, 8> >, vk::Literal<vk::integral_constant<bool, false> > > uint8_t;
// CHECK: %_arr_spirvIntrinsicType_0_uint_4 = OpTypeArray %spirvIntrinsicType_0 %uint_4
// TODO: maybe I've checked this before, but can we add this to uint8_t instead?
[[vk::ext_capability(/* Int8 */ 39)]]
void main() {
// CHECK: %image = OpVariable %_ptr_Function_spirvIntrinsicType Function
// Array<Texture2D> image;
ArrayTex2D image;
// CHECK: %byte = OpVariable %_ptr_Function_spirvIntrinsicType_0
uint8_t byte;
// Check that uses of the same type use the same SPIR-V type definition.
// CHECK: %byte1 = OpVariable %_ptr_Function_spirvIntrinsicType_0
uint8_t byte1;
// CHECK: %bytes = OpVariable %_ptr_Function__arr_spirvIntrinsicType_0_uint_4
uint8_t bytes[4];
}

Просмотреть файл

@ -0,0 +1,8 @@
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
// CHECK: error: The template argument to vk::Literal must be a vk::integral_constant
typedef vk::SpirvOpaqueType</* OpTypeArray */ 28, Texture2D, vk::Literal<bool> > Invalid;
void main() {
Invalid a;
}