Add rewrite pass to create global cb. (#4000)

* Add rewrite pass to create global cb.
This commit is contained in:
Xiang Li 2021-10-15 14:41:25 -07:00 коммит произвёл GitHub
Родитель 44e0509d89
Коммит e39defba07
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 551 добавлений и 69 удалений

Просмотреть файл

@ -98,6 +98,7 @@ struct RewriterOpts {
bool RemoveUnusedGlobals = false; // OPT_rw_remove_unused_globals
bool RemoveUnusedFunctions = false; // OPT_rw_remove_unused_functions
bool WithLineDirective = false; // OPT_rw_line_directive
bool DeclGlobalCB = false; // OPT_rw_decl_global_cb
};
/// Use this class to capture all options.

Просмотреть файл

@ -512,5 +512,7 @@ def rw_remove_unused_functions : Flag<["-", "/"], "remove-unused-functions">, Gr
HelpText<"Remove unused functions and types">;
def rw_line_directive : Flag<["-", "/"], "line-directive">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
HelpText<"Add line directive">;
def rw_decl_global_cb : Flag<["-", "/"], "decl-global-cb">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
HelpText<"Collect all global constants outside cbuffer declarations into cbuffer GlobalCB { ... }. Still experimental, not all dependency scenarios handled.">;
// Also removed: compress, decompress, /Gch (child effect), /Gpp (partial precision)
// /Op - no support for preshaders.

Просмотреть файл

@ -41,6 +41,7 @@ public:
virtual void EnableDisplayIncludeProcess() = 0;
virtual HRESULT CreateStdStreams(_In_ IMalloc *pMalloc) = 0;
virtual HRESULT RegisterOutputStream(LPCWSTR pName, IStream *pStream) = 0;
virtual HRESULT UnRegisterOutputStream() = 0;
};
DxcArgsFileSystem *

Просмотреть файл

@ -1064,6 +1064,8 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
opts.RWOpt.RemoveUnusedGlobals = Args.hasFlag(OPT_rw_remove_unused_globals, OPT_INVALID, false);
opts.RWOpt.RemoveUnusedFunctions = Args.hasFlag(OPT_rw_remove_unused_functions, OPT_INVALID, false);
opts.RWOpt.WithLineDirective = Args.hasFlag(OPT_rw_line_directive, OPT_INVALID, false);
opts.RWOpt.DeclGlobalCB =
Args.hasFlag(OPT_rw_decl_global_cb, OPT_INVALID, false);
if (opts.EntryPoint.empty() &&
(opts.RWOpt.RemoveUnusedGlobals || opts.RWOpt.ExtractEntryUniforms ||
opts.RWOpt.RemoveUnusedFunctions)) {

Просмотреть файл

@ -44,7 +44,11 @@ struct PrintingPolicy {
Half(LO.HLSL || LO.Half), // HLSL Change - always print 'half' for HLSL
MSWChar(LO.MicrosoftExt && !LO.WChar),
IncludeNewlines(true),
HLSLSuppressUniformParameters(false) { }
// HLSL Change Begin - hlsl print policy.
HLSLSuppressUniformParameters(false), HLSLOnlyDecl(false),
HLSLNoinlineMethod(false)
// HLSL Change End.
{}
/// \brief What language we're printing.
LangOptions LangOpts;
@ -169,6 +173,10 @@ struct PrintingPolicy {
// HLSL Change Begin
/// \brief When true, exclude uniform function parameters
unsigned HLSLSuppressUniformParameters : 1;
/// \brief When true, only print function decl without function body.
unsigned HLSLOnlyDecl : 1;
/// \brief When true, print inline method define as outside struct scope define.
unsigned HLSLNoinlineMethod : 1;
// HLSL Change Ends
};

Просмотреть файл

@ -1414,10 +1414,10 @@ void NamedDecl::printQualifiedName(raw_ostream &OS,
OS << "(anonymous namespace)";
else
OS << *ND;
// HLSL Change Begin - not add cbuffer name to qualified name.
// HLSL Change Begin - not add cbuffer name to qualified name.
} else if (isa<HLSLBufferDecl>(*I)) {
continue;
// HLSL Change End.
// HLSL Change End.
} else if (const RecordDecl *RD = dyn_cast<RecordDecl>(*I)) {
if (!RD->getIdentifier())
OS << "(anonymous " << RD->getKindName() << ')';

Просмотреть файл

@ -488,10 +488,14 @@ void DeclPrinter::VisitFunctionDecl(FunctionDecl *D) {
// HLSL Change Begin
DeclContext *Namespace = D->getEnclosingNamespaceContext();
DeclContext *Enclosing = D->getLexicalParent();
if (!Enclosing->isNamespace() && Namespace->isNamespace()) {
if (!Enclosing->isNamespace() && Namespace->isNamespace() &&
!Policy.HLSLOnlyDecl) {
NamespaceDecl* ns = (NamespaceDecl*)Namespace;
Proto = ns->getName().str() + "::" + Proto;
}
if (Policy.HLSLNoinlineMethod) {
Proto = D->getQualifiedNameAsString();
}
// HLSL Change End
QualType Ty = D->getType();
@ -677,8 +681,15 @@ void DeclPrinter::VisitFunctionDecl(FunctionDecl *D) {
} else
Out << ' ';
if (D->getBody())
D->getBody()->printPretty(Out, nullptr, SubPolicy, Indentation);
if (D->getBody()) {
// HLSL Change Begin - only print decl.
if (Policy.HLSLOnlyDecl) {
Out << ";";
} else {
// HLSL Change end.
D->getBody()->printPretty(Out, nullptr, SubPolicy, Indentation);
}
}
Out << '\n';
}
}

Просмотреть файл

@ -6,7 +6,7 @@
// CHECK:float main
// KEEP_GLOBAL:struct
// KEEP_GLOBAL:ConstantBuffer.
// KEEP_GLOBAL:ConstantBuffer
// KEEP_GLOBAL:float main

Просмотреть файл

@ -0,0 +1,39 @@
// RUN: %dxr -decl-global-cb -line-directive %s | FileCheck %s
// Make sure created GlobalCB and the dependent type is before GlobalCB.
// CHECK:struct S0
// CHECK:namespace NN
// CHECK:namespace N
// CHECK:struct C
// CHECK:struct S1
// CHECK:NN::N::C
// CHECK:cbuffer GlobalCB
// CHECK:float4 a;
// CHECK:S1 s;
// CHECK:float foo(S0 s0)
// CHECK:float main()
#include "inc/globalCB.h"
struct S0 {
float4 b;
};
float foo(S0 s0) {
return s0.b.y;
}
struct S1 : S0 {
float4 c;
float getX() { return c.x; }
NN::N::C c0;
};
S1 s;
//C c;
float main() : SV_Target {
return s.getX();
}

Просмотреть файл

@ -0,0 +1,29 @@
// RUN: %dxr -decl-global-cb -line-directive %s | FileCheck %s
// Make sure created GlobalCB and the dependent type is before GlobalCB.
// Methods define stay after globalCB.
// CHECK:struct UseA {
// CHECK:float b;
// CHECK:float foo() ;
// CHECK:};
// CHECK:cbuffer GlobalCB {
// CHECK:float4 a;
// CHECK:UseA use;
// CHECK:}
// CHECK:float UseA::foo() {
float4 a;
struct UseA {
float b;
float foo() { return a.x + b;}
};
UseA use;
float main() : SV_Target {
return use.foo();
}

Просмотреть файл

@ -0,0 +1,34 @@
// RUN: %dxr -decl-global-cb -line-directive %s | FileCheck %s
// Make sure namespace in global cb works.
// CHECK:cbuffer GlobalCB {
// CHECK-NEXT:namespace N {
// CHECK-NEXT:float a;
// CHECK-NEXT:}
// CHECK-NEXT:namespace N2 {
// CHECK-NEXT:float b;
// CHECK-NEXT:}
// CHECK-NEXT:}
namespace N {
float a;
}
cbuffer B {
namespace N {
float c;
}
namespace N2 {
float d;
}
}
namespace N2 {
float b;
}
float main() : SV_Target {
return N::a + N::c + N2::d + N2::b;
}

Просмотреть файл

@ -0,0 +1,44 @@
// RUN: %dxr -decl-global-cb -line-directive %s | FileCheck %s
// Make sure namespace with method works.
// CHECK:namespace N2 {
// CHECK:struct UseA {
// CHECK:float b;
// CHECK:float foo() ;
// CHECK:};
// CHECK:}
// CHECK:cbuffer GlobalCB {
// CHECK:namespace N {
// CHECK:float4 a;
// CHECK:}
// CHECK:namespace N3 {
// CHECK:N2::UseA use;
// CHECK:}
// CHECK:}
// CHECK:namespace N2 {
// CHECK:float N2::UseA::foo() {
// CHECK:return N::a.x + this.b;
// CHECK:}
// CHECK:}
namespace N {
float4 a;
}
namespace N2 {
struct UseA {
float b;
float foo() { return N::a.x + b;}
};
}
namespace N3 {
N2::UseA use;
}
float main() : SV_Target {
return N3::use.foo();
}

Просмотреть файл

@ -0,0 +1,12 @@
float4 a;
namespace NN {
namespace N {
struct C {
int4 c;
};
}
}

Просмотреть файл

@ -417,6 +417,12 @@ public:
return S_OK;
}
HRESULT UnRegisterOutputStream() override {
m_pOutputStream.Detach();
m_pOutputStream = nullptr;
return S_OK;
}
~DxcArgsFileSystemImpl() override { };
BOOL FindNextFileW(
_In_ HANDLE hFindFile,

Просмотреть файл

@ -226,6 +226,83 @@ public:
}
};
// Collect all global constants.
class GlobalCBVisitor : public RecursiveASTVisitor<GlobalCBVisitor> {
private:
SmallVectorImpl<VarDecl *> &globalConstants;
public:
GlobalCBVisitor(SmallVectorImpl<VarDecl *> &globals)
: globalConstants(globals) {}
bool VisitVarDecl(VarDecl *vd) {
// Skip local var.
if (!vd->getDeclContext()->isTranslationUnit()) {
auto *DclContext = vd->getDeclContext();
while (NamespaceDecl *ND = dyn_cast<NamespaceDecl>(DclContext))
DclContext = ND->getDeclContext();
if (!DclContext->isTranslationUnit())
return true;
}
// Skip group shared.
if (vd->hasAttr<HLSLGroupSharedAttr>())
return true;
// Skip static global.
if (!vd->hasExternalFormalLinkage())
return true;
// Skip resource.
if (DXIL::ResourceClass::Invalid !=
hlsl::GetResourceClassForType(vd->getASTContext(), vd->getType()))
return true;
globalConstants.emplace_back(vd);
return true;
}
};
// Collect types used by a record decl.
// TODO: template support.
class TypeVisitor : public RecursiveASTVisitor<TypeVisitor> {
private:
MapVector<const TypeDecl *, DenseSet<const TypeDecl *>> &m_typeDepMap;
public:
TypeVisitor(
MapVector<const TypeDecl *, DenseSet<const TypeDecl *>> &typeDepMap)
: m_typeDepMap(typeDepMap) {}
bool VisitRecordType(const RecordType *RT) {
RecordDecl *RD = RT->getDecl();
if (m_typeDepMap.count(RD))
return true;
// Create empty dep set.
m_typeDepMap[RD];
if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
for (const auto &I : CXXRD->bases()) {
const CXXRecordDecl *BaseDecl =
cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
if (BaseDecl->field_empty())
continue;
QualType baseTy = QualType(BaseDecl->getTypeForDecl(), 0);
TraverseType(baseTy);
m_typeDepMap[RD].insert(BaseDecl);
}
}
for (auto *field : RD->fields()) {
QualType Ty = field->getType();
if (hlsl::IsHLSLResourceType(Ty))
continue;
if (hlsl::IsHLSLVecMatType(Ty))
continue;
TraverseType(Ty);
const clang::Type *TyPtr = Ty.getTypePtr();
m_typeDepMap[RD].insert(TyPtr->getAsTagDecl());
}
return true;
}
};
// Macro related.
namespace {
@ -441,8 +518,7 @@ void SetupCompilerCommon(CompilerInstance &compiler,
_In_ LPCSTR pMainFile,
_In_ TextDiagnosticPrinter *diagPrinter,
_In_opt_ ASTUnit::RemappedFile *rewrite,
_In_ hlsl::options::DxcOpts &opts,
_In_opt_ dxcutil::DxcArgsFileSystem *msfPtr) {
_In_ hlsl::options::DxcOpts &opts) {
// Setup a compiler instance.
std::shared_ptr<TargetOptions> targetOptions(new TargetOptions);
targetOptions->Triple = llvm::sys::getDefaultTargetTriple();
@ -504,8 +580,7 @@ void SetupCompilerForRewrite(
_In_opt_ ASTUnit::RemappedFile *rewrite, _In_ hlsl::options::DxcOpts &opts,
_In_opt_ LPCSTR pDefines, _In_opt_ dxcutil::DxcArgsFileSystem *msfPtr) {
SetupCompilerCommon(compiler, helper, pMainFile, diagPrinter, rewrite, opts,
msfPtr);
SetupCompilerCommon(compiler, helper, pMainFile, diagPrinter, rewrite, opts);
if (msfPtr) {
msfPtr->SetupForCompilerInstance(compiler);
@ -541,8 +616,7 @@ void SetupCompilerForPreprocess(
_In_ DxcDefine *pDefines, _In_ UINT32 defineCount,
_In_opt_ dxcutil::DxcArgsFileSystem *msfPtr) {
SetupCompilerCommon(compiler, helper, pMainFile, diagPrinter, rewrite, opts,
msfPtr);
SetupCompilerCommon(compiler, helper, pMainFile, diagPrinter, rewrite, opts);
if (pDefines) {
PreprocessorOptions &PPOpts = compiler.getPreprocessorOpts();
@ -581,6 +655,8 @@ HRESULT GenerateAST(DxcLangExtensionsHelper *pExtHelper, LPCSTR pFileName,
dxcutil::DxcArgsFileSystem *msfPtr, raw_ostream &w) {
// Setup a compiler instance.
CompilerInstance &compiler = astHelper.compiler;
compiler.getLangOpts().EnableTemplates = opts.EnableTemplates;
std::unique_ptr<TextDiagnosticPrinter> diagPrinter =
llvm::make_unique<TextDiagnosticPrinter>(w,
&compiler.getDiagnosticOpts());
@ -1146,6 +1222,72 @@ private:
bool bNeedLineInfo;
};
// Preprocess rewritten files.
HRESULT preprocessRewrittenFiles(
_In_ DxcLangExtensionsHelper *pExtHelper, Rewriter &R,
_In_ LPCSTR pFileName, _In_ ASTUnit::RemappedFile *pRemap,
_In_ hlsl::options::DxcOpts &opts, _In_ DxcDefine *pDefines,
_In_ UINT32 defineCount, raw_string_ostream &w, raw_string_ostream &o,
_In_opt_ dxcutil::DxcArgsFileSystem *msfPtr, IMalloc *pMalloc) {
CComPtr<AbstractMemoryStream> pOutputStream;
IFT(CreateMemoryStream(pMalloc, &pOutputStream));
raw_stream_ostream outStream(pOutputStream.p);
// TODO: how to reuse msfPtr when ReigsterOutputStream.
IFT(msfPtr->RegisterOutputStream(L"output.bc", pOutputStream));
llvm::MemoryBuffer *pMemBuf = pRemap->second;
std::unique_ptr<llvm::MemoryBuffer> pBuffer(
llvm::MemoryBuffer::getMemBufferCopy(pMemBuf->getBuffer(), pFileName));
std::unique_ptr<ASTUnit::RemappedFile> pPreprocessRemap(
new ASTUnit::RemappedFile(pFileName, pBuffer.release()));
// Need another compiler instance for preprocess because
// PrintPreprocessedAction will createPreprocessor.
CompilerInstance compiler;
std::unique_ptr<TextDiagnosticPrinter> diagPrinter =
llvm::make_unique<TextDiagnosticPrinter>(w,
&compiler.getDiagnosticOpts());
SetupCompilerForPreprocess(compiler, pExtHelper, pFileName, diagPrinter.get(),
pPreprocessRemap.get(), opts, pDefines,
defineCount, msfPtr);
auto &sourceManager = R.getSourceMgr();
auto &preprocessorOpts = compiler.getPreprocessorOpts();
// Map rewrite buf to source manager of preprocessor compiler.
for (auto it = R.buffer_begin(); it != R.buffer_end(); it++) {
RewriteBuffer &buf = it->second;
const FileEntry *Entry = sourceManager.getFileEntryForID(it->first);
std::string lineStr;
raw_string_ostream o(lineStr);
buf.write(o);
o.flush();
StringRef fileName = Entry->getName();
std::unique_ptr<llvm::MemoryBuffer> rewriteBuf =
MemoryBuffer::getMemBufferCopy(lineStr, fileName);
preprocessorOpts.addRemappedFile(fileName, rewriteBuf.release());
}
compiler.getFrontendOpts().OutputFile = "output.bc";
compiler.WriteDefaultOutputDirectly = true;
compiler.setOutStream(&outStream);
try {
PreprocessResult(compiler, pFileName);
StringRef out((char *)pOutputStream.p->GetPtr(),
pOutputStream.p->GetPtrSize());
o << out;
compiler.setSourceManager(nullptr);
msfPtr->UnRegisterOutputStream();
} catch (Exception &exp) {
w << exp.msg;
return E_FAIL;
} catch (...) {
return E_FAIL;
}
return S_OK;
}
HRESULT DoReWriteWithLineDirective(
_In_ DxcLangExtensionsHelper *pExtHelper, _In_ LPCSTR pFileName,
_In_ ASTUnit::RemappedFile *pRemap, _In_ hlsl::options::DxcOpts &opts,
@ -1198,62 +1340,8 @@ HRESULT DoReWriteWithLineDirective(
}
// Preprocess rewritten files.
{
CComPtr<AbstractMemoryStream> pOutputStream;
IFT(CreateMemoryStream(pMalloc, &pOutputStream));
raw_stream_ostream outStream(pOutputStream.p);
IFT(msfPtr->RegisterOutputStream(L"output.bc", pOutputStream));
llvm::MemoryBuffer *pMemBuf = pRemap->second;
std::unique_ptr<llvm::MemoryBuffer> pBuffer(
llvm::MemoryBuffer::getMemBufferCopy(pMemBuf->getBuffer(), pFileName));
std::unique_ptr<ASTUnit::RemappedFile> pPreprocessRemap(
new ASTUnit::RemappedFile(pFileName, pBuffer.release()));
// Need another compiler instance for preprocess because
// PrintPreprocessedAction will createPreprocessor.
CompilerInstance compiler;
std::unique_ptr<TextDiagnosticPrinter> diagPrinter =
llvm::make_unique<TextDiagnosticPrinter>(w,
&compiler.getDiagnosticOpts());
SetupCompilerForPreprocess(compiler, pExtHelper, pFileName,
diagPrinter.get(), pPreprocessRemap.get(), opts,
pDefines, defineCount, msfPtr);
auto &sourceManager = rewriter.getSourceMgr();
auto &preprocessorOpts = compiler.getPreprocessorOpts();
// Map rewrite buf to source manager of preprocessor compiler.
for (auto it = rewriter.buffer_begin(); it != rewriter.buffer_end(); it++) {
RewriteBuffer &buf = it->second;
const FileEntry *Entry = sourceManager.getFileEntryForID(it->first);
std::string lineStr;
raw_string_ostream o(lineStr);
buf.write(o);
o.flush();
StringRef fileName = Entry->getName();
std::unique_ptr<llvm::MemoryBuffer> rewriteBuf =
MemoryBuffer::getMemBufferCopy(lineStr, fileName);
preprocessorOpts.addRemappedFile(fileName, rewriteBuf.release());
}
compiler.getFrontendOpts().OutputFile = "output.bc";
compiler.WriteDefaultOutputDirectly = true;
compiler.setOutStream(&outStream);
try {
PreprocessResult(compiler, pFileName);
StringRef out((char *)pOutputStream.p->GetPtr(),
pOutputStream.p->GetPtrSize());
o << out;
compiler.setSourceManager(nullptr);
} catch (Exception &exp) {
w << exp.msg;
return E_FAIL;
} catch (...) {
return E_FAIL;
}
}
preprocessRewrittenFiles(pExtHelper, rewriter, pFileName, pRemap, opts,
pDefines, defineCount, w, o, msfPtr, pMalloc);
WriteMacroDefines(astHelper.semanticMacros, o);
if (opts.RWOpt.KeepUserMacro)
@ -1265,6 +1353,196 @@ HRESULT DoReWriteWithLineDirective(
return S_OK;
}
template<typename DT>
void printWithNamespace(DT *VD, raw_string_ostream &OS, PrintingPolicy &p) {
SmallVector<StringRef, 2> namespaceList;
auto const *Context = VD->getDeclContext();
while (const NamespaceDecl *ND = dyn_cast<NamespaceDecl>(Context)) {
namespaceList.emplace_back(ND->getName());
Context = ND->getDeclContext();
}
for (auto it = namespaceList.rbegin(); it != namespaceList.rend(); ++it) {
OS << "namespace " << *it << " {\n";
}
VD->print(OS, p);
OS << ";\n";
for (unsigned i = 0; i < namespaceList.size(); ++i) {
OS << "}\n";
}
}
void printTypeWithoutMethodBody(const TypeDecl *TD, raw_string_ostream &OS,
PrintingPolicy &p) {
PrintingPolicy declP(p);
declP.HLSLOnlyDecl = true;
printWithNamespace(TD, OS, declP);
}
class MethodsVisitor : public DeclVisitor<MethodsVisitor> {
public:
MethodsVisitor(raw_string_ostream &o, PrintingPolicy &p)
: OS(o), declP(p) {
declP.HLSLNoinlineMethod = true;
}
void VisitFunctionDecl(FunctionDecl *f) {
// Don't need to do namespace, the location is not change.
f->print(OS, declP);
return;
}
void VisitDeclContext(DeclContext *DC) {
SmallVector<Decl *, 2> Decls;
for (DeclContext::decl_iterator D = DC->decls_begin(),
DEnd = DC->decls_end();
D != DEnd; ++D) {
// Don't print ObjCIvarDecls, as they are printed when visiting the
// containing ObjCInterfaceDecl.
if (isa<ObjCIvarDecl>(*D))
continue;
// Skip over implicit declarations in pretty-printing mode.
if (D->isImplicit())
continue;
Visit(*D);
}
}
void VisitCXXRecordDecl(CXXRecordDecl *D) {
if (D->isCompleteDefinition()) {
VisitDeclContext(D);
}
}
private:
raw_string_ostream &OS;
PrintingPolicy declP;
};
HRESULT DoRewriteGlobalCB(_In_ DxcLangExtensionsHelper *pExtHelper,
_In_ LPCSTR pFileName,
_In_ ASTUnit::RemappedFile *pRemap,
_In_ hlsl::options::DxcOpts &opts,
_In_ DxcDefine *pDefines, _In_ UINT32 defineCount,
std::string &warnings, std::string &result,
_In_opt_ dxcutil::DxcArgsFileSystem *msfPtr,
IMalloc *pMalloc) {
raw_string_ostream o(result);
raw_string_ostream w(warnings);
ASTHelper astHelper;
GenerateAST(pExtHelper, pFileName, pRemap, pDefines, defineCount, astHelper,
opts, msfPtr, w);
if (astHelper.bHasErrors)
return E_FAIL;
TranslationUnitDecl *tu = astHelper.tu;
// Collect global constants.
SmallVector<VarDecl *, 128> globalConstants;
GlobalCBVisitor visitor(globalConstants);
visitor.TraverseDecl(tu);
// Collect types for global constants.
MapVector<const TypeDecl *, DenseSet<const TypeDecl *>> typeDepMap;
TypeVisitor tyVisitor(typeDepMap);
for (VarDecl *VD : globalConstants) {
QualType Type = VD->getType();
tyVisitor.TraverseType(Type);
}
ASTContext &C = tu->getASTContext();
Rewriter R(C.getSourceManager(), C.getLangOpts());
std::string globalCBStr;
raw_string_ostream OS(globalCBStr);
PrintingPolicy p = PrintingPolicy(C.getPrintingPolicy());
// Sort types with typeDepMap.
SmallVector<const TypeDecl *, 32> sortedGlobalConstantTypes;
while (!typeDepMap.empty()) {
SmallSet<const TypeDecl *, 4> noDepTypes;
for (auto it : typeDepMap) {
const TypeDecl *TD = it.first;
auto &dep = it.second;
if (dep.empty()) {
sortedGlobalConstantTypes.emplace_back(TD);
noDepTypes.insert(TD);
} else {
for (auto *depDecl : dep) {
if (typeDepMap.count(depDecl) == 0) {
noDepTypes.insert(depDecl);
}
}
for (auto *noDepDecl : noDepTypes) {
if (dep.count(noDepDecl))
dep.erase(noDepDecl);
}
if (dep.empty()) {
sortedGlobalConstantTypes.emplace_back(TD);
noDepTypes.insert(TD);
}
}
}
for (auto *noDepDecl : noDepTypes)
typeDepMap.erase(noDepDecl);
}
// Move all type decl to top of tu.
for (const TypeDecl *TD : sortedGlobalConstantTypes) {
printTypeWithoutMethodBody(TD, OS, p);
std::string methodsStr;
raw_string_ostream methodsOS(methodsStr);
MethodsVisitor Visitor(methodsOS, p);
Visitor.Visit(const_cast<TypeDecl*>(TD));
methodsOS.flush();
R.ReplaceText(TD->getSourceRange(), methodsStr);
// TODO: remove ; for type decl.
}
OS << "cbuffer GlobalCB {\n";
// Create HLSLBufferDecl after the types.
for (VarDecl *VD : globalConstants) {
printWithNamespace(VD, OS, p);
R.RemoveText(VD->getSourceRange());
// TODO: remove ; for var decl.
}
OS << "}\n";
OS.flush();
// Cannot find begin of tu, just write first when output.
// R.InsertTextBefore(tu->decls_begin()->getLocation(), globalCBStr);
o << globalCBStr;
// Preprocess rewritten files.
preprocessRewrittenFiles(pExtHelper, R, pFileName, pRemap, opts, pDefines,
defineCount, w, o, msfPtr, pMalloc);
WriteMacroDefines(astHelper.semanticMacros, o);
if (opts.RWOpt.KeepUserMacro)
WriteMacroDefines(astHelper.userMacros, o);
// Flush and return results.
o.flush();
w.flush();
return S_OK;
}
} // namespace
class DxcRewriter : public IDxcRewriter2, public IDxcLangExtensions3 {
@ -1478,6 +1756,7 @@ public:
new ASTUnit::RemappedFile(fName, pBuffer.release()));
hlsl::options::MainArgs mainArgs(argCount, pArguments, 0);
hlsl::options::DxcOpts opts;
IFR(ReadOptsAndValidate(mainArgs, opts, ppResult));
HRESULT hr;
@ -1486,6 +1765,20 @@ public:
return S_OK;
}
if (opts.RWOpt.DeclGlobalCB) {
std::string errors;
std::string rewrite;
HRESULT status = S_OK;
status = DoRewriteGlobalCB(&m_langExtensionsHelper, fName, pRemap.get(),
opts, pDefines, defineCount, errors, rewrite,
msfPtr, m_pMalloc);
if (status != S_OK) {
return S_OK;
}
pBuffer = llvm::MemoryBuffer::getMemBufferCopy(rewrite, fName);
pRemap.reset(new ASTUnit::RemappedFile(fName, pBuffer.release()));
}
std::string errors;
std::string rewrite;
HRESULT status = S_OK;