[spirv] Call SPIR-V legalization passes from SPIRV-Tools (#655)

When seeing opaque types within structs in function parameter,
function return, and variable definition, invoke SPIRV-Tools
legalization passes.

Also refreshed external projects
This commit is contained in:
Lei Zhang 2017-09-27 16:04:35 -04:00 коммит произвёл GitHub
Родитель 52c27ffbbd
Коммит 62629d982a
10 изменённых файлов: 134 добавлений и 15 удалений

2
external/CMakeLists.txt поставляемый
Просмотреть файл

@ -28,7 +28,7 @@ if (${ENABLE_SPIRV_CODEGEN})
if (NOT TARGET SPIRV-Tools) if (NOT TARGET SPIRV-Tools)
message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen") message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen")
else() else()
set(SPIRV_TOOLS_INCLUDE_DIR ${SPIRV-Tools_SOURCE_DIR}/include PARENT_SCOPE) set(SPIRV_TOOLS_INCLUDE_DIR ${spirv-tools_SOURCE_DIR}/include PARENT_SCOPE)
endif() endif()
set(SPIRV_DEP_TARGETS set(SPIRV_DEP_TARGETS

2
external/SPIRV-Tools поставляемый

@ -1 +1 @@
Subproject commit 768d9b42d38c7562bd42dbc29b22c61046848ee8 Subproject commit dcf42433a63c9779cf1269a4e5f1caea3a887b63

2
external/googletest поставляемый

@ -1 +1 @@
Subproject commit b7e8a993b4125d1083cb431d91407d8ee4dba2ad Subproject commit f1a87d73fc604c5ab8fbb0cc6fa9a86ffd845530

2
external/re2 поставляемый

@ -1 +1 @@
Subproject commit 971f917a35125c6dcfabf099d5fe9a1e5c383265 Subproject commit d2b639578a17f459ff90f4bf9b904f66c3ebb93d

Просмотреть файл

@ -24,6 +24,8 @@ add_clang_library(clangSPIRV
clangBasic clangBasic
clangFrontend clangFrontend
clangLex clangLex
SPIRV-Tools-opt
) )
target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}) target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR})
target_include_directories(clangSPIRV PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR})

Просмотреть файл

@ -14,6 +14,7 @@
#include "SPIRVEmitter.h" #include "SPIRVEmitter.h"
#include "dxc/HlslIntrinsicOp.h" #include "dxc/HlslIntrinsicOp.h"
#include "spirv-tools/optimizer.hpp"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "InitListHandler.h" #include "InitListHandler.h"
@ -148,15 +149,33 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
return nullptr; return nullptr;
} }
/// \brief Returns the statement that is the immediate parent AST node of the bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
/// given statement. Returns nullptr if there are no parents nodes. spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
const Stmt *getImmediateParent(ASTContext &astContext, const Stmt *stmt) {
const auto &parents = astContext.getParents(*stmt);
return parents.empty() ? nullptr : parents[0].get<Stmt>();
}
bool isLoopStmt(const Stmt *stmt) { optimizer.SetMessageConsumer(
return isa<ForStmt>(stmt) || isa<WhileStmt>(stmt) || isa<DoStmt>(stmt); [messages](spv_message_level_t /*level*/, const char * /*source*/,
const spv_position_t & /*position*/,
const char *message) { *messages += message; });
optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass());
optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass());
optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass());
optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass());
optimizer.RegisterPass(spvtools::CreateBlockMergePass());
optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass());
optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass());
optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass());
optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
return optimizer.Run(module->data(), module->size(), module);
} }
} // namespace } // namespace
@ -171,7 +190,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
theContext(), theBuilder(&theContext), theContext(), theBuilder(&theContext),
declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions), declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
typeTranslator(astContext, theBuilder, diags), entryFunctionId(0), typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
curFunction(nullptr), curThis(0) { curFunction(nullptr), curThis(0), needsLegalization(false) {
if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid) if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
emitError("unknown shader module: %0") << shaderModel.GetName(); emitError("unknown shader module: %0") << shaderModel.GetName();
} }
@ -230,6 +249,19 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
// Output the constructed module. // Output the constructed module.
std::vector<uint32_t> m = theBuilder.takeModule(); std::vector<uint32_t> m = theBuilder.takeModule();
const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel;
if (needsLegalization || optLevel > 0) {
if (needsLegalization && optLevel == 0)
emitWarning("-O0 ignored since SPIR-V legalization required");
std::string messages;
if (!spirvToolsOptimize(&m, &messages)) {
emitFatalError("failed to legalize/optimize SPIR-V: %0") << messages;
return;
}
}
theCompilerInstance.getOutStream()->write( theCompilerInstance.getOutStream()->write(
reinterpret_cast<const char *>(m.data()), m.size() * 4); reinterpret_cast<const char *>(m.data()), m.size() * 4);
} }
@ -425,6 +457,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
funcId = declIdMapper.getDeclResultId(decl); funcId = declIdMapper.getDeclResultId(decl);
} }
if (!needsLegalization &&
TypeTranslator::isOpaqueStructType(decl->getReturnType()))
needsLegalization = true;
const uint32_t retType = typeTranslator.translateType(decl->getReturnType()); const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
// Construct the function signature. // Construct the function signature.
@ -454,6 +490,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
const uint32_t ptrType = const uint32_t ptrType =
theBuilder.getPointerType(valueType, spv::StorageClass::Function); theBuilder.getPointerType(valueType, spv::StorageClass::Function);
paramTypes.push_back(ptrType); paramTypes.push_back(ptrType);
if (!needsLegalization &&
TypeTranslator::isOpaqueStructType(param->getType()))
needsLegalization = true;
} }
const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes); const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
@ -555,6 +595,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) { if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision); theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
} }
if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
needsLegalization = true;
} }
spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) { spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {

Просмотреть файл

@ -500,6 +500,15 @@ private:
const CXXMemberCallExpr *); const CXXMemberCallExpr *);
private: private:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
template <unsigned N>
DiagnosticBuilder emitFatalError(const char (&message)[N]) {
const auto diagId =
diags.getCustomDiagID(clang::DiagnosticsEngine::Fatal, message);
return diags.Report(diagId);
}
/// \brief Wrapper method to create an error message and report it /// \brief Wrapper method to create an error message and report it
/// in the diagnostic engine associated with this consumer. /// in the diagnostic engine associated with this consumer.
template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) { template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
@ -548,6 +557,15 @@ private:
/// The SPIR-V function parameter for the current this object. /// The SPIR-V function parameter for the current this object.
uint32_t curThis; uint32_t curThis;
/// Whether the translated SPIR-V binary needs legalization.
///
/// The following cases will require legalization:
/// * Opaque types (textures, samplers) within structs
///
/// If this is true, SPIRV-Tools legalization passes will be executed after
/// the translation to legalize the generated SPIR-V binary.
bool needsLegalization;
/// Global variables that should be initialized once at the begining of the /// Global variables that should be initialized once at the begining of the
/// entry function. /// entry function.
llvm::SmallVector<const VarDecl *, 4> toInitGloalVars; llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;

Просмотреть файл

@ -61,6 +61,53 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
return false; return false;
} }
bool TypeTranslator::isOpaqueType(QualType type) {
if (const auto *recordType = type->getAs<RecordType>()) {
const auto name = recordType->getDecl()->getName();
if (name == "Texture1D" || name == "RWTexture1D")
return true;
if (name == "Texture2D" || name == "RWTexture2D")
return true;
if (name == "Texture2DMS" || name == "RWTexture2DMS")
return true;
if (name == "Texture3D" || name == "RWTexture3D")
return true;
if (name == "TextureCube" || name == "RWTextureCube")
return true;
if (name == "Texture1DArray" || name == "RWTexture1DArray")
return true;
if (name == "Texture2DArray" || name == "RWTexture2DArray")
return true;
if (name == "Texture2DMSArray" || name == "RWTexture2DMSArray")
return true;
if (name == "TextureCubeArray" || name == "RWTextureCubeArray")
return true;
if (name == "Buffer" || name == "RWBuffer")
return true;
if (name == "SamplerState" || name == "SamplerComparisonState")
return true;
}
return false;
}
bool TypeTranslator::isOpaqueStructType(QualType type) {
if (isOpaqueType(type))
return false;
if (const auto *recordType = type->getAs<RecordType>())
for (const auto *field : recordType->getDecl()->decls())
if (const auto *fieldDecl = dyn_cast<FieldDecl>(field))
if (isOpaqueType(fieldDecl->getType()) ||
isOpaqueStructType(fieldDecl->getType()))
return true;
return false;
}
uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule, uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
bool isRowMajor) { bool isRowMajor) {
// We can only apply row_major to matrices or arrays of matrices. // We can only apply row_major to matrices or arrays of matrices.

Просмотреть файл

@ -133,6 +133,14 @@ public:
/// operated on with a relaxed precision. /// operated on with a relaxed precision.
static bool isRelaxedPrecisionType(QualType); static bool isRelaxedPrecisionType(QualType);
/// Returns true if the given type will be translated into a SPIR-V image,
/// sampler or struct containing images or samplers.
static bool isOpaqueType(QualType type);
/// Returns true if the given type is a struct type who has an opaque field
/// (in a recursive away).
static bool isOpaqueStructType(QualType tye);
/// \brief Returns the the element type for the given scalar/vector/matrix /// \brief Returns the the element type for the given scalar/vector/matrix
/// type. Returns empty QualType for other cases. /// type. Returns empty QualType for other cases.
QualType getElementType(QualType type); QualType getElementType(QualType type);

Просмотреть файл

@ -22,7 +22,7 @@ namespace utils {
bool disassembleSpirvBinary(std::vector<uint32_t> &binary, bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
std::string *generatedSpirvAsm, std::string *generatedSpirvAsm,
bool generateHeader) { bool generateHeader) {
spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0); spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
spirvTools.SetMessageConsumer( spirvTools.SetMessageConsumer(
[](spv_message_level_t, const char *, const spv_position_t &, [](spv_message_level_t, const char *, const spv_position_t &,
const char *message) { fprintf(stdout, "%s\n", message); }); const char *message) { fprintf(stdout, "%s\n", message); });
@ -33,7 +33,7 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
} }
bool validateSpirvBinary(std::vector<uint32_t> &binary) { bool validateSpirvBinary(std::vector<uint32_t> &binary) {
spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0); spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
spirvTools.SetMessageConsumer( spirvTools.SetMessageConsumer(
[](spv_message_level_t, const char *, const spv_position_t &, [](spv_message_level_t, const char *, const spv_position_t &,
const char *message) { fprintf(stdout, "%s\n", message); }); const char *message) { fprintf(stdout, "%s\n", message); });
@ -134,6 +134,7 @@ bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
flags.push_back(L"-T"); flags.push_back(L"-T");
flags.push_back(profile.c_str()); flags.push_back(profile.c_str());
flags.push_back(L"-spirv"); flags.push_back(L"-spirv");
flags.push_back(L"-O0"); // Disable optimization for testing
flags.push_back(rest.c_str()); flags.push_back(rest.c_str());
IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary)); IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));