diff --git a/include/dxc/dxctools.h b/include/dxc/dxctools.h index 7b70a89bb..db073b8e2 100644 --- a/include/dxc/dxctools.h +++ b/include/dxc/dxctools.h @@ -19,6 +19,7 @@ enum RewriterOptionMask { SkipFunctionBody = 1, SkipStatic = 2, GlobalExternByDefault = 4, + KeepUserMacro = 8, }; struct __declspec(uuid("c012115b-8893-4eb9-9c5a-111456ea1c45")) diff --git a/tools/clang/test/HLSL/rewriter/predefines2.hlsl b/tools/clang/test/HLSL/rewriter/predefines2.hlsl new file mode 100644 index 000000000..df8e240f7 --- /dev/null +++ b/tools/clang/test/HLSL/rewriter/predefines2.hlsl @@ -0,0 +1,6 @@ +#define X 1 +#define Y(A, B) ((A) + (B)) +float x = X; +float test(float a, float b) { + return Y(a, b); +} \ No newline at end of file diff --git a/tools/clang/tools/libclang/dxcrewriteunused.cpp b/tools/clang/tools/libclang/dxcrewriteunused.cpp index 87923c47f..b79e418fa 100644 --- a/tools/clang/tools/libclang/dxcrewriteunused.cpp +++ b/tools/clang/tools/libclang/dxcrewriteunused.cpp @@ -161,9 +161,9 @@ bool MacroPairCompareIsLessThan(const std::pairgetName().compare(right.first->getName()) < 0; } + static -void WriteSemanticDefines(CompilerInstance &compiler, _In_ DxcLangExtensionsHelper *helper, raw_string_ostream &o) { - ParsedSemanticDefineList macros = CollectSemanticDefinesParsedByCompiler(compiler, helper); +void WriteMacroDefines(ParsedSemanticDefineList ¯os, raw_string_ostream &o) { if (!macros.empty()) { o << "\n// Macros:\n"; for (auto&& m : macros) { @@ -172,6 +172,12 @@ void WriteSemanticDefines(CompilerInstance &compiler, _In_ DxcLangExtensionsHelp } } +static +void WriteSemanticDefines(CompilerInstance &compiler, _In_ DxcLangExtensionsHelper *helper, raw_string_ostream &o) { + ParsedSemanticDefineList macros = CollectSemanticDefinesParsedByCompiler(compiler, helper); + WriteMacroDefines(macros, o); +} + ParsedSemanticDefineList hlsl::CollectSemanticDefinesParsedByCompiler(CompilerInstance &compiler, _In_ DxcLangExtensionsHelper *helper) { ParsedSemanticDefineList parsedDefines; const llvm::SmallVector& defines = helper->GetSemanticDefines(); @@ -232,6 +238,89 @@ ParsedSemanticDefineList hlsl::CollectSemanticDefinesParsedByCompiler(CompilerIn return parsedDefines; } +static ParsedSemanticDefineList CollectUserMacrosParsedByCompiler(CompilerInstance &compiler) { + ParsedSemanticDefineList parsedDefines; + // This is very inefficient in general, but in practice we either have + // no semantic defines, or we have a star define for a some reserved prefix. These will be + // sorted so rewrites are stable. + std::vector > macros; + Preprocessor& pp = compiler.getPreprocessor(); + Preprocessor::macro_iterator end = pp.macro_end(); + SourceManager &SM = compiler.getSourceManager(); + FileID PredefineFileID = pp.getPredefinesFileID(); + + for (Preprocessor::macro_iterator i = pp.macro_begin(); i != end; ++i) { + if (!i->second.getLatest()->isDefined()) { + continue; + } + MacroInfo* mi = i->second.getLatest()->getMacroInfo(); + if (mi->getDefinitionLoc().isInvalid()) { + continue; + } + FileID FID = SM.getFileID(mi->getDefinitionEndLoc()); + if (FID == PredefineFileID) + continue; + + const IdentifierInfo* ii = i->first; + + macros.push_back(std::pair(ii, mi)); + } + + if (!macros.empty()) { + std::sort(macros.begin(), macros.end(), MacroPairCompareIsLessThan); + MacroExpander expander(pp); + for (std::pair m : macros) { + std::string expandedValue; + MacroInfo* mi = m.second; + if (!mi->isFunctionLike()) { + expander.ExpandMacro(m.second, &expandedValue); + parsedDefines.emplace_back(ParsedSemanticDefine{ m.first->getName(), expandedValue, m.second->getDefinitionLoc().getRawEncoding() }); + } else { + std::string macroStr; + raw_string_ostream macro(macroStr); + macro << m.first->getName(); + auto args = mi->args(); + + macro << "("; + for (unsigned I = 0; I != mi->getNumArgs(); ++I) { + if (I) + macro << ", "; + macro << args[I]->getName(); + } + macro << ")"; + macro.flush(); + + std::string macroValStr; + raw_string_ostream macroVal(macroValStr); + for (const Token &Tok : mi->tokens()) { + macroVal << " "; + if (const char *Punc = tok::getPunctuatorSpelling(Tok.getKind())) + macroVal << Punc; + else if (const char *Kwd = tok::getKeywordSpelling(Tok.getKind())) + macroVal << Kwd; + else if (Tok.is(tok::identifier)) + macroVal << Tok.getIdentifierInfo()->getName(); + else if (Tok.isLiteral() && Tok.getLiteralData()) + macroVal << StringRef(Tok.getLiteralData(), Tok.getLength()); + else + macroVal << Tok.getName(); + } + macroVal.flush(); + parsedDefines.emplace_back(ParsedSemanticDefine{ macroStr, macroValStr, m.second->getDefinitionLoc().getRawEncoding() }); + } + } + } + + return parsedDefines; +} + + +static +void WriteUserMacroDefines(CompilerInstance &compiler, raw_string_ostream &o) { + ParsedSemanticDefineList macros = CollectUserMacrosParsedByCompiler(compiler); + WriteMacroDefines(macros, o); +} + static HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper, _In_ LPCSTR pFileName, @@ -404,6 +493,7 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper, bool bSkipFunctionBody = rewriteOption & RewriterOptionMask::SkipFunctionBody; bool bSkipStatic = rewriteOption & RewriterOptionMask::SkipStatic; bool bGlobalExternByDefault = rewriteOption & RewriterOptionMask::GlobalExternByDefault; + bool bKeepUserMacro = rewriteOption & RewriterOptionMask::KeepUserMacro; std::string s, warnings; raw_string_ostream o(s); @@ -417,6 +507,7 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper, // Parse the source file. compiler.getDiagnosticClient().BeginSourceFile(compiler.getLangOpts(), &compiler.getPreprocessor()); + ParseAST(compiler.getSema(), false, bSkipFunctionBody); ASTContext& C = compiler.getASTContext(); @@ -437,6 +528,8 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper, tu->print(o, p); WriteSemanticDefines(compiler, pHelper, o); + if (bKeepUserMacro) + WriteUserMacroDefines(compiler, o); // Flush and return results. raw_string_ostream_to_CoString(o, pResult); diff --git a/tools/clang/unittests/HLSL/RewriterTest.cpp b/tools/clang/unittests/HLSL/RewriterTest.cpp index da4aae890..7aa310886 100644 --- a/tools/clang/unittests/HLSL/RewriterTest.cpp +++ b/tools/clang/unittests/HLSL/RewriterTest.cpp @@ -79,6 +79,7 @@ public: TEST_METHOD(RunNoFunctionBody); TEST_METHOD(RunNoFunctionBodyInclude); TEST_METHOD(RunNoStatic); + TEST_METHOD(RunKeepUserMacro); dxc::DxcDllSupport m_dllSupport; CComPtr m_pIncludeHandler; @@ -570,4 +571,41 @@ namespace b {\n\ }\n\ static int f;\n\ float4 main() : SV_Target;\n") == 0); +} + +TEST_F(RewriterTest, RunKeepUserMacro) { CComPtr pRewriter; + VERIFY_SUCCEEDED(CreateRewriter(&pRewriter)); + CComPtr pRewriteResult; + + // Get the source text from a file + FileWithBlob source( + m_dllSupport, + GetPathToHlslDataFile(L"rewriter\\predefines2.hlsl") + .c_str()); + + const int myDefinesCount = 3; + DxcDefine myDefines[myDefinesCount] = { + {L"myDefine", L"2"}, {L"myDefine3", L"1994"}, {L"myDefine4", nullptr}}; + + // Run rewrite no function body on the source code + VERIFY_SUCCEEDED(pRewriter->RewriteUnchangedWithInclude( + source.BlobEncoding, L"vector-assignments_noerr.hlsl", myDefines, + myDefinesCount, /*pIncludeHandler*/ nullptr, + RewriterOptionMask::KeepUserMacro, + &pRewriteResult)); + + CComPtr result; + VERIFY_SUCCEEDED(pRewriteResult->GetResult(&result)); + // Function decl only. + VERIFY_IS_TRUE(strcmp(BlobToUtf8(result).c_str(), + "// Rewrite unchanged result:\n\ +const float x = 1;\n\ +float test(float a, float b) {\n\ + return ((a) + (b));\n\ +}\n\ +\n\n\n\ +// Macros:\n\ +#define X 1\n\ +#define Y(A, B) ( ( A ) + ( B ) )\n\ +") == 0); } \ No newline at end of file