From d60dffef1a3ccb237050a18087598a79e9bd8c2d Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Mon, 15 Apr 2024 13:58:49 -0700 Subject: [PATCH] [SPIR-V] Implement SpirvType and SpirvOpaqueType (#6156) Implements hlsl-specs proposal 0011, adding `vk::SpirvType` and `vk::SpirvOpaqueType` templates which allow users to define and use SPIR-V level types. --- .../clang/AST/HlslBuiltinTypeDeclBuilder.h | 3 +- tools/clang/include/clang/AST/HlslTypes.h | 10 ++ .../clang/include/clang/SPIRV/SpirvContext.h | 12 +- .../include/clang/SPIRV/SpirvInstruction.h | 2 + tools/clang/include/clang/SPIRV/SpirvType.h | 13 +- tools/clang/lib/AST/ASTContextHLSL.cpp | 37 ++++ .../lib/AST/HlslBuiltinTypeDeclBuilder.cpp | 7 +- .../lib/SPIRV/AlignmentSizeCalculator.cpp | 16 ++ tools/clang/lib/SPIRV/ConstEvaluator.h | 12 +- tools/clang/lib/SPIRV/EmitVisitor.cpp | 6 +- tools/clang/lib/SPIRV/InitListHandler.cpp | 3 +- tools/clang/lib/SPIRV/LowerTypeVisitor.cpp | 164 +++++++++++++++++- tools/clang/lib/SPIRV/LowerTypeVisitor.h | 27 ++- tools/clang/lib/SPIRV/SpirvBuilder.cpp | 6 +- tools/clang/lib/SPIRV/SpirvContext.cpp | 37 +++- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 25 ++- tools/clang/lib/SPIRV/SpirvInstruction.cpp | 32 ++++ tools/clang/lib/SPIRV/SpirvType.cpp | 16 ++ tools/clang/lib/Sema/SemaHLSL.cpp | 55 +++++- .../spv.inline.type.alignment.hlsl | 36 ++++ .../spv.inline.type.enum-class.hlsl | 28 +++ .../test/CodeGenSPIRV/spv.inline.type.hlsl | 32 ++++ .../spv.inline.type.literal.error.hlsl | 8 + 23 files changed, 533 insertions(+), 54 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/spv.inline.type.alignment.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/spv.inline.type.enum-class.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/spv.inline.type.literal.error.hlsl diff --git a/tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h b/tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h index 60603a0b6..ef36ec75c 100644 --- a/tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h +++ b/tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h @@ -34,7 +34,8 @@ public: clang::TemplateTypeParmDecl * addTypeTemplateParam(llvm::StringRef name, - clang::TypeSourceInfo *defaultValue = nullptr); + clang::TypeSourceInfo *defaultValue = nullptr, + bool parameterPack = false); clang::TemplateTypeParmDecl * addTypeTemplateParam(llvm::StringRef name, clang::QualType defaultValue); clang::NonTypeTemplateParmDecl * diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 4339cf511..e1da1193a 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -399,6 +399,16 @@ DeclareNodeOrRecordType(clang::ASTContext &Ctx, DXIL::NodeIOKind Type, bool HasGetMethods = false, bool IsArray = false, bool IsCompleteType = false); +#ifdef ENABLE_SPIRV_CODEGEN +clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context, + clang::DeclContext *declContext, + llvm::StringRef typeName, + bool opaque); +clang::CXXRecordDecl *DeclareVkIntegralConstant( + clang::ASTContext &context, clang::DeclContext *declContext, + llvm::StringRef typeName, clang::ClassTemplateDecl **templateDecl); +#endif + clang::CXXRecordDecl *DeclareNodeOutputArray(clang::ASTContext &Ctx, DXIL::NodeIOKind Type, clang::CXXRecordDecl *OutputType, diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index 328548b29..b42ba8767 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -286,9 +286,12 @@ public: const RayQueryTypeKHR *getRayQueryTypeKHR() const { return rayQueryTypeKHR; } - const SpirvIntrinsicType * - getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, - llvm::ArrayRef operands); + const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType( + unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef operands); + + const SpirvIntrinsicType *getOrCreateSpirvIntrinsicType( + unsigned typeOpCode, llvm::ArrayRef operands); SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId); @@ -471,7 +474,8 @@ private: llvm::DenseMap pointerTypes; llvm::SmallVector hybridPointerTypes; llvm::DenseSet functionTypes; - llvm::DenseMap spirvIntrinsicTypes; + llvm::DenseMap spirvIntrinsicTypesById; + llvm::SmallVector spirvIntrinsicTypes; const AccelerationStructureTypeNV *accelerationStructureTypeNV; const RayQueryTypeKHR *rayQueryTypeKHR; diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index f6a1dc520..5a7f7ae72 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -1193,6 +1193,8 @@ public: inst->getKind() <= IK_ConstantNull; } + bool operator==(const SpirvConstant &that) const; + bool isSpecConstant() const; void setLiteral(bool literal = true) { literalConstant = literal; } bool isLiteral() { return literalConstant; } diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 68ee1d080..d4012b0b8 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -429,15 +429,16 @@ public: class SpirvInstruction; struct SpvIntrinsicTypeOperand { - SpvIntrinsicTypeOperand(SpirvType *type_operand) + SpvIntrinsicTypeOperand(const SpirvType *type_operand) : operand_as_type(type_operand), isTypeOperand(true) {} SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand) : operand_as_inst(inst_operand), isTypeOperand(false) {} + bool operator==(const SpvIntrinsicTypeOperand &that) const; union { - SpirvType *operand_as_type; + const SpirvType *operand_as_type; SpirvInstruction *operand_as_inst; }; - bool isTypeOperand; + const bool isTypeOperand; }; class SpirvIntrinsicType : public SpirvType { @@ -453,6 +454,12 @@ public: return operands; } + bool operator==(const SpirvIntrinsicType &that) const { + return typeOpCode == that.typeOpCode && + operands.size() == that.operands.size() && + std::equal(operands.begin(), operands.end(), that.operands.begin()); + } + private: unsigned typeOpCode; llvm::SmallVector operands; diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index 87c1ba549..bf048ac75 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -1241,6 +1241,43 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType( return Builder.getRecordDecl(); } +#ifdef ENABLE_SPIRV_CODEGEN +CXXRecordDecl *hlsl::DeclareInlineSpirvType(clang::ASTContext &context, + clang::DeclContext *declContext, + llvm::StringRef typeName, + bool opaque) { + // template vk::SpirvType { ... } + // template vk::SpirvOpaqueType { ... } + BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName, + clang::TagTypeKind::TTK_Class); + typeDeclBuilder.addIntegerTemplateParam("opcode", context.UnsignedIntTy); + if (!opaque) { + typeDeclBuilder.addIntegerTemplateParam("size", context.UnsignedIntTy); + typeDeclBuilder.addIntegerTemplateParam("alignment", context.UnsignedIntTy); + } + typeDeclBuilder.addTypeTemplateParam("operands", nullptr, true); + typeDeclBuilder.startDefinition(); + typeDeclBuilder.addField( + "h", context.UnsignedIntTy); // Add an 'h' field to hold the handle. + return typeDeclBuilder.getRecordDecl(); +} + +CXXRecordDecl *hlsl::DeclareVkIntegralConstant( + clang::ASTContext &context, clang::DeclContext *declContext, + llvm::StringRef typeName, ClassTemplateDecl **templateDecl) { + // template vk::integral_constant { ... } + BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName, + clang::TagTypeKind::TTK_Class); + typeDeclBuilder.addTypeTemplateParam("T"); + typeDeclBuilder.addIntegerTemplateParam("v", context.UnsignedIntTy); + typeDeclBuilder.startDefinition(); + typeDeclBuilder.addField( + "h", context.UnsignedIntTy); // Add an 'h' field to hold the handle. + *templateDecl = typeDeclBuilder.getTemplateDecl(); + return typeDeclBuilder.getRecordDecl(); +} +#endif + CXXRecordDecl *hlsl::DeclareNodeOutputArray(clang::ASTContext &Ctx, DXIL::NodeIOKind Type, CXXRecordDecl *OutputType, diff --git a/tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp b/tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp index d53b1e84d..fb3fde379 100644 --- a/tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp +++ b/tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp @@ -33,9 +33,8 @@ BuiltinTypeDeclBuilder::BuiltinTypeDeclBuilder(DeclContext *declContext, m_recordDecl->setImplicit(true); } -TemplateTypeParmDecl * -BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name, - TypeSourceInfo *defaultValue) { +TemplateTypeParmDecl *BuiltinTypeDeclBuilder::addTypeTemplateParam( + StringRef name, TypeSourceInfo *defaultValue, bool parameterPack) { DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() && !m_recordDecl->isCompleteDefinition()); @@ -45,7 +44,7 @@ BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name, astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc, /* TemplateDepth */ 0, index, &astContext.Idents.get(name, tok::TokenKind::identifier), - /* Typename */ false, /* ParameterPack */ false); + /* Typename */ false, parameterPack); if (defaultValue != nullptr) decl->setDefaultArgument(defaultValue); m_templateParams.emplace_back(decl); diff --git a/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp b/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp index 0416ba86c..072a83fa5 100644 --- a/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp +++ b/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp @@ -9,6 +9,7 @@ #include "AlignmentSizeCalculator.h" #include "clang/AST/Attr.h" +#include "clang/AST/DeclTemplate.h" namespace { @@ -264,6 +265,21 @@ std::pair AlignmentSizeCalculator::getAlignmentAndSize( return getAlignmentAndSize(desugaredType, rule, isRowMajor, stride); } + const auto *recordType = type->getAs(); + if (recordType != nullptr) { + const llvm::StringRef name = recordType->getDecl()->getName(); + + if (isTypeInVkNamespace(recordType) && name == "SpirvType") { + const ClassTemplateSpecializationDecl *templateDecl = + cast(recordType->getDecl()); + const uint64_t size = + templateDecl->getTemplateArgs()[1].getAsIntegral().getZExtValue(); + const uint64_t alignment = + templateDecl->getTemplateArgs()[2].getAsIntegral().getZExtValue(); + return {alignment, size}; + } + } + if (isEnumType(type)) type = astContext.IntTy; diff --git a/tools/clang/lib/SPIRV/ConstEvaluator.h b/tools/clang/lib/SPIRV/ConstEvaluator.h index 85a9827da..858405db3 100644 --- a/tools/clang/lib/SPIRV/ConstEvaluator.h +++ b/tools/clang/lib/SPIRV/ConstEvaluator.h @@ -35,6 +35,12 @@ public: SpirvConstant *translateAPFloat(llvm::APFloat floatValue, QualType targetType, bool isSpecConstantMode); + /// Translates the given frontend APValue into its SPIR-V equivalent for the + /// given targetType. + SpirvConstant *translateAPValue(const APValue &value, + const QualType targetType, + bool isSpecConstantMode); + /// Tries to evaluate the given APInt as a 32-bit integer. If the evaluation /// can be performed without loss, it returns the of the SPIR-V /// constant for that value. @@ -52,12 +58,6 @@ public: bool isSpecConstantMode); private: - /// Translates the given frontend APValue into its SPIR-V equivalent for the - /// given targetType. - SpirvConstant *translateAPValue(const APValue &value, - const QualType targetType, - bool isSpecConstantMode); - /// Emits error to the diagnostic engine associated with the AST context. template DiagnosticBuilder emitError(const char (&message)[N], diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index a8199cc56..7eb1a7384 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -2577,7 +2577,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { for (const SpvIntrinsicTypeOperand &operand : spvIntrinsicType->getOperands()) { if (operand.isTypeOperand) { - curTypeInst.push_back(emitType(operand.operand_as_type)); + // calling emitType recursively will potentially replace the contents of + // curTypeInst, so we need to save them and restore after the call + std::vector outerTypeInst = curTypeInst; + outerTypeInst.push_back(emitType(operand.operand_as_type)); + curTypeInst = outerTypeInst; } else { auto *literal = dyn_cast(operand.operand_as_inst); if (literal && literal->isLiteral()) { diff --git a/tools/clang/lib/SPIRV/InitListHandler.cpp b/tools/clang/lib/SPIRV/InitListHandler.cpp index 452940cfe..b2d74fe99 100644 --- a/tools/clang/lib/SPIRV/InitListHandler.cpp +++ b/tools/clang/lib/SPIRV/InitListHandler.cpp @@ -417,7 +417,8 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc, assert(recordType); LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(), - theEmitter.getSpirvOptions()); + theEmitter.getSpirvOptions(), + theEmitter.getSpirvBuilder()); const SpirvType *spirvType = lowerTypeVisitor.lowerType(type, SpirvLayoutRule::Void, false, srcLoc); diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index c60ca7e54..8fdc0cb9e 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -8,6 +8,8 @@ //===----------------------------------------------------------------------===// #include "LowerTypeVisitor.h" + +#include "ConstEvaluator.h" #include "clang/AST/Attr.h" #include "clang/AST/DeclCXX.h" #include "clang/AST/HlslTypes.h" @@ -530,7 +532,8 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type, // (ClassTemplateSpecializationDecl is a subclass of CXXRecordDecl, which // is then a subclass of RecordDecl.) So we need to check them before // checking the general struct type. - if (const auto *spvType = lowerResourceType(type, rule, srcLoc)) { + if (const auto *spvType = + lowerResourceType(type, rule, isRowMajor, srcLoc)) { spvContext.registerStructDeclForSpirvType(spvType, decl); return spvType; } @@ -620,10 +623,152 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type, return 0; } -const SpirvType * -LowerTypeVisitor::lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name, - SpirvLayoutRule rule, - SourceLocation srcLoc) { +QualType LowerTypeVisitor::createASTTypeFromTemplateName(TemplateName name) { + auto *decl = name.getAsTemplateDecl(); + if (decl == nullptr) { + return QualType(); + } + + auto *classTemplateDecl = dyn_cast(decl); + if (classTemplateDecl == nullptr) { + return QualType(); + } + + TemplateParameterList *parameters = + classTemplateDecl->getTemplateParameters(); + if (parameters->size() != 1) { + return QualType(); + } + + auto *parmDecl = dyn_cast(parameters->getParam(0)); + if (parmDecl == nullptr) { + return QualType(); + } + + if (!parmDecl->hasDefaultArgument()) { + return QualType(); + } + + TemplateArgument *arg = + new (context) TemplateArgument(parmDecl->getDefaultArgument()); + + auto *specialized = ClassTemplateSpecializationDecl::Create( + astContext, TagDecl::TagKind::TTK_Class, + classTemplateDecl->getDeclContext(), classTemplateDecl->getLocStart(), + classTemplateDecl->getLocStart(), classTemplateDecl, /* Args */ arg, + /* NumArgs */ 1, + /* PrevDecl */ nullptr); + QualType type = astContext.getTypeDeclType(specialized); + + return type; +} + +bool LowerTypeVisitor::getVkIntegralConstantValue(QualType type, + SpirvConstant *&result, + SourceLocation srcLoc) { + auto *recordType = type->getAs(); + if (!recordType) + return false; + if (!isTypeInVkNamespace(recordType)) + return false; + + if (recordType->getDecl()->getName() == "Literal") { + auto *specDecl = + dyn_cast(recordType->getDecl()); + assert(specDecl); + + const TemplateArgumentList &args = specDecl->getTemplateArgs(); + QualType constant = args[0].getAsType(); + bool val = getVkIntegralConstantValue(constant, result, srcLoc); + + if (val) { + result->setLiteral(true); + } else { + emitError("The template argument to vk::Literal must be a " + "vk::integral_constant", + srcLoc); + } + return true; + } + + if (recordType->getDecl()->getName() != "integral_constant") + return false; + + auto *specDecl = + dyn_cast(recordType->getDecl()); + assert(specDecl); + + const TemplateArgumentList &args = specDecl->getTemplateArgs(); + + QualType constantType = args[0].getAsType(); + llvm::APSInt value = args[1].getAsIntegral(); + result = ConstEvaluator(astContext, spvBuilder) + .translateAPValue(APValue(value), constantType, false); + return true; +} + +const SpirvType *LowerTypeVisitor::lowerInlineSpirvType( + llvm::StringRef name, unsigned int opcode, + const ClassTemplateSpecializationDecl *specDecl, SpirvLayoutRule rule, + llvm::Optional isRowMajor, SourceLocation srcLoc) { + assert(specDecl); + + SmallVector operands; + + // Lower each operand argument + + size_t operandsIndex = 1; + if (name == "SpirvType") + operandsIndex = 3; + + auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray(); + + for (TemplateArgument arg : args) { + switch (arg.getKind()) { + case TemplateArgument::ArgKind::Type: { + QualType typeArg = arg.getAsType(); + + SpirvConstant *constant = nullptr; + if (getVkIntegralConstantValue(typeArg, constant, srcLoc)) { + if (constant) { + visitInstruction(constant); + operands.emplace_back(constant); + } + } else { + operands.emplace_back(lowerType(typeArg, rule, isRowMajor, srcLoc)); + } + break; + } + case TemplateArgument::ArgKind::Template: { + // Handle HLSL template types that allow the omission of < and >; for + // example, Texture2D + TemplateName templateName = arg.getAsTemplate(); + QualType typeArg = createASTTypeFromTemplateName(templateName); + assert(!typeArg.isNull() && + "Could not create HLSL type from template name"); + + operands.emplace_back(lowerType(typeArg, rule, isRowMajor, srcLoc)); + break; + } + default: + emitError("template argument kind %0 unimplemented", srcLoc) + << arg.getKind(); + } + } + return spvContext.getOrCreateSpirvIntrinsicType(opcode, operands); +} + +const SpirvType *LowerTypeVisitor::lowerVkTypeInVkNamespace( + QualType type, llvm::StringRef name, SpirvLayoutRule rule, + llvm::Optional isRowMajor, SourceLocation srcLoc) { + if (name == "SpirvType" || name == "SpirvOpaqueType") { + auto opcode = hlsl::GetHLSLResourceTemplateUInt(type); + auto *specDecl = dyn_cast( + type->getAs()->getDecl()); + + return lowerInlineSpirvType(name, opcode, specDecl, rule, isRowMajor, + srcLoc); + } if (name == "ext_type") { auto typeId = hlsl::GetHLSLResourceTemplateUInt(type); return spvContext.getCreatedSpirvIntrinsicType(typeId); @@ -636,9 +781,10 @@ LowerTypeVisitor::lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name, return nullptr; } -const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type, - SpirvLayoutRule rule, - SourceLocation srcLoc) { +const SpirvType * +LowerTypeVisitor::lowerResourceType(QualType type, SpirvLayoutRule rule, + llvm::Optional isRowMajor, + SourceLocation srcLoc) { // Resource types are either represented like C struct or C++ class in the // AST. Samplers are represented like C struct, so isStructureType() will // return true for it; textures are represented like C++ class, so @@ -651,7 +797,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type, const llvm::StringRef name = recordType->getDecl()->getName(); if (isTypeInVkNamespace(recordType)) { - return lowerVkTypeInVkNamespace(type, name, rule, srcLoc); + return lowerVkTypeInVkNamespace(type, name, rule, isRowMajor, srcLoc); } // TODO: avoid string comparison once hlsl::IsHLSLResouceType() does that. diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.h b/tools/clang/lib/SPIRV/LowerTypeVisitor.h index a591d9c53..895e0e6cf 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.h +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.h @@ -12,6 +12,7 @@ #include "AlignmentSizeCalculator.h" #include "clang/AST/ASTContext.h" +#include "clang/SPIRV/SpirvBuilder.h" #include "clang/SPIRV/SpirvContext.h" #include "clang/SPIRV/SpirvVisitor.h" #include "llvm/ADT/Optional.h" @@ -23,9 +24,10 @@ namespace spirv { class LowerTypeVisitor : public Visitor { public: LowerTypeVisitor(ASTContext &astCtx, SpirvContext &spvCtx, - const SpirvCodeGenOptions &opts) + const SpirvCodeGenOptions &opts, SpirvBuilder &builder) : Visitor(opts, spvCtx), astContext(astCtx), spvContext(spvCtx), - alignmentCalc(astCtx, opts), useArrayForMat1xN(false) {} + alignmentCalc(astCtx, opts), useArrayForMat1xN(false), + spvBuilder(builder) {} // Visiting different SPIR-V constructs. bool visit(SpirvModule *, Phase) override { return true; } @@ -69,6 +71,7 @@ private: /// Lowers the given HLSL resource type into its SPIR-V type. const SpirvType *lowerResourceType(QualType type, SpirvLayoutRule rule, + llvm::Optional isRowMajor, SourceLocation); /// Lowers the fields of a RecordDecl into SPIR-V StructType field @@ -76,9 +79,28 @@ private: llvm::SmallVector lowerStructFields(const RecordDecl *structType, SpirvLayoutRule rule); + /// Creates the default AST type from a TemplateName for HLSL templates + /// which have optional parameters (e.g. Texture2D). + QualType createASTTypeFromTemplateName(TemplateName templateName); + + /// If the given type is an integral_constant or a Literal, + /// return the constant value as a SpirvConstant, which will be set as a + /// literal constant if wrapped in Literal. + bool getVkIntegralConstantValue(QualType type, SpirvConstant *&result, + SourceLocation srcLoc); + + /// Lowers the given vk::SpirvType or vk::SpirvOpaqueType into its SPIR-V + /// type. + const SpirvType * + lowerInlineSpirvType(llvm::StringRef name, unsigned int opcode, + const ClassTemplateSpecializationDecl *specDecl, + SpirvLayoutRule rule, llvm::Optional isRowMajor, + SourceLocation srcLoc); + /// Lowers the given type defined in vk namespace into its SPIR-V type. const SpirvType *lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name, SpirvLayoutRule rule, + llvm::Optional isRowMajor, SourceLocation srcLoc); /// For the given sampled type, returns the corresponding image format @@ -107,6 +129,7 @@ private: SpirvContext &spvContext; /// SPIR-V context AlignmentSizeCalculator alignmentCalc; /// alignment calculator bool useArrayForMat1xN; /// SPIR-V array for HLSL Matrix 1xN + SpirvBuilder &spvBuilder; }; } // end namespace spirv diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index f7d8e9b61..2adc0cd15 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -293,7 +293,7 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address, if (bitfieldInfo.hasValue()) { // Generate SPIR-V type for value. This is required to know the final // layout. - LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this); lowerTypeVisitor.visitInstruction(value); context.addToInstructionsWithLoweredType(value); @@ -1403,7 +1403,7 @@ SpirvBuilder::initializeCloneVarForFxcCTBuffer(SpirvInstruction *instr) { auto astType = var->getAstResultType(); const auto *spvType = var->getResultType(); - LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this); lowerTypeVisitor.visitInstruction(var); context.addToInstructionsWithLoweredType(instr); if (!lowerTypeVisitor.useSpvArrayForHlslMat1xN()) { @@ -1881,7 +1881,7 @@ std::vector SpirvBuilder::takeModule() { // Run necessary visitor passes first LiteralTypeVisitor literalTypeVisitor(astContext, context, spirvOptions); - LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions, *this); CapabilityVisitor capabilityVisitor(astContext, context, spirvOptions, *this, featureManager); RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions); diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index 4f9fabafd..b1f6a6bd6 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -96,10 +96,14 @@ SpirvContext::~SpirvContext() { for (auto &typePair : typeTemplateParams) typePair.second->releaseMemory(); - for (auto &pair : spirvIntrinsicTypes) { + for (auto &pair : spirvIntrinsicTypesById) { assert(pair.second); pair.second->~SpirvIntrinsicType(); } + + for (auto *spirvIntrinsicType : spirvIntrinsicTypes) { + spirvIntrinsicType->~SpirvIntrinsicType(); + } } inline uint32_t log2ForBitwidth(uint32_t bitwidth) { @@ -534,22 +538,41 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) { typeTemplateParams.clear(); } -const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType( +const SpirvIntrinsicType *SpirvContext::getOrCreateSpirvIntrinsicType( unsigned typeId, unsigned typeOpCode, llvm::ArrayRef operands) { - if (spirvIntrinsicTypes[typeId] == nullptr) { - spirvIntrinsicTypes[typeId] = + if (spirvIntrinsicTypesById[typeId] == nullptr) { + spirvIntrinsicTypesById[typeId] = new (this) SpirvIntrinsicType(typeOpCode, operands); } - return spirvIntrinsicTypes[typeId]; + return spirvIntrinsicTypesById[typeId]; +} + +const SpirvIntrinsicType *SpirvContext::getOrCreateSpirvIntrinsicType( + unsigned typeOpCode, llvm::ArrayRef operands) { + SpirvIntrinsicType type(typeOpCode, operands); + + auto found = + std::find_if(spirvIntrinsicTypes.begin(), spirvIntrinsicTypes.end(), + [&type](const SpirvIntrinsicType *cachedType) { + return type == *cachedType; + }); + + if (found != spirvIntrinsicTypes.end()) + return *found; + + spirvIntrinsicTypes.push_back(new (this) + SpirvIntrinsicType(typeOpCode, operands)); + + return spirvIntrinsicTypes.back(); } SpirvIntrinsicType * SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) { - if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()) { + if (spirvIntrinsicTypesById.find(typeId) == spirvIntrinsicTypesById.end()) { return nullptr; } - return spirvIntrinsicTypes[typeId]; + return spirvIntrinsicTypesById[typeId]; } } // end namespace spirv diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index d531d0bb3..5cbc95792 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -6069,7 +6069,8 @@ SpirvInstruction *SpirvEmitter::doMemberExpr(const MemberExpr *expr, } const uint32_t indexAST = getNumBaseClasses(baseType) + fieldDecl->getFieldIndex(); - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const StructType *spirvStructType = lowerStructType(spirvOptions, lowerTypeVisitor, baseType); assert(spirvStructType); @@ -6608,7 +6609,8 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal, if (const auto *recordType = valType->getAs()) { assert(recordType->isStructureType()); - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const StructType *spirvStructType = lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar()); @@ -7157,7 +7159,8 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType, SourceRange range) { assert(astStructType->isStructureType()); - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const StructType *spirvStructType = lowerStructType(spirvOptions, lowerTypeVisitor, astStructType); uint32_t vectorIndex = 0; @@ -7966,7 +7969,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices( } { - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const auto &astStructType = /* structType */ indexing->getBase()->getType(); const StructType *spirvStructType = @@ -14224,8 +14228,8 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { } auto typeDefAttr = funcDecl->getAttr(); - spvContext.getSpirvIntrinsicType(typeDefAttr->getId(), - typeDefAttr->getOpcode(), operands); + spvContext.getOrCreateSpirvIntrinsicType(typeDefAttr->getId(), + typeDefAttr->getOpcode(), operands); return createSpirvIntrInstExt( funcDecl->getAttrs(), QualType(), @@ -14521,7 +14525,8 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) { std::vector result; const SpirvType *type = nullptr; - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); type = lowerTypeVisitor.lowerType(resultType, inst->getLayoutRule(), false, inst->getSourceLocation()); @@ -14619,7 +14624,8 @@ SpirvEmitter::generateFromScalars(QualType type, return result; } else if (const RecordType *recordType = dyn_cast(type)) { std::vector elements; - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const SpirvType *spirvType = lowerTypeVisitor.lowerType(type, layoutRule, false, sourceLocation); @@ -14701,7 +14707,8 @@ SpirvEmitter::splatScalarToGenerate(QualType type, SpirvInstruction *scalar, return result; } else if (const RecordType *recordType = dyn_cast(type)) { std::vector elements; - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); const SpirvType *spirvType = lowerTypeVisitor.lowerType( type, SpirvLayoutRule::Void, false, sourceLocation); diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index c6ec7e960..90dcc8909 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -513,6 +513,38 @@ SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType, /*SourceLocation*/ {}), literalConstant(literal) {} +bool SpirvConstant::operator==(const SpirvConstant &that) const { + if (auto *booleanInst = dyn_cast(this)) { + auto *thatBooleanInst = dyn_cast(&that); + if (thatBooleanInst == nullptr) + return false; + return *booleanInst == *thatBooleanInst; + } else if (auto *integerInst = dyn_cast(this)) { + auto *thatIntegerInst = dyn_cast(&that); + if (thatIntegerInst == nullptr) + return false; + return *integerInst == *thatIntegerInst; + } else if (auto *floatInst = dyn_cast(this)) { + auto *thatFloatInst = dyn_cast(&that); + if (thatFloatInst == nullptr) + return false; + return *floatInst == *thatFloatInst; + } else if (auto *compositeInst = dyn_cast(this)) { + auto *thatCompositeInst = dyn_cast(&that); + if (thatCompositeInst == nullptr) + return false; + return *compositeInst == *thatCompositeInst; + } else if (auto *nullInst = dyn_cast(this)) { + auto *thatNullInst = dyn_cast(&that); + if (thatNullInst == nullptr) + return false; + return *nullInst == *thatNullInst; + } + + assert(false && "operator== undefined for SpirvConstant subclass"); + return false; +} + bool SpirvConstant::isSpecConstant() const { return opcode == spv::Op::OpSpecConstant || opcode == spv::Op::OpSpecConstantTrue || diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index d6a41f19f..cabeba4cd 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -167,6 +167,22 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { (!stride.hasValue() || stride.getValue() == that.stride.getValue()); } +bool SpvIntrinsicTypeOperand::operator==( + const SpvIntrinsicTypeOperand &that) const { + if (isTypeOperand != that.isTypeOperand) + return false; + + if (isTypeOperand) { + return operand_as_type == that.operand_as_type; + } else { + auto constantInst = dyn_cast(operand_as_inst); + assert(constantInst != nullptr); + auto thatConstantInst = dyn_cast(that.operand_as_inst); + assert(thatConstantInst != nullptr); + return *constantInst == *thatConstantInst; + } +} + SpirvIntrinsicType::SpirvIntrinsicType( unsigned typeOp, llvm::ArrayRef inOps) : SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"), diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 7b30f28b9..d78f273e8 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -185,6 +185,10 @@ enum ArBasicKind { #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPIRV_TYPE, + AR_OBJECT_VK_SPIRV_OPAQUE_TYPE, + AR_OBJECT_VK_INTEGRAL_CONSTANT, + AR_OBJECT_VK_LITERAL, AR_OBJECT_VK_SPV_INTRINSIC_TYPE, AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID, #endif // ENABLE_SPIRV_CODEGEN @@ -557,6 +561,10 @@ const UINT g_uBasicKindProps[] = { #ifdef ENABLE_SPIRV_CODEGEN BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT_MS + BPROP_OBJECT, // AR_OBJECT_VK_SPIRV_TYPE + BPROP_OBJECT, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE + BPROP_OBJECT, // AR_OBJECT_VK_INTEGRAL_CONSTANT, + BPROP_OBJECT, // AR_OBJECT_VK_LITERAL, BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE use recordType BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID use recordType #endif // ENABLE_SPIRV_CODEGEN @@ -1401,6 +1409,8 @@ static const ArBasicKind g_ArBasicKindsAsTypes[] = { // SPIRV change starts #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPIRV_TYPE, AR_OBJECT_VK_SPIRV_OPAQUE_TYPE, + AR_OBJECT_VK_INTEGRAL_CONSTANT, AR_OBJECT_VK_LITERAL, AR_OBJECT_VK_SPV_INTRINSIC_TYPE, AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID, #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1503,6 +1513,10 @@ static const uint8_t g_ArBasicKindsTemplateCount[] = { #ifdef ENABLE_SPIRV_CODEGEN 1, // AR_OBJECT_VK_SUBPASS_INPUT 1, // AR_OBJECT_VK_SUBPASS_INPUT_MS, + 1, // AR_OBJECT_VK_SPIRV_TYPE + 1, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE + 1, // AR_OBJECT_VK_INTEGRAL_CONSTANT, + 1, // AR_OBJECT_VK_LITERAL, 1, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE 1, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID #endif // ENABLE_SPIRV_CODEGEN @@ -1650,6 +1664,10 @@ static const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = { {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput) {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SUBPASS_INPUT_MS (SubpassInputMS) + {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPIRV_TYPE + {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPIRV_OPAQUE_TYPE + {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_INTEGRAL_CONSTANT, + {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_LITERAL, {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE {0, MipsFalse, SampleFalse}, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID #endif // ENABLE_SPIRV_CODEGEN @@ -1743,7 +1761,8 @@ static const char *g_ArBasicTypeNames[] = { // SPIRV change starts #ifdef ENABLE_SPIRV_CODEGEN - "SubpassInput", "SubpassInputMS", "ext_type", "ext_result_id", + "SubpassInput", "SubpassInputMS", "SpirvType", "SpirvOpaqueType", + "integral_constant", "Literal", "ext_type", "ext_result_id", #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -2926,6 +2945,9 @@ private: ClassTemplateDecl *m_matrixTemplateDecl; ClassTemplateDecl *m_vectorTemplateDecl; + ClassTemplateDecl *m_vkIntegralConstantTemplateDecl; + ClassTemplateDecl *m_vkLiteralTemplateDecl; + // Declarations for Work Graph Output Record types ClassTemplateDecl *m_GroupNodeOutputRecordsTemplateDecl; ClassTemplateDecl *m_ThreadNodeOutputRecordsTemplateDecl; @@ -3795,7 +3817,25 @@ private: recordDecl = m_ThreadNodeOutputRecordsTemplateDecl->getTemplatedDecl(); } #ifdef ENABLE_SPIRV_CODEGEN - else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) { + else if (kind == AR_OBJECT_VK_SPIRV_TYPE && m_vkNSDecl) { + recordDecl = + DeclareInlineSpirvType(*m_context, m_vkNSDecl, typeName, false); + recordDecl->setImplicit(true); + } else if (kind == AR_OBJECT_VK_SPIRV_OPAQUE_TYPE && m_vkNSDecl) { + recordDecl = + DeclareInlineSpirvType(*m_context, m_vkNSDecl, typeName, true); + recordDecl->setImplicit(true); + } else if (kind == AR_OBJECT_VK_INTEGRAL_CONSTANT && m_vkNSDecl) { + recordDecl = + DeclareVkIntegralConstant(*m_context, m_vkNSDecl, typeName, + &m_vkIntegralConstantTemplateDecl); + recordDecl->setImplicit(true); + } else if (kind == AR_OBJECT_VK_LITERAL && m_vkNSDecl) { + recordDecl = DeclareTemplateTypeWithHandleInDeclContext( + *m_context, m_vkNSDecl, typeName, 1, nullptr); + recordDecl->setImplicit(true); + m_vkLiteralTemplateDecl = recordDecl->getDescribedClassTemplate(); + } else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) { recordDecl = DeclareUIntTemplatedTypeWithHandleInDeclContext( *m_context, m_vkNSDecl, typeName, "id"); recordDecl->setImplicit(true); @@ -3914,8 +3954,10 @@ private: public: HLSLExternalSource() : m_matrixTemplateDecl(nullptr), m_vectorTemplateDecl(nullptr), - m_hlslNSDecl(nullptr), m_vkNSDecl(nullptr), m_context(nullptr), - m_sema(nullptr), m_hlslStringTypedef(nullptr) { + m_vkIntegralConstantTemplateDecl(nullptr), + m_vkLiteralTemplateDecl(nullptr), m_hlslNSDecl(nullptr), + m_vkNSDecl(nullptr), m_context(nullptr), m_sema(nullptr), + m_hlslStringTypedef(nullptr) { memset(m_matrixTypes, 0, sizeof(m_matrixTypes)); memset(m_matrixShorthandTypes, 0, sizeof(m_matrixShorthandTypes)); memset(m_vectorTypes, 0, sizeof(m_vectorTypes)); @@ -4183,6 +4225,9 @@ public: return AR_TOBJ_MATRIX; else if (decl == m_vectorTemplateDecl) return AR_TOBJ_VECTOR; + else if (decl == m_vkIntegralConstantTemplateDecl || + decl == m_vkLiteralTemplateDecl) + return AR_TOBJ_COMPOUND; else if (!decl->isImplicit()) return AR_TOBJ_COMPOUND; return AR_TOBJ_OBJECT; @@ -14593,6 +14638,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth, if (!getLangOpts().SPIRV) { if (basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT || basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS || + basicKind == ArBasicKind::AR_OBJECT_VK_SPIRV_TYPE || + basicKind == ArBasicKind::AR_OBJECT_VK_SPIRV_OPAQUE_TYPE || basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE || basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID) { Diag(D.getLocStart(), diag::err_hlsl_vulkan_specific_feature) diff --git a/tools/clang/test/CodeGenSPIRV/spv.inline.type.alignment.hlsl b/tools/clang/test/CodeGenSPIRV/spv.inline.type.alignment.hlsl new file mode 100644 index 000000000..7e733ffcc --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.inline.type.alignment.hlsl @@ -0,0 +1,36 @@ +// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s + +typedef vk::SpirvType >, vk::Literal > > type1; +typedef vk::SpirvType >, vk::Literal > > type2; +typedef vk::SpirvType >, vk::Literal > > type3; + +// CHECK: OpDecorate %_arr_spirvIntrinsicType_uint_3 ArrayStride 16 +// CHECK: OpDecorate %_arr_spirvIntrinsicType_uint_3_0 ArrayStride 32 + +// CHECK: OpMemberDecorate %type__Globals 0 Offset 0 +type1 a; + +// CHECK: OpMemberDecorate %type__Globals 1 Offset 16 +type1 a_arr[3]; + +// CHECK: OpMemberDecorate %type__Globals 2 Offset 64 +type2 b; + +// CHECK: OpMemberDecorate %type__Globals 3 Offset 96 +type2 b_arr[3]; + +// CHECK: OpMemberDecorate %type__Globals 4 Offset 192 +type3 c; + +// CHECK: OpMemberDecorate %type__Globals 5 Offset 224 +type3 c_arr[3]; + +// CHECK: OpMemberDecorate %type__Globals 6 Offset 320 +int end; + + +// CHECK: %spirvIntrinsicType = OpTypeInt 8 0 + +[[vk::ext_capability(/* Int8 */ 39), vk::ext_capability(/* Int16 */ 22)]] +void main() { +} diff --git a/tools/clang/test/CodeGenSPIRV/spv.inline.type.enum-class.hlsl b/tools/clang/test/CodeGenSPIRV/spv.inline.type.enum-class.hlsl new file mode 100644 index 000000000..1650f442c --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.inline.type.enum-class.hlsl @@ -0,0 +1,28 @@ +// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv + +// This won't work until #5554 is fixed. When it is, add a check that it compiles correctly. + +enum class Scope { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + QueueFamilyKHR = 5, + ShaderCallKHR = 6, +}; + +enum class CooperativeMatrixUse { + MatrixAKHR = 0, + MatrixBKHR = 1, + MatrixAccumulatorKHR = 2, +}; + +typedef vk::SpirvOpaqueType, 32, 32, CooperativeMatrixUse::MatrixAKHR> mat_t; + +[[vk::ext_extension("SPV_KHR_cooperative_matrix")]] +[[vk::ext_capability(/* CooperativeMatrixKHR */ 6022)]] +void main() { + mat_t mat; +} diff --git a/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl b/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl new file mode 100644 index 000000000..18bd5e596 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.inline.type.hlsl @@ -0,0 +1,32 @@ +// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s + +// TODO(6498): enable Array test when using `Texture2D` with an alias template of `SpirvType` is fixed +// CHECK-TODO: %type_Array_type_2d_image = OpTypeArray %type_2d_image +// template +// using Array = vk::SpirvOpaqueType; + +// CHECK: %spirvIntrinsicType = OpTypeArray %type_2d_image %uint_4 +typedef vk::SpirvOpaqueType > ArrayTex2D; + +// CHECK: %spirvIntrinsicType_0 = OpTypeInt 8 0 +typedef vk::SpirvOpaqueType >, vk::Literal > > uint8_t; + +// CHECK: %_arr_spirvIntrinsicType_0_uint_4 = OpTypeArray %spirvIntrinsicType_0 %uint_4 + +// TODO: maybe I've checked this before, but can we add this to uint8_t instead? +[[vk::ext_capability(/* Int8 */ 39)]] +void main() { + // CHECK: %image = OpVariable %_ptr_Function_spirvIntrinsicType Function + // Array image; + ArrayTex2D image; + + // CHECK: %byte = OpVariable %_ptr_Function_spirvIntrinsicType_0 + uint8_t byte; + + // Check that uses of the same type use the same SPIR-V type definition. + // CHECK: %byte1 = OpVariable %_ptr_Function_spirvIntrinsicType_0 + uint8_t byte1; + + // CHECK: %bytes = OpVariable %_ptr_Function__arr_spirvIntrinsicType_0_uint_4 + uint8_t bytes[4]; +} diff --git a/tools/clang/test/CodeGenSPIRV/spv.inline.type.literal.error.hlsl b/tools/clang/test/CodeGenSPIRV/spv.inline.type.literal.error.hlsl new file mode 100644 index 000000000..f56a87906 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.inline.type.literal.error.hlsl @@ -0,0 +1,8 @@ +// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s + +// CHECK: error: The template argument to vk::Literal must be a vk::integral_constant +typedef vk::SpirvOpaqueType > Invalid; + +void main() { + Invalid a; +}