From 62629d982af708eedd6d1f4ab4a96f84df0b2eaa Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 27 Sep 2017 16:04:35 -0400 Subject: [PATCH] [spirv] Call SPIR-V legalization passes from SPIRV-Tools (#655) When seeing opaque types within structs in function parameter, function return, and variable definition, invoke SPIRV-Tools legalization passes. Also refreshed external projects --- external/CMakeLists.txt | 2 +- external/SPIRV-Tools | 2 +- external/googletest | 2 +- external/re2 | 2 +- tools/clang/lib/SPIRV/CMakeLists.txt | 2 + tools/clang/lib/SPIRV/SPIRVEmitter.cpp | 61 ++++++++++++++++--- tools/clang/lib/SPIRV/SPIRVEmitter.h | 18 ++++++ tools/clang/lib/SPIRV/TypeTranslator.cpp | 47 ++++++++++++++ tools/clang/lib/SPIRV/TypeTranslator.h | 8 +++ tools/clang/unittests/SPIRV/FileTestUtils.cpp | 5 +- 10 files changed, 134 insertions(+), 15 deletions(-) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 7ee001fb1..bdb529d1f 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -28,7 +28,7 @@ if (${ENABLE_SPIRV_CODEGEN}) if (NOT TARGET SPIRV-Tools) message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen") else() - set(SPIRV_TOOLS_INCLUDE_DIR ${SPIRV-Tools_SOURCE_DIR}/include PARENT_SCOPE) + set(SPIRV_TOOLS_INCLUDE_DIR ${spirv-tools_SOURCE_DIR}/include PARENT_SCOPE) endif() set(SPIRV_DEP_TARGETS diff --git a/external/SPIRV-Tools b/external/SPIRV-Tools index 768d9b42d..dcf42433a 160000 --- a/external/SPIRV-Tools +++ b/external/SPIRV-Tools @@ -1 +1 @@ -Subproject commit 768d9b42d38c7562bd42dbc29b22c61046848ee8 +Subproject commit dcf42433a63c9779cf1269a4e5f1caea3a887b63 diff --git a/external/googletest b/external/googletest index b7e8a993b..f1a87d73f 160000 --- a/external/googletest +++ b/external/googletest @@ -1 +1 @@ -Subproject commit b7e8a993b4125d1083cb431d91407d8ee4dba2ad +Subproject commit f1a87d73fc604c5ab8fbb0cc6fa9a86ffd845530 diff --git a/external/re2 b/external/re2 index 971f917a3..d2b639578 160000 --- a/external/re2 +++ b/external/re2 @@ -1 +1 @@ -Subproject commit 971f917a35125c6dcfabf099d5fe9a1e5c383265 +Subproject commit d2b639578a17f459ff90f4bf9b904f66c3ebb93d diff --git a/tools/clang/lib/SPIRV/CMakeLists.txt b/tools/clang/lib/SPIRV/CMakeLists.txt index fea35bade..df00bb2b2 100644 --- a/tools/clang/lib/SPIRV/CMakeLists.txt +++ b/tools/clang/lib/SPIRV/CMakeLists.txt @@ -24,6 +24,8 @@ add_clang_library(clangSPIRV clangBasic clangFrontend clangLex + SPIRV-Tools-opt ) target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}) +target_include_directories(clangSPIRV PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR}) diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp index 329893668..9a40e6f91 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.cpp +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.cpp @@ -14,6 +14,7 @@ #include "SPIRVEmitter.h" #include "dxc/HlslIntrinsicOp.h" +#include "spirv-tools/optimizer.hpp" #include "llvm/ADT/StringExtras.h" #include "InitListHandler.h" @@ -148,15 +149,33 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) { return nullptr; } -/// \brief Returns the statement that is the immediate parent AST node of the -/// given statement. Returns nullptr if there are no parents nodes. -const Stmt *getImmediateParent(ASTContext &astContext, const Stmt *stmt) { - const auto &parents = astContext.getParents(*stmt); - return parents.empty() ? nullptr : parents[0].get(); -} +bool spirvToolsOptimize(std::vector *module, std::string *messages) { + spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0); -bool isLoopStmt(const Stmt *stmt) { - return isa(stmt) || isa(stmt) || isa(stmt); + optimizer.SetMessageConsumer( + [messages](spv_message_level_t /*level*/, const char * /*source*/, + const spv_position_t & /*position*/, + const char *message) { *messages += message; }); + + optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass()); + optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass()); + optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass()); + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + + optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass()); + optimizer.RegisterPass(spvtools::CreateBlockMergePass()); + optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass()); + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + + optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass()); + optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass()); + + optimizer.RegisterPass(spvtools::CreateCompactIdsPass()); + + return optimizer.Run(module->data(), module->size(), module); } } // namespace @@ -171,7 +190,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci, theContext(), theBuilder(&theContext), declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions), typeTranslator(astContext, theBuilder, diags), entryFunctionId(0), - curFunction(nullptr), curThis(0) { + curFunction(nullptr), curThis(0), needsLegalization(false) { if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid) emitError("unknown shader module: %0") << shaderModel.GetName(); } @@ -230,6 +249,19 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) { // Output the constructed module. std::vector m = theBuilder.takeModule(); + + const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel; + if (needsLegalization || optLevel > 0) { + if (needsLegalization && optLevel == 0) + emitWarning("-O0 ignored since SPIR-V legalization required"); + + std::string messages; + if (!spirvToolsOptimize(&m, &messages)) { + emitFatalError("failed to legalize/optimize SPIR-V: %0") << messages; + return; + } + } + theCompilerInstance.getOutStream()->write( reinterpret_cast(m.data()), m.size() * 4); } @@ -425,6 +457,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) { funcId = declIdMapper.getDeclResultId(decl); } + if (!needsLegalization && + TypeTranslator::isOpaqueStructType(decl->getReturnType())) + needsLegalization = true; + const uint32_t retType = typeTranslator.translateType(decl->getReturnType()); // Construct the function signature. @@ -454,6 +490,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) { const uint32_t ptrType = theBuilder.getPointerType(valueType, spv::StorageClass::Function); paramTypes.push_back(ptrType); + + if (!needsLegalization && + TypeTranslator::isOpaqueStructType(param->getType())) + needsLegalization = true; } const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes); @@ -555,6 +595,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) { if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) { theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision); } + + if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType())) + needsLegalization = true; } spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) { diff --git a/tools/clang/lib/SPIRV/SPIRVEmitter.h b/tools/clang/lib/SPIRV/SPIRVEmitter.h index 171a20aae..a4b9152ba 100644 --- a/tools/clang/lib/SPIRV/SPIRVEmitter.h +++ b/tools/clang/lib/SPIRV/SPIRVEmitter.h @@ -500,6 +500,15 @@ private: const CXXMemberCallExpr *); private: + /// \brief Wrapper method to create a fatal error message and report it + /// in the diagnostic engine associated with this consumer. + template + DiagnosticBuilder emitFatalError(const char (&message)[N]) { + const auto diagId = + diags.getCustomDiagID(clang::DiagnosticsEngine::Fatal, message); + return diags.Report(diagId); + } + /// \brief Wrapper method to create an error message and report it /// in the diagnostic engine associated with this consumer. template DiagnosticBuilder emitError(const char (&message)[N]) { @@ -548,6 +557,15 @@ private: /// The SPIR-V function parameter for the current this object. uint32_t curThis; + /// Whether the translated SPIR-V binary needs legalization. + /// + /// The following cases will require legalization: + /// * Opaque types (textures, samplers) within structs + /// + /// If this is true, SPIRV-Tools legalization passes will be executed after + /// the translation to legalize the generated SPIR-V binary. + bool needsLegalization; + /// Global variables that should be initialized once at the begining of the /// entry function. llvm::SmallVector toInitGloalVars; diff --git a/tools/clang/lib/SPIRV/TypeTranslator.cpp b/tools/clang/lib/SPIRV/TypeTranslator.cpp index 1c3239974..646daacc0 100644 --- a/tools/clang/lib/SPIRV/TypeTranslator.cpp +++ b/tools/clang/lib/SPIRV/TypeTranslator.cpp @@ -61,6 +61,53 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) { return false; } +bool TypeTranslator::isOpaqueType(QualType type) { + if (const auto *recordType = type->getAs()) { + const auto name = recordType->getDecl()->getName(); + + if (name == "Texture1D" || name == "RWTexture1D") + return true; + if (name == "Texture2D" || name == "RWTexture2D") + return true; + if (name == "Texture2DMS" || name == "RWTexture2DMS") + return true; + if (name == "Texture3D" || name == "RWTexture3D") + return true; + if (name == "TextureCube" || name == "RWTextureCube") + return true; + + if (name == "Texture1DArray" || name == "RWTexture1DArray") + return true; + if (name == "Texture2DArray" || name == "RWTexture2DArray") + return true; + if (name == "Texture2DMSArray" || name == "RWTexture2DMSArray") + return true; + if (name == "TextureCubeArray" || name == "RWTextureCubeArray") + return true; + + if (name == "Buffer" || name == "RWBuffer") + return true; + + if (name == "SamplerState" || name == "SamplerComparisonState") + return true; + } + return false; +} + +bool TypeTranslator::isOpaqueStructType(QualType type) { + if (isOpaqueType(type)) + return false; + + if (const auto *recordType = type->getAs()) + for (const auto *field : recordType->getDecl()->decls()) + if (const auto *fieldDecl = dyn_cast(field)) + if (isOpaqueType(fieldDecl->getType()) || + isOpaqueStructType(fieldDecl->getType())) + return true; + + return false; +} + uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule, bool isRowMajor) { // We can only apply row_major to matrices or arrays of matrices. diff --git a/tools/clang/lib/SPIRV/TypeTranslator.h b/tools/clang/lib/SPIRV/TypeTranslator.h index e8385627a..8e569eabc 100644 --- a/tools/clang/lib/SPIRV/TypeTranslator.h +++ b/tools/clang/lib/SPIRV/TypeTranslator.h @@ -133,6 +133,14 @@ public: /// operated on with a relaxed precision. static bool isRelaxedPrecisionType(QualType); + /// Returns true if the given type will be translated into a SPIR-V image, + /// sampler or struct containing images or samplers. + static bool isOpaqueType(QualType type); + + /// Returns true if the given type is a struct type who has an opaque field + /// (in a recursive away). + static bool isOpaqueStructType(QualType tye); + /// \brief Returns the the element type for the given scalar/vector/matrix /// type. Returns empty QualType for other cases. QualType getElementType(QualType type); diff --git a/tools/clang/unittests/SPIRV/FileTestUtils.cpp b/tools/clang/unittests/SPIRV/FileTestUtils.cpp index 77f06765e..ec9f8728f 100644 --- a/tools/clang/unittests/SPIRV/FileTestUtils.cpp +++ b/tools/clang/unittests/SPIRV/FileTestUtils.cpp @@ -22,7 +22,7 @@ namespace utils { bool disassembleSpirvBinary(std::vector &binary, std::string *generatedSpirvAsm, bool generateHeader) { - spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0); + spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0); spirvTools.SetMessageConsumer( [](spv_message_level_t, const char *, const spv_position_t &, const char *message) { fprintf(stdout, "%s\n", message); }); @@ -33,7 +33,7 @@ bool disassembleSpirvBinary(std::vector &binary, } bool validateSpirvBinary(std::vector &binary) { - spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0); + spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0); spirvTools.SetMessageConsumer( [](spv_message_level_t, const char *, const spv_position_t &, const char *message) { fprintf(stdout, "%s\n", message); }); @@ -134,6 +134,7 @@ bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath, flags.push_back(L"-T"); flags.push_back(profile.c_str()); flags.push_back(L"-spirv"); + flags.push_back(L"-O0"); // Disable optimization for testing flags.push_back(rest.c_str()); IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));