diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index e14c1914e..55c890a41 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -52,6 +52,8 @@ class SpirvBuilder { public: SpirvBuilder(ASTContext &ac, SpirvContext &c, const SpirvCodeGenOptions &, FeatureManager &featureMgr); + SpirvBuilder(SpirvContext &c, const SpirvCodeGenOptions &, + FeatureManager &featureMgr); ~SpirvBuilder() = default; // Forbid copy construction and assignment @@ -732,6 +734,7 @@ public: public: std::vector takeModule(); + std::vector takeModuleForDxilToSpv(); protected: /// Only friend classes are allowed to add capability/extension to the module @@ -792,7 +795,7 @@ private: SpirvInstruction *var); private: - ASTContext &astContext; + ASTContext *astContext; SpirvContext &context; ///< From which we allocate various SPIR-V object FeatureManager &featureManager; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 507e150dc..b69134389 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -298,7 +298,7 @@ void EmitVisitor::emitDebugLine(spv::Op op, const SourceLocation &loc, } auto fileId = debugMainFileId; - const auto &sm = astContext.getSourceManager(); + const auto &sm = astContext->getSourceManager(); const char *fileName = sm.getPresumedLoc(loc).getFilename(); if (fileName) fileId = getOrCreateOpStringId(fileName); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 140d19c25..aedd93634 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -46,7 +46,7 @@ public: }; public: - EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvContext, + EmitTypeHandler(ASTContext *astCtx, SpirvContext &spvContext, const SpirvCodeGenOptions &opts, FeatureManager &featureMgr, std::vector *debugVec, std::vector *decVec, @@ -145,13 +145,13 @@ private: template DiagnosticBuilder emitError(const char (&message)[N], SourceLocation loc = {}) { - const auto diagId = astContext.getDiagnostics().getCustomDiagID( + const auto diagId = astContext->getDiagnostics().getCustomDiagID( clang::DiagnosticsEngine::Error, message); - return astContext.getDiagnostics().Report(loc, diagId); + return astContext->getDiagnostics().Report(loc, diagId); } private: - ASTContext &astContext; + ASTContext *astContext; SpirvContext &context; FeatureManager featureManager; std::vector curTypeInst; @@ -198,7 +198,7 @@ public: }; public: - EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx, + EmitVisitor(ASTContext *astCtx, SpirvContext &spvCtx, const SpirvCodeGenOptions &opts, FeatureManager &featureMgr) : Visitor(opts, spvCtx), astContext(astCtx), id(0), typeHandler(astCtx, spvCtx, opts, featureMgr, &debugVariableBinary, @@ -373,14 +373,14 @@ private: template DiagnosticBuilder emitError(const char (&message)[N], SourceLocation loc = {}) { - const auto diagId = astContext.getDiagnostics().getCustomDiagID( + const auto diagId = astContext->getDiagnostics().getCustomDiagID( clang::DiagnosticsEngine::Error, message); - return astContext.getDiagnostics().Report(loc, diagId); + return astContext->getDiagnostics().Report(loc, diagId); } private: // Object that holds Clang AST nodes. - ASTContext &astContext; + ASTContext *astContext; // The last result-id that's been used so far. uint32_t id; // Handler for emitting types and their related instructions. diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index a93d42797..3b5756a2a 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -27,7 +27,15 @@ namespace spirv { SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx, const SpirvCodeGenOptions &opt, FeatureManager &featureMgr) - : astContext(ac), context(ctx), featureManager(featureMgr), + : astContext(&ac), context(ctx), featureManager(featureMgr), + mod(llvm::make_unique()), function(nullptr), + moduleInit(nullptr), moduleInitInsertPoint(nullptr), spirvOptions(opt), + builtinVars(), debugNone(nullptr), nullDebugExpr(nullptr), + stringLiterals() {} + +SpirvBuilder::SpirvBuilder(SpirvContext &ctx, const SpirvCodeGenOptions &opt, + FeatureManager &featureMg) + : astContext(nullptr), context(ctx), featureManager(featureMg), mod(llvm::make_unique()), function(nullptr), moduleInit(nullptr), moduleInitInsertPoint(nullptr), spirvOptions(opt), builtinVars(), debugNone(nullptr), nullDebugExpr(nullptr), @@ -540,9 +548,8 @@ SpirvInstruction *SpirvBuilder::createImageSample( if (isSparse) { // Write the Residency Code - const auto status = createCompositeExtract(astContext.UnsignedIntTy, - imageSampleInst, {0}, loc, - range); + const auto status = createCompositeExtract( + astContext->UnsignedIntTy, imageSampleInst, {0}, loc, range); createStore(residencyCode, status, loc, range); // Extract the real result from the struct return createCompositeExtract(texelType, imageSampleInst, {1}, loc, range); @@ -581,7 +588,7 @@ SpirvInstruction *SpirvBuilder::createImageFetchOrRead( if (isSparse) { // Write the Residency Code const auto status = createCompositeExtract( - astContext.UnsignedIntTy, fetchOrReadInst, {0}, loc, range); + astContext->UnsignedIntTy, fetchOrReadInst, {0}, loc, range); createStore(residencyCode, status, loc, range); // Extract the real result from the struct return createCompositeExtract(texelType, fetchOrReadInst, {1}, loc, range); @@ -642,7 +649,7 @@ SpirvInstruction *SpirvBuilder::createImageGather( if (residencyCode) { // Write the Residency Code - const auto status = createCompositeExtract(astContext.UnsignedIntTy, + const auto status = createCompositeExtract(astContext->UnsignedIntTy, imageInstruction, {0}, loc); createStore(residencyCode, status, loc); // Extract the real result from the struct @@ -655,9 +662,8 @@ SpirvInstruction *SpirvBuilder::createImageGather( SpirvImageSparseTexelsResident *SpirvBuilder::createImageSparseTexelsResident( SpirvInstruction *residentCode, SourceLocation loc, SourceRange range) { assert(insertPoint && "null insert point"); - auto *inst = new (context) - SpirvImageSparseTexelsResident(astContext.BoolTy, loc, residentCode, - range); + auto *inst = new (context) SpirvImageSparseTexelsResident( + astContext->BoolTy, loc, residentCode, range); insertPoint->addInstruction(inst); return inst; } @@ -1009,7 +1015,7 @@ SpirvInstruction *SpirvBuilder::createReadClock(SpirvInstruction *scope, assert(insertPoint && "null insert point"); assert(scope->getAstResultType()->isIntegerType()); auto *inst = - new (context) SpirvReadClock(astContext.UnsignedLongLongTy, scope, loc); + new (context) SpirvReadClock(astContext->UnsignedLongLongTy, scope, loc); insertPoint->addInstruction(inst); return inst; } @@ -1071,11 +1077,11 @@ void SpirvBuilder::createCopyArrayInFxcCTBufferToClone( for (uint32_t i = 0; i < fxcCTBufferArrTy->getElementCount(); ++i) { auto *ptrToFxcCTBufferElem = createAccessChain( fxcCTBufferElemPtrTy, fxcCTBuffer, - {getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, loc); + {getConstantInt(astContext->UnsignedIntTy, llvm::APInt(32, i))}, loc); context.addToInstructionsWithLoweredType(ptrToFxcCTBufferElem); auto *ptrToCloneElem = createAccessChain( cloneElemPtrTy, clone, - {getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, loc); + {getConstantInt(astContext->UnsignedIntTy, llvm::APInt(32, i))}, loc); context.addToInstructionsWithLoweredType(ptrToCloneElem); createCopyInstructionsFromFxcCTBufferToClone(ptrToFxcCTBufferElem, ptrToCloneElem); @@ -1095,13 +1101,13 @@ void SpirvBuilder::createCopyStructInFxcCTBufferToClone( fxcCTBufferFields[i].type, fxcCTBuffer->getStorageClass()); auto *ptrToFxcCTBufferElem = createAccessChain( fxcCTBufferElemPtrTy, fxcCTBuffer, - {getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, loc); + {getConstantInt(astContext->UnsignedIntTy, llvm::APInt(32, i))}, loc); context.addToInstructionsWithLoweredType(ptrToFxcCTBufferElem); auto *cloneElemPtrTy = context.getPointerType(cloneFields[i].type, clone->getStorageClass()); auto *ptrToCloneElem = createAccessChain( cloneElemPtrTy, clone, - {getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, i))}, loc); + {getConstantInt(astContext->UnsignedIntTy, llvm::APInt(32, i))}, loc); context.addToInstructionsWithLoweredType(ptrToCloneElem); createCopyInstructionsFromFxcCTBufferToClone(ptrToFxcCTBufferElem, ptrToCloneElem); @@ -1152,7 +1158,7 @@ void SpirvBuilder::createCopyInstructionsFromFxcCTBufferToClone( void SpirvBuilder::switchInsertPointToModuleInit() { if (moduleInitInsertPoint == nullptr) { - moduleInit = createSpirvFunction(astContext.VoidTy, SourceLocation(), + moduleInit = createSpirvFunction(astContext->VoidTy, SourceLocation(), "module.init", false); moduleInitInsertPoint = new (context) SpirvBasicBlock("module.init.bb"); moduleInit->addBasicBlock(moduleInitInsertPoint); @@ -1205,7 +1211,7 @@ SpirvBuilder::initializeCloneVarForFxcCTBuffer(SpirvInstruction *instr) { auto astType = var->getAstResultType(); const auto *spvType = var->getResultType(); - LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions); + LowerTypeVisitor lowerTypeVisitor(*astContext, context, spirvOptions); lowerTypeVisitor.visitInstruction(var); context.addToInstructionsWithLoweredType(instr); if (!lowerTypeVisitor.useSpvArrayForHlslMat1xN()) { @@ -1552,7 +1558,7 @@ SpirvConstant *SpirvBuilder::getConstantFloat(QualType type, SpirvConstant *SpirvBuilder::getConstantBool(bool value, bool specConst) { // We do not care about making unique constants at this point. auto *boolConst = - new (context) SpirvConstantBoolean(astContext.BoolTy, value, specConst); + new (context) SpirvConstantBoolean(astContext->BoolTy, value, specConst); mod->addConstant(boolConst); return boolConst; } @@ -1607,7 +1613,7 @@ void SpirvBuilder::addModuleInitCallToEntryPoints() { for (auto *entry : mod->getEntryPoints()) { auto *instruction = new (context) - SpirvFunctionCall(astContext.VoidTy, /* SourceLocation */ {}, + SpirvFunctionCall(astContext->VoidTy, /* SourceLocation */ {}, moduleInit, /* params */ {}); instruction->setRValue(true); entry->getEntryPoint()->addFirstInstruction(instruction); @@ -1633,48 +1639,75 @@ std::vector SpirvBuilder::takeModule() { addModuleInitCallToEntryPoints(); // Run necessary visitor passes first - LiteralTypeVisitor literalTypeVisitor(astContext, context, spirvOptions); - LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions); - CapabilityVisitor capabilityVisitor(astContext, context, spirvOptions, *this, - featureManager); - RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions); - PreciseVisitor preciseVisitor(context, spirvOptions); - NonUniformVisitor nonUniformVisitor(context, spirvOptions); - RemoveBufferBlockVisitor removeBufferBlockVisitor( - astContext, context, spirvOptions, featureManager); - EmitVisitor emitVisitor(astContext, context, spirvOptions, featureManager); - - mod->invokeVisitor(&literalTypeVisitor, true); + LiteralTypeVisitor literalTypeVisitor(*astContext, context, spirvOptions); + mod->invokeVisitor(&literalTypeVisitor, true); // Propagate NonUniform decorations + NonUniformVisitor nonUniformVisitor(context, spirvOptions); mod->invokeVisitor(&nonUniformVisitor); // Lower types - mod->invokeVisitor(&lowerTypeVisitor); + LowerTypeVisitor lowerTypeVisitor(*astContext, context, spirvOptions); + mod->invokeVisitor(&lowerTypeVisitor); - // Generate debug types (if needed) - if (spirvOptions.debugInfoRich) { - DebugTypeVisitor debugTypeVisitor(astContext, context, spirvOptions, *this, - lowerTypeVisitor); - SortDebugInfoVisitor sortDebugInfoVisitor(context, spirvOptions); - mod->invokeVisitor(&debugTypeVisitor); - mod->invokeVisitor(&sortDebugInfoVisitor); - } + // Generate debug types (if needed) + if (spirvOptions.debugInfoRich) { + DebugTypeVisitor debugTypeVisitor(*astContext, context, spirvOptions, + *this, lowerTypeVisitor); + SortDebugInfoVisitor sortDebugInfoVisitor(context, spirvOptions); + mod->invokeVisitor(&debugTypeVisitor); + mod->invokeVisitor(&sortDebugInfoVisitor); + } // Add necessary capabilities and extensions + CapabilityVisitor capabilityVisitor(*astContext, context, spirvOptions, *this, + featureManager); mod->invokeVisitor(&capabilityVisitor); // Propagate RelaxedPrecision decorations + RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions); mod->invokeVisitor(&relaxedPrecisionVisitor); // Propagate NoContraction decorations + PreciseVisitor preciseVisitor(context, spirvOptions); mod->invokeVisitor(&preciseVisitor, true); // Remove BufferBlock decoration if necessary (this decoration is deprecated // after SPIR-V 1.3). - mod->invokeVisitor(&removeBufferBlockVisitor); + RemoveBufferBlockVisitor removeBufferBlockVisitor( + *astContext, context, spirvOptions, featureManager); + mod->invokeVisitor(&removeBufferBlockVisitor); // Emit SPIR-V + EmitVisitor emitVisitor(astContext, context, spirvOptions, featureManager); + mod->invokeVisitor(&emitVisitor); + + return emitVisitor.takeBinary(); +} + +std::vector SpirvBuilder::takeModuleForDxilToSpv() { + endModuleInitFunction(); + addModuleInitCallToEntryPoints(); + + // Propagate NonUniform decorations + NonUniformVisitor nonUniformVisitor(context, spirvOptions); + mod->invokeVisitor(&nonUniformVisitor); + + // Add necessary capabilities and extensions + CapabilityVisitor capabilityVisitor(*astContext, context, spirvOptions, *this, + featureManager); + mod->invokeVisitor(&capabilityVisitor); + + // Propagate RelaxedPrecision decorations + RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions); + mod->invokeVisitor(&relaxedPrecisionVisitor); + + // Propagate NoContraction decorations + PreciseVisitor preciseVisitor(context, spirvOptions); + mod->invokeVisitor(&preciseVisitor, true); + + // Emit SPIR-V + EmitVisitor emitVisitor(astContext, context, spirvOptions, featureManager); mod->invokeVisitor(&emitVisitor); return emitVisitor.takeBinary(); diff --git a/tools/clang/tools/dxil2spv/CMakeLists.txt b/tools/clang/tools/dxil2spv/CMakeLists.txt index 74f4d593a..d771c1663 100644 --- a/tools/clang/tools/dxil2spv/CMakeLists.txt +++ b/tools/clang/tools/dxil2spv/CMakeLists.txt @@ -16,6 +16,8 @@ add_clang_executable(dxil2spv target_link_libraries(dxil2spv dxcompiler dxclib + clangSPIRV + SPIRV-Tools ) llvm_map_components_to_libnames(llvm_libs support core irreader dxil) diff --git a/tools/clang/tools/dxil2spv/dxil2spvmain.cpp b/tools/clang/tools/dxil2spv/dxil2spvmain.cpp index a75102b85..46516d710 100644 --- a/tools/clang/tools/dxil2spv/dxil2spvmain.cpp +++ b/tools/clang/tools/dxil2spv/dxil2spvmain.cpp @@ -26,10 +26,10 @@ // // OUTPUT // ====== -// TODO: The current implementation parses a DXIL file but does not yet produce -// output. +// TODO: The current implementation produces incomplete SPIR-V output. //===----------------------------------------------------------------------===// +#include "dxc/DXIL/DxilModule.h" #include "dxc/DXIL/DxilUtil.h" #include "dxc/DxilContainer/DxilContainer.h" #include "dxc/DxilContainer/DxilContainerReader.h" @@ -42,10 +42,16 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" +#include "llvm/Support/MSFileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "spirv-tools/libspirv.hpp" +#include "clang/Frontend/TextDiagnosticPrinter.h" +#include "clang/SPIRV/SpirvBuilder.h" +#include "clang/SPIRV/SpirvContext.h" + static dxc::DxcDllSupport dxcSupport; #ifdef _WIN32 @@ -53,11 +59,23 @@ int __cdecl wmain(int argc, const wchar_t **argv_) { #else int main(int argc, const char **argv_) { #endif // _WIN32 + // Configure filesystem for llvm stdout and stderr handling. + if (llvm::sys::fs::SetupPerThreadFileSystem()) + return DXC_E_GENERAL_INTERNAL_ERROR; + llvm::sys::fs::AutoCleanupPerThreadFileSystem auto_cleanup_fs; + llvm::sys::fs::MSFileSystem *msfPtr; + HRESULT hr; + if (!SUCCEEDED(hr = CreateMSFileSystemForDisk(&msfPtr))) + return DXC_E_GENERAL_INTERNAL_ERROR; + std::unique_ptr msf(msfPtr); + llvm::sys::fs::AutoPerThreadSystem pts(msf.get()); + llvm::STDStreamCloser stdStreamCloser; + + // Check input arguments. if (argc < 2) { - fprintf(stderr, "Required input file argument is missing.\n"); + llvm::errs() << "Required input file argument is missing\n"; return DXC_E_GENERAL_INTERNAL_ERROR; } - hlsl::options::StringRefUtf16 filename(argv_[1]); // Read input file. @@ -73,7 +91,7 @@ int main(int argc, const char **argv_) { std::unique_ptr memoryBuffer; std::unique_ptr module; - // Parse DXIL from bitcode. + // Parse LLVM module from bitcode. hlsl::DxilContainerHeader *pBlobHeader = (hlsl::DxilContainerHeader *)blob->GetBufferPointer(); if (hlsl::IsValidDxilContainer(pBlobHeader, @@ -95,7 +113,7 @@ int main(int argc, const char **argv_) { llvm::StringRef(blobContext, blobSize), context, DiagStr); } } - // Parse DXIL from IR. + // Parse LLVM module from IR. else { llvm::StringRef bufStrRef(blobContext, blobSize); memoryBuffer = llvm::MemoryBuffer::getMemBufferCopy(bufStrRef); @@ -103,9 +121,59 @@ int main(int argc, const char **argv_) { } if (module == nullptr) { - fprintf(stderr, "Could not parse DXIL module.\n"); + llvm::errs() << "Could not parse DXIL module\n"; return DXC_E_GENERAL_INTERNAL_ERROR; } + // Construct DXIL module. + hlsl::DxilModule &program = module->GetOrCreateDxilModule(); + + const hlsl::ShaderModel *shaderModel = program.GetShaderModel(); + if (shaderModel->GetKind() == hlsl::ShaderModel::Kind::Invalid) + llvm::errs() << "Unknown shader model: " << shaderModel->GetName(); + + // Set shader model kind and HLSL major/minor version. + clang::spirv::SpirvContext spvContext; + spvContext.setCurrentShaderModelKind(shaderModel->GetKind()); + spvContext.setMajorVersion(shaderModel->GetMajor()); + spvContext.setMinorVersion(shaderModel->GetMinor()); + + clang::spirv::SpirvCodeGenOptions spvOpts{}; + // TODO: Allow configuration of targetEnv via options. + spvOpts.targetEnv = "vulkan1.0"; + + // Construct SPIR-V builder with diagnostics + clang::IntrusiveRefCntPtr diagnosticOpts = + new clang::DiagnosticOptions(); + clang::TextDiagnosticPrinter diagnosticPrinter(llvm::errs(), + &*diagnosticOpts); + clang::DiagnosticsEngine diagnosticEngine( + clang::IntrusiveRefCntPtr( + new clang::DiagnosticIDs()), + &*diagnosticOpts, &diagnosticPrinter, false); + + clang::spirv::FeatureManager featureMgr(diagnosticEngine, spvOpts); + clang::spirv::SpirvBuilder spvBuilder(spvContext, spvOpts, featureMgr); + + // Set default addressing and memory model for SPIR-V module. + spvBuilder.setMemoryModel(spv::AddressingModel::Logical, + spv::MemoryModel::GLSL450); + + // Contsruct the SPIR-V module. + std::vector m = spvBuilder.takeModuleForDxilToSpv(); + + // Disassemble SPIR-V for output. + std::string assembly; + spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); + uint32_t spirvDisOpts = (SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | + SPV_BINARY_TO_TEXT_OPTION_INDENT); + + if (!spirvTools.Disassemble(m, &assembly, spirvDisOpts)) { + llvm::errs() << "SPIR-V disassembly failed\n"; + return DXC_E_GENERAL_INTERNAL_ERROR; + } + + llvm::outs() << assembly; + return 0; }