[spirv] Call SPIR-V legalization passes from SPIRV-Tools (#655)
When seeing opaque types within structs in function parameter, function return, and variable definition, invoke SPIRV-Tools legalization passes. Also refreshed external projects
This commit is contained in:
Родитель
52c27ffbbd
Коммит
62629d982a
|
@ -28,7 +28,7 @@ if (${ENABLE_SPIRV_CODEGEN})
|
||||||
if (NOT TARGET SPIRV-Tools)
|
if (NOT TARGET SPIRV-Tools)
|
||||||
message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen")
|
message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen")
|
||||||
else()
|
else()
|
||||||
set(SPIRV_TOOLS_INCLUDE_DIR ${SPIRV-Tools_SOURCE_DIR}/include PARENT_SCOPE)
|
set(SPIRV_TOOLS_INCLUDE_DIR ${spirv-tools_SOURCE_DIR}/include PARENT_SCOPE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(SPIRV_DEP_TARGETS
|
set(SPIRV_DEP_TARGETS
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 768d9b42d38c7562bd42dbc29b22c61046848ee8
|
Subproject commit dcf42433a63c9779cf1269a4e5f1caea3a887b63
|
|
@ -1 +1 @@
|
||||||
Subproject commit b7e8a993b4125d1083cb431d91407d8ee4dba2ad
|
Subproject commit f1a87d73fc604c5ab8fbb0cc6fa9a86ffd845530
|
|
@ -1 +1 @@
|
||||||
Subproject commit 971f917a35125c6dcfabf099d5fe9a1e5c383265
|
Subproject commit d2b639578a17f459ff90f4bf9b904f66c3ebb93d
|
|
@ -24,6 +24,8 @@ add_clang_library(clangSPIRV
|
||||||
clangBasic
|
clangBasic
|
||||||
clangFrontend
|
clangFrontend
|
||||||
clangLex
|
clangLex
|
||||||
|
SPIRV-Tools-opt
|
||||||
)
|
)
|
||||||
|
|
||||||
target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR})
|
target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR})
|
||||||
|
target_include_directories(clangSPIRV PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR})
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "SPIRVEmitter.h"
|
#include "SPIRVEmitter.h"
|
||||||
|
|
||||||
#include "dxc/HlslIntrinsicOp.h"
|
#include "dxc/HlslIntrinsicOp.h"
|
||||||
|
#include "spirv-tools/optimizer.hpp"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
|
||||||
#include "InitListHandler.h"
|
#include "InitListHandler.h"
|
||||||
|
@ -148,15 +149,33 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Returns the statement that is the immediate parent AST node of the
|
bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
|
||||||
/// given statement. Returns nullptr if there are no parents nodes.
|
spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
|
||||||
const Stmt *getImmediateParent(ASTContext &astContext, const Stmt *stmt) {
|
|
||||||
const auto &parents = astContext.getParents(*stmt);
|
|
||||||
return parents.empty() ? nullptr : parents[0].get<Stmt>();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isLoopStmt(const Stmt *stmt) {
|
optimizer.SetMessageConsumer(
|
||||||
return isa<ForStmt>(stmt) || isa<WhileStmt>(stmt) || isa<DoStmt>(stmt);
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
||||||
|
const spv_position_t & /*position*/,
|
||||||
|
const char *message) { *messages += message; });
|
||||||
|
|
||||||
|
optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
|
||||||
|
|
||||||
|
optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateBlockMergePass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
|
||||||
|
|
||||||
|
optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass());
|
||||||
|
optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass());
|
||||||
|
|
||||||
|
optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
|
||||||
|
|
||||||
|
return optimizer.Run(module->data(), module->size(), module);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -171,7 +190,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
|
||||||
theContext(), theBuilder(&theContext),
|
theContext(), theBuilder(&theContext),
|
||||||
declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
|
declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
|
||||||
typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
|
typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
|
||||||
curFunction(nullptr), curThis(0) {
|
curFunction(nullptr), curThis(0), needsLegalization(false) {
|
||||||
if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
|
if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
|
||||||
emitError("unknown shader module: %0") << shaderModel.GetName();
|
emitError("unknown shader module: %0") << shaderModel.GetName();
|
||||||
}
|
}
|
||||||
|
@ -230,6 +249,19 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
|
||||||
|
|
||||||
// Output the constructed module.
|
// Output the constructed module.
|
||||||
std::vector<uint32_t> m = theBuilder.takeModule();
|
std::vector<uint32_t> m = theBuilder.takeModule();
|
||||||
|
|
||||||
|
const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel;
|
||||||
|
if (needsLegalization || optLevel > 0) {
|
||||||
|
if (needsLegalization && optLevel == 0)
|
||||||
|
emitWarning("-O0 ignored since SPIR-V legalization required");
|
||||||
|
|
||||||
|
std::string messages;
|
||||||
|
if (!spirvToolsOptimize(&m, &messages)) {
|
||||||
|
emitFatalError("failed to legalize/optimize SPIR-V: %0") << messages;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
theCompilerInstance.getOutStream()->write(
|
theCompilerInstance.getOutStream()->write(
|
||||||
reinterpret_cast<const char *>(m.data()), m.size() * 4);
|
reinterpret_cast<const char *>(m.data()), m.size() * 4);
|
||||||
}
|
}
|
||||||
|
@ -425,6 +457,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
|
||||||
funcId = declIdMapper.getDeclResultId(decl);
|
funcId = declIdMapper.getDeclResultId(decl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!needsLegalization &&
|
||||||
|
TypeTranslator::isOpaqueStructType(decl->getReturnType()))
|
||||||
|
needsLegalization = true;
|
||||||
|
|
||||||
const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
|
const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
|
||||||
|
|
||||||
// Construct the function signature.
|
// Construct the function signature.
|
||||||
|
@ -454,6 +490,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
|
||||||
const uint32_t ptrType =
|
const uint32_t ptrType =
|
||||||
theBuilder.getPointerType(valueType, spv::StorageClass::Function);
|
theBuilder.getPointerType(valueType, spv::StorageClass::Function);
|
||||||
paramTypes.push_back(ptrType);
|
paramTypes.push_back(ptrType);
|
||||||
|
|
||||||
|
if (!needsLegalization &&
|
||||||
|
TypeTranslator::isOpaqueStructType(param->getType()))
|
||||||
|
needsLegalization = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
|
const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
|
||||||
|
@ -555,6 +595,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
|
||||||
if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
|
if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
|
||||||
theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
|
theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
|
||||||
|
needsLegalization = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
|
spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
|
||||||
|
|
|
@ -500,6 +500,15 @@ private:
|
||||||
const CXXMemberCallExpr *);
|
const CXXMemberCallExpr *);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// \brief Wrapper method to create a fatal error message and report it
|
||||||
|
/// in the diagnostic engine associated with this consumer.
|
||||||
|
template <unsigned N>
|
||||||
|
DiagnosticBuilder emitFatalError(const char (&message)[N]) {
|
||||||
|
const auto diagId =
|
||||||
|
diags.getCustomDiagID(clang::DiagnosticsEngine::Fatal, message);
|
||||||
|
return diags.Report(diagId);
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Wrapper method to create an error message and report it
|
/// \brief Wrapper method to create an error message and report it
|
||||||
/// in the diagnostic engine associated with this consumer.
|
/// in the diagnostic engine associated with this consumer.
|
||||||
template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
|
template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
|
||||||
|
@ -548,6 +557,15 @@ private:
|
||||||
/// The SPIR-V function parameter for the current this object.
|
/// The SPIR-V function parameter for the current this object.
|
||||||
uint32_t curThis;
|
uint32_t curThis;
|
||||||
|
|
||||||
|
/// Whether the translated SPIR-V binary needs legalization.
|
||||||
|
///
|
||||||
|
/// The following cases will require legalization:
|
||||||
|
/// * Opaque types (textures, samplers) within structs
|
||||||
|
///
|
||||||
|
/// If this is true, SPIRV-Tools legalization passes will be executed after
|
||||||
|
/// the translation to legalize the generated SPIR-V binary.
|
||||||
|
bool needsLegalization;
|
||||||
|
|
||||||
/// Global variables that should be initialized once at the begining of the
|
/// Global variables that should be initialized once at the begining of the
|
||||||
/// entry function.
|
/// entry function.
|
||||||
llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;
|
llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;
|
||||||
|
|
|
@ -61,6 +61,53 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TypeTranslator::isOpaqueType(QualType type) {
|
||||||
|
if (const auto *recordType = type->getAs<RecordType>()) {
|
||||||
|
const auto name = recordType->getDecl()->getName();
|
||||||
|
|
||||||
|
if (name == "Texture1D" || name == "RWTexture1D")
|
||||||
|
return true;
|
||||||
|
if (name == "Texture2D" || name == "RWTexture2D")
|
||||||
|
return true;
|
||||||
|
if (name == "Texture2DMS" || name == "RWTexture2DMS")
|
||||||
|
return true;
|
||||||
|
if (name == "Texture3D" || name == "RWTexture3D")
|
||||||
|
return true;
|
||||||
|
if (name == "TextureCube" || name == "RWTextureCube")
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (name == "Texture1DArray" || name == "RWTexture1DArray")
|
||||||
|
return true;
|
||||||
|
if (name == "Texture2DArray" || name == "RWTexture2DArray")
|
||||||
|
return true;
|
||||||
|
if (name == "Texture2DMSArray" || name == "RWTexture2DMSArray")
|
||||||
|
return true;
|
||||||
|
if (name == "TextureCubeArray" || name == "RWTextureCubeArray")
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (name == "Buffer" || name == "RWBuffer")
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (name == "SamplerState" || name == "SamplerComparisonState")
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TypeTranslator::isOpaqueStructType(QualType type) {
|
||||||
|
if (isOpaqueType(type))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (const auto *recordType = type->getAs<RecordType>())
|
||||||
|
for (const auto *field : recordType->getDecl()->decls())
|
||||||
|
if (const auto *fieldDecl = dyn_cast<FieldDecl>(field))
|
||||||
|
if (isOpaqueType(fieldDecl->getType()) ||
|
||||||
|
isOpaqueStructType(fieldDecl->getType()))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
|
uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
|
||||||
bool isRowMajor) {
|
bool isRowMajor) {
|
||||||
// We can only apply row_major to matrices or arrays of matrices.
|
// We can only apply row_major to matrices or arrays of matrices.
|
||||||
|
|
|
@ -133,6 +133,14 @@ public:
|
||||||
/// operated on with a relaxed precision.
|
/// operated on with a relaxed precision.
|
||||||
static bool isRelaxedPrecisionType(QualType);
|
static bool isRelaxedPrecisionType(QualType);
|
||||||
|
|
||||||
|
/// Returns true if the given type will be translated into a SPIR-V image,
|
||||||
|
/// sampler or struct containing images or samplers.
|
||||||
|
static bool isOpaqueType(QualType type);
|
||||||
|
|
||||||
|
/// Returns true if the given type is a struct type who has an opaque field
|
||||||
|
/// (in a recursive away).
|
||||||
|
static bool isOpaqueStructType(QualType tye);
|
||||||
|
|
||||||
/// \brief Returns the the element type for the given scalar/vector/matrix
|
/// \brief Returns the the element type for the given scalar/vector/matrix
|
||||||
/// type. Returns empty QualType for other cases.
|
/// type. Returns empty QualType for other cases.
|
||||||
QualType getElementType(QualType type);
|
QualType getElementType(QualType type);
|
||||||
|
|
|
@ -22,7 +22,7 @@ namespace utils {
|
||||||
bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
|
bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
|
||||||
std::string *generatedSpirvAsm,
|
std::string *generatedSpirvAsm,
|
||||||
bool generateHeader) {
|
bool generateHeader) {
|
||||||
spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0);
|
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
|
||||||
spirvTools.SetMessageConsumer(
|
spirvTools.SetMessageConsumer(
|
||||||
[](spv_message_level_t, const char *, const spv_position_t &,
|
[](spv_message_level_t, const char *, const spv_position_t &,
|
||||||
const char *message) { fprintf(stdout, "%s\n", message); });
|
const char *message) { fprintf(stdout, "%s\n", message); });
|
||||||
|
@ -33,7 +33,7 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool validateSpirvBinary(std::vector<uint32_t> &binary) {
|
bool validateSpirvBinary(std::vector<uint32_t> &binary) {
|
||||||
spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0);
|
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
|
||||||
spirvTools.SetMessageConsumer(
|
spirvTools.SetMessageConsumer(
|
||||||
[](spv_message_level_t, const char *, const spv_position_t &,
|
[](spv_message_level_t, const char *, const spv_position_t &,
|
||||||
const char *message) { fprintf(stdout, "%s\n", message); });
|
const char *message) { fprintf(stdout, "%s\n", message); });
|
||||||
|
@ -134,6 +134,7 @@ bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
|
||||||
flags.push_back(L"-T");
|
flags.push_back(L"-T");
|
||||||
flags.push_back(profile.c_str());
|
flags.push_back(profile.c_str());
|
||||||
flags.push_back(L"-spirv");
|
flags.push_back(L"-spirv");
|
||||||
|
flags.push_back(L"-O0"); // Disable optimization for testing
|
||||||
flags.push_back(rest.c_str());
|
flags.push_back(rest.c_str());
|
||||||
|
|
||||||
IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
|
IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
|
||||||
|
|
Загрузка…
Ссылка в новой задаче