2018-06-30 02:43:40 +03:00
|
|
|
#include "dxc/DxrFallback/DxrFallbackCompiler.h"
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
#include "dxc/DXIL/DxilFunctionProps.h"
|
|
|
|
#include "dxc/DXIL/DxilInstructions.h"
|
|
|
|
#include "dxc/DXIL/DxilModule.h"
|
|
|
|
#include "dxc/DXIL/DxilOperations.h"
|
|
|
|
#include "dxc/HLSL/DxilLinker.h"
|
|
|
|
#include "dxc/Support/FileIOHelper.h"
|
2018-06-30 02:43:40 +03:00
|
|
|
#include "dxc/Support/Global.h"
|
|
|
|
#include "dxc/Support/Unicode.h"
|
|
|
|
#include "dxc/Support/WinIncludes.h"
|
2023-09-19 15:49:22 +03:00
|
|
|
#include "dxc/Support/dxcapi.impl.h"
|
|
|
|
#include "dxc/Support/dxcapi.use.h"
|
2018-06-30 02:43:40 +03:00
|
|
|
#include "dxc/dxcapi.h"
|
|
|
|
#include "dxc/dxcdxrfallbackcompiler.h"
|
|
|
|
|
|
|
|
#include "llvm/Analysis/CallGraph.h"
|
2023-09-19 15:49:22 +03:00
|
|
|
#include "llvm/IR/IRBuilder.h"
|
2018-06-30 02:43:40 +03:00
|
|
|
#include "llvm/IR/InstIterator.h"
|
|
|
|
#include "llvm/IR/Instructions.h"
|
|
|
|
#include "llvm/IR/LegacyPassManager.h"
|
|
|
|
#include "llvm/IR/Module.h"
|
|
|
|
#include "llvm/Linker/Linker.h"
|
|
|
|
#include "llvm/Transforms/IPO.h"
|
|
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
|
|
#include "llvm/Transforms/Utils/Cloning.h"
|
|
|
|
|
|
|
|
#include "FunctionBuilder.h"
|
|
|
|
#include "LLVMUtils.h"
|
|
|
|
#include "StateFunctionTransform.h"
|
2023-09-19 15:49:22 +03:00
|
|
|
#include "runtime.h"
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
#include <queue>
|
|
|
|
|
|
|
|
using namespace hlsl;
|
|
|
|
using namespace llvm;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static std::vector<Function *>
|
|
|
|
getFunctionsWithPrefix(Module *mod, const std::string &prefix) {
|
|
|
|
std::vector<Function *> functions;
|
|
|
|
for (auto F = mod->begin(), E = mod->end(); F != E; ++F) {
|
2018-06-30 02:43:40 +03:00
|
|
|
StringRef name = F->getName();
|
|
|
|
if (name.startswith(prefix))
|
|
|
|
functions.push_back(F);
|
|
|
|
}
|
|
|
|
return functions;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static bool inlineFunc(CallInst *call, Function *Fimpl) {
|
|
|
|
// Note. LLVM inlining may not be sufficient if the function references DX
|
2018-06-30 02:43:40 +03:00
|
|
|
// resources because the corresponding metadata is not created if the function
|
|
|
|
// comes from another module.
|
|
|
|
|
|
|
|
// Make sure that we have a definition for the called function in this module
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *F = call->getCalledFunction();
|
|
|
|
Module *dstM = F->getParent();
|
|
|
|
if (F->isDeclaration()) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Map called functions in impl module to functions in this one (because the
|
|
|
|
// cloning step doesn't do this automatically)
|
|
|
|
ValueToValueMapTy VMap;
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &I : inst_range(Fimpl)) {
|
|
|
|
if (CallInst *c = dyn_cast<CallInst>(&I)) {
|
|
|
|
Function *calledFimpl = c->getCalledFunction();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (VMap.count(calledFimpl))
|
|
|
|
continue;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Constant *calledF = dstM->getOrInsertFunction(
|
|
|
|
calledFimpl->getName(), calledFimpl->getFunctionType(),
|
|
|
|
calledFimpl->getAttributes());
|
2018-06-30 02:43:40 +03:00
|
|
|
VMap[calledFimpl] = calledF;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Map arguments
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(),
|
|
|
|
DI = F->arg_begin();
|
|
|
|
SI != SE; ++SI, ++DI)
|
2018-06-30 02:43:40 +03:00
|
|
|
VMap[SI] = DI;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
SmallVector<ReturnInst *, 4> returns;
|
2018-06-30 02:43:40 +03:00
|
|
|
CloneFunctionInto(F, Fimpl, VMap, true, returns);
|
|
|
|
F->setLinkage(GlobalValue::InternalLinkage);
|
|
|
|
}
|
|
|
|
|
|
|
|
InlineFunctionInfo IFI;
|
|
|
|
return InlineFunction(call, IFI, false);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Remove ELF mangling
|
2023-09-19 15:49:22 +03:00
|
|
|
static std::string cleanName(StringRef name) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!name.startswith("\x1?"))
|
|
|
|
return name;
|
|
|
|
|
|
|
|
size_t pos = name.find("@@");
|
|
|
|
if (pos == name.npos)
|
|
|
|
return name;
|
|
|
|
|
|
|
|
std::string newName = name.substr(2, pos - 2);
|
|
|
|
return newName;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static inline Function *getOrInsertFunction(Module *mod, Function *F) {
|
|
|
|
return dyn_cast<Function>(
|
|
|
|
mod->getOrInsertFunction(F->getName(), F->getFunctionType()));
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
template <typename K, typename V>
|
|
|
|
V get(std::map<K, V> &theMap, const K &key,
|
|
|
|
V defaultVal = static_cast<V>(nullptr)) {
|
2018-06-30 02:43:40 +03:00
|
|
|
auto it = theMap.find(key);
|
|
|
|
if (it == theMap.end())
|
|
|
|
return defaultVal;
|
|
|
|
else
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
DxrFallbackCompiler::DxrFallbackCompiler(
|
|
|
|
llvm::Module *mod, const std::vector<std::string> &shaderNames,
|
|
|
|
unsigned maxAttributeSize, unsigned stackSizeInBytes,
|
|
|
|
bool findCalledShaders /*= false*/)
|
|
|
|
: m_module(mod), m_entryShaderNames(shaderNames),
|
|
|
|
m_stackSizeInBytes(stackSizeInBytes),
|
|
|
|
m_maxAttributeSize(maxAttributeSize),
|
|
|
|
m_findCalledShaders(findCalledShaders) {}
|
|
|
|
|
|
|
|
void DxrFallbackCompiler::compile(std::vector<int> &shaderEntryStateIds,
|
|
|
|
std::vector<unsigned int> &shaderStackSizes,
|
|
|
|
IntToFuncNameMap *pCachedMap) {
|
2018-06-30 02:43:40 +03:00
|
|
|
std::vector<std::string> shaderNames = m_entryShaderNames;
|
|
|
|
initShaderMap(shaderNames);
|
|
|
|
|
|
|
|
// Bring in runtime so we can get the runtime data type
|
|
|
|
linkRuntime();
|
2023-09-19 15:49:22 +03:00
|
|
|
Type *runtimeDataArgTy = getRuntimeDataArgType();
|
|
|
|
|
|
|
|
// Make sure all calls to intrinsics and shaders are at function scope and
|
2018-06-30 02:43:40 +03:00
|
|
|
// fix up control flow.
|
|
|
|
lowerAnyHitControlFlowFuncs();
|
|
|
|
lowerReportHit();
|
|
|
|
lowerTraceRay(runtimeDataArgTy);
|
2023-09-19 15:49:22 +03:00
|
|
|
|
2018-06-30 02:43:40 +03:00
|
|
|
// Create state functions
|
|
|
|
IntToFuncMap stateFunctionMap; // stateID -> state function
|
2023-09-19 15:49:22 +03:00
|
|
|
const int baseStateId =
|
|
|
|
1000; // could be anything but this makes stateIds more recognizable
|
|
|
|
createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes,
|
|
|
|
baseStateId, shaderNames, runtimeDataArgTy);
|
|
|
|
|
|
|
|
if (pCachedMap) {
|
|
|
|
for (auto &entry : stateFunctionMap) {
|
|
|
|
(*pCachedMap)[entry.first] = entry.second->getName().str();
|
|
|
|
}
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::link(std::vector<int> &shaderEntryStateIds,
|
|
|
|
std::vector<unsigned int> &shaderStackSizes,
|
|
|
|
IntToFuncNameMap *pCachedMap) {
|
|
|
|
IntToFuncMap stateFunctionMap; // stateID -> state function
|
|
|
|
if (pCachedMap) {
|
|
|
|
for (auto entry : *pCachedMap) {
|
|
|
|
stateFunctionMap[entry.first] = m_module->getFunction(entry.second);
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
2023-09-19 15:49:22 +03:00
|
|
|
} else {
|
|
|
|
for (UINT i = 0; i < shaderEntryStateIds.size(); i++) {
|
|
|
|
UINT substateIndex = 0;
|
|
|
|
UINT baseStateId = shaderEntryStateIds[i];
|
|
|
|
while (true) {
|
|
|
|
auto substateName =
|
|
|
|
m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex);
|
|
|
|
|
|
|
|
auto function = m_module->getFunction(substateName);
|
|
|
|
if (!function)
|
|
|
|
break;
|
|
|
|
stateFunctionMap[baseStateId + substateIndex] =
|
|
|
|
m_module->getFunction(substateName);
|
|
|
|
substateIndex++;
|
|
|
|
}
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
2023-09-19 15:49:22 +03:00
|
|
|
}
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
// Fix up scheduler
|
|
|
|
Function *schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler");
|
|
|
|
createLaunchParams(schedulerFunc);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Type *runtimeDataArgTy = getRuntimeDataArgType();
|
|
|
|
createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy);
|
|
|
|
createStack(schedulerFunc);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
lowerIntrinsics();
|
|
|
|
}
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::setDebugOutputLevel(int val) {
|
2018-06-30 02:43:40 +03:00
|
|
|
m_debugOutputLevel = val;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static bool isShader(Function *F) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (F->hasFnAttribute("exp-shader"))
|
|
|
|
return true;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
DxilModule &DM = F->getParent()->GetDxilModule();
|
2018-06-30 02:43:40 +03:00
|
|
|
return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay());
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
DXIL::ShaderKind getRayShaderKind(Function *F) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (F->hasFnAttribute("exp-shader"))
|
|
|
|
return DXIL::ShaderKind::RayGeneration;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
DxilModule &DM = F->getParent()->GetDxilModule();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
|
|
|
|
return DM.GetDxilFunctionProps(F).shaderKind;
|
|
|
|
|
|
|
|
return DXIL::ShaderKind::Invalid;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
// Some shaders should use the "pending" values of intrinsics instead of the
|
2018-06-30 02:43:40 +03:00
|
|
|
// committed ones. In particular anyhit and intersection shaders use the
|
|
|
|
// pending values with the exception that the committed rayTCurrent should be
|
|
|
|
// used in intersection.
|
2023-09-19 15:49:22 +03:00
|
|
|
static bool shouldUsePendingValue(Function *F, StringRef instrinsicName) {
|
|
|
|
DxilModule &DM = F->getParent()->GetDxilModule();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!DM.HasDxilFunctionProps(F))
|
|
|
|
return false;
|
2023-09-19 15:49:22 +03:00
|
|
|
const hlsl::DxilFunctionProps &props = DM.GetDxilFunctionProps(F);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
return props.IsAnyHit() ||
|
|
|
|
(props.IsIntersection() && instrinsicName != "rayTCurrent");
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::initShaderMap(std::vector<std::string> &shaderNames) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Clean names and initialize shaderMap
|
|
|
|
StringToFuncMap allShadersMap;
|
2023-09-19 15:49:22 +03:00
|
|
|
for (Function &F : m_module->functions()) {
|
|
|
|
if (isShader(&F)) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!F.isDeclaration())
|
|
|
|
allShadersMap[cleanName(F.getName())] = &F;
|
|
|
|
}
|
|
|
|
|
|
|
|
F.removeFnAttr(Attribute::NoInline);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &name : shaderNames)
|
2018-06-30 02:43:40 +03:00
|
|
|
m_shaderMap[name] = allShadersMap[name];
|
|
|
|
|
|
|
|
if (!m_findCalledShaders)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Create a map from shader name to CallGraphNode
|
|
|
|
CallGraph callGraph(*m_module);
|
2023-09-19 15:49:22 +03:00
|
|
|
std::map<std::string, CallGraphNode *> allShaderNodes;
|
|
|
|
for (auto &kv : m_shaderMap) {
|
|
|
|
const std::string &name = kv.first;
|
|
|
|
Function *func = kv.second;
|
2018-06-30 02:43:40 +03:00
|
|
|
allShaderNodes[name] = callGraph[func];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Start traversing the call graph from given shaderNames
|
2023-09-19 15:49:22 +03:00
|
|
|
std::deque<CallGraphNode *> workList;
|
|
|
|
for (auto &name : shaderNames)
|
2018-06-30 02:43:40 +03:00
|
|
|
workList.push_back(allShaderNodes[name]);
|
2023-09-19 15:49:22 +03:00
|
|
|
while (!workList.empty()) {
|
|
|
|
CallGraphNode *cur = workList.front();
|
2018-06-30 02:43:40 +03:00
|
|
|
workList.pop_front();
|
2023-09-19 15:49:22 +03:00
|
|
|
for (size_t i = 0; i < cur->size(); ++i) {
|
|
|
|
Function *nextFunc = (*cur)[i]->getFunction();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!nextFunc)
|
|
|
|
continue;
|
2023-09-19 15:49:22 +03:00
|
|
|
if (isShader(nextFunc)) {
|
2018-06-30 02:43:40 +03:00
|
|
|
const std::string nextName = cleanName(nextFunc->getName());
|
|
|
|
if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet?
|
|
|
|
{
|
|
|
|
workList.push_back(allShaderNodes[nextName]);
|
|
|
|
shaderNames.push_back(nextName);
|
|
|
|
m_shaderMap[nextName] = workList.back()->getFunction();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::linkRuntime() {
|
2018-06-30 02:43:40 +03:00
|
|
|
Linker linker(m_module);
|
2023-09-19 15:49:22 +03:00
|
|
|
std::unique_ptr<Module> runtimeModule =
|
|
|
|
loadModuleFromAsmString(m_module->getContext(), getRuntimeString());
|
2018-06-30 02:43:40 +03:00
|
|
|
bool linkErr = linker.linkInModule(runtimeModule.get());
|
|
|
|
assert(!linkErr && "Error linking runtime");
|
|
|
|
UNREFERENCED_PARAMETER(linkErr);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static void inlineFuncAndAddRet(CallInst *call, Function *F) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Add a return after the function call.
|
2023-09-19 15:49:22 +03:00
|
|
|
// Should be followed immediately by "unreachable". Turn that into a "ret
|
|
|
|
// void".
|
|
|
|
Instruction *ret = ReturnInst::Create(call->getContext());
|
2018-06-30 02:43:40 +03:00
|
|
|
ReplaceInstWithInst(call->getParent()->getTerminator(), ret);
|
|
|
|
|
|
|
|
bool success = inlineFunc(call, F);
|
|
|
|
assert(success);
|
|
|
|
UNREFERENCED_PARAMETER(success);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs() {
|
|
|
|
std::vector<CallInst *> callsToIgnoreHit =
|
|
|
|
getCallsInShadersToFunction("dx.op.ignoreHit");
|
|
|
|
if (!callsToIgnoreHit.empty()) {
|
|
|
|
Function *ignoreHitFunc =
|
|
|
|
m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ");
|
2018-06-30 02:43:40 +03:00
|
|
|
assert(ignoreHitFunc && "IgnoreHit() implementation not found");
|
2023-09-19 15:49:22 +03:00
|
|
|
for (CallInst *call : callsToIgnoreHit)
|
2018-06-30 02:43:40 +03:00
|
|
|
inlineFuncAndAddRet(call, ignoreHitFunc);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<CallInst *> callsToAcceptHitAndEndSearch =
|
|
|
|
getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch");
|
|
|
|
if (!callsToAcceptHitAndEndSearch.empty()) {
|
|
|
|
Function *acceptHitAndEndSearchFunc =
|
|
|
|
m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ");
|
|
|
|
assert(acceptHitAndEndSearchFunc &&
|
|
|
|
"AcceptHitAndEndSearch() implementation not found");
|
|
|
|
for (CallInst *call : callsToAcceptHitAndEndSearch)
|
2018-06-30 02:43:40 +03:00
|
|
|
inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::lowerReportHit() {
|
|
|
|
std::vector<CallInst *> callsToReportHit =
|
|
|
|
getCallsInShadersToFunctionWithPrefix("dx.op.reportHit");
|
2018-06-30 02:43:40 +03:00
|
|
|
if (callsToReportHit.empty())
|
|
|
|
return;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *reportHitFunc =
|
|
|
|
m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z");
|
2018-06-30 02:43:40 +03:00
|
|
|
assert(reportHitFunc && "ReportHit() implementation not found");
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
LLVMContext &C = m_module->getContext();
|
|
|
|
for (CallInst *call : callsToReportHit) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Wrap attribute arguments in Fallback_SetPendingAttr() call
|
2023-09-19 15:49:22 +03:00
|
|
|
Instruction *insertBefore = call;
|
2018-06-30 02:43:40 +03:00
|
|
|
hlsl::DxilInst_ReportHit reportHitCall(call);
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *attr = reportHitCall.get_Attributes();
|
|
|
|
Function *setPendingAttrFunc =
|
|
|
|
FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@")
|
|
|
|
.voidTy()
|
|
|
|
.type(attr->getType(), "attr")
|
|
|
|
.build();
|
|
|
|
CallInst::Create(setPendingAttrFunc, {attr}, "", insertBefore);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
// Make call to implementation and load result
|
2023-09-19 15:49:22 +03:00
|
|
|
CallInst *callImpl = CallInst::Create(
|
|
|
|
reportHitFunc, {reportHitCall.get_THit(), reportHitCall.get_HitKind()},
|
|
|
|
"reportHit.result", insertBefore);
|
|
|
|
Value *result = callImpl;
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
// Result < 0 ==> ret
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *zero = makeInt32(0, C);
|
|
|
|
Value *ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero,
|
|
|
|
"endSearch");
|
|
|
|
BasicBlock *prevBlock = call->getParent();
|
|
|
|
BasicBlock *retBlock = prevBlock->splitBasicBlock(call, "endSearch");
|
|
|
|
BasicBlock *nextBlock = retBlock->splitBasicBlock(call, "afterReportHit");
|
|
|
|
ReplaceInstWithInst(prevBlock->getTerminator(),
|
|
|
|
BranchInst::Create(retBlock, nextBlock, ltz));
|
2018-06-30 02:43:40 +03:00
|
|
|
ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C));
|
|
|
|
|
|
|
|
// Compare result to zero and store into original result
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *gtz =
|
|
|
|
new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted");
|
2018-06-30 02:43:40 +03:00
|
|
|
call->replaceAllUsesWith(gtz);
|
|
|
|
|
|
|
|
bool success = inlineFunc(callImpl, reportHitFunc);
|
|
|
|
assert(success);
|
|
|
|
(void)success;
|
|
|
|
|
|
|
|
call->eraseFromParent();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::lowerTraceRay(Type *runtimeDataArgTy) {
|
|
|
|
std::vector<CallInst *> callsToTraceRay =
|
|
|
|
getCallsInShadersToFunctionWithPrefix("dx.op.traceRay");
|
|
|
|
if (callsToTraceRay.empty()) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// TODO: It might be worth dropping this from the tests eventually
|
2023-09-19 15:49:22 +03:00
|
|
|
callsToTraceRay =
|
|
|
|
getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@");
|
2018-06-30 02:43:40 +03:00
|
|
|
if (callsToTraceRay.empty())
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<Function *> traceRayImpl =
|
|
|
|
getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@");
|
|
|
|
assert(traceRayImpl.size() == 1 &&
|
|
|
|
"Could not find Fallback_TraceRay() implementation");
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
enum { CLOSEST_HIT = 0, MISS = 1 };
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *traceRaySave[] = {m_module->getFunction("traceRaySave_ClosestHit"),
|
|
|
|
m_module->getFunction("traceRaySave_Miss")};
|
|
|
|
Function *traceRayRestore[] = {
|
|
|
|
m_module->getFunction("traceRayRestore_ClosestHit"),
|
|
|
|
m_module->getFunction("traceRayRestore_Miss")};
|
|
|
|
assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] &&
|
|
|
|
traceRaySave[MISS] && traceRayRestore[MISS] &&
|
|
|
|
"Could not find TraceRay spill functions");
|
|
|
|
|
|
|
|
Function *dummyRuntimeDataArgFunc =
|
|
|
|
StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module,
|
|
|
|
runtimeDataArgTy);
|
|
|
|
assert(dummyRuntimeDataArgFunc &&
|
|
|
|
"dummyRuntimeDataArg function could not be created.");
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
// Process calls
|
2023-09-19 15:49:22 +03:00
|
|
|
LLVMContext &C = m_module->getContext();
|
|
|
|
Type *int32Ty = Type::getInt32Ty(C);
|
|
|
|
std::map<FunctionType *, Function *> movePayloadToStackFuncs;
|
|
|
|
std::map<Function *, AllocaInst *> funcToSpillAlloca;
|
|
|
|
for (CallInst *call : callsToTraceRay) {
|
|
|
|
Instruction *insertBefore = call;
|
|
|
|
|
2018-06-30 02:43:40 +03:00
|
|
|
// Spill runtime data values, if necessary (closesthit and miss shaders)
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *caller = call->getParent()->getParent();
|
2018-06-30 02:43:40 +03:00
|
|
|
DXIL::ShaderKind kind = getRayShaderKind(caller);
|
2023-09-19 15:49:22 +03:00
|
|
|
if (kind == DXIL::ShaderKind::ClosestHit ||
|
|
|
|
kind == DXIL::ShaderKind::Miss) {
|
2018-06-30 02:43:40 +03:00
|
|
|
int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS;
|
2023-09-19 15:49:22 +03:00
|
|
|
AllocaInst *spillAlloca = get(funcToSpillAlloca, caller);
|
|
|
|
if (!spillAlloca) {
|
|
|
|
Argument *spillAllocaArg = (++traceRaySave[sh]->arg_begin());
|
|
|
|
Type *spillAllocaTy =
|
|
|
|
spillAllocaArg->getType()->getPointerElementType();
|
|
|
|
spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca",
|
|
|
|
caller->getEntryBlock().begin());
|
2018-06-30 02:43:40 +03:00
|
|
|
funcToSpillAlloca[caller] = spillAlloca;
|
|
|
|
}
|
2023-09-19 15:49:22 +03:00
|
|
|
|
2018-06-30 02:43:40 +03:00
|
|
|
// Create calls. SFT will inline them.
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc,
|
|
|
|
"runtimeData", insertBefore);
|
|
|
|
CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "",
|
|
|
|
insertBefore);
|
|
|
|
CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "",
|
|
|
|
getInstructionAfter(call));
|
2018-06-30 02:43:40 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
// Get the payload offset to pass to trace implementation
|
2023-09-19 15:49:22 +03:00
|
|
|
// hlsl::DxilInst_TraceRay traceRayCall(call);
|
2018-06-30 02:43:40 +03:00
|
|
|
// TODO: Avoiding the intrinsic to support the test's use of TraceRayTest
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *payload = call->getOperand(call->getNumArgOperands() - 1);
|
|
|
|
FunctionType *funcType =
|
|
|
|
FunctionType::get(int32Ty, {payload->getType()}, false);
|
|
|
|
Function *movePayloadToStackFunc = getOrCreateFunction(
|
|
|
|
"movePayloadToStack", m_module, funcType, movePayloadToStackFuncs);
|
|
|
|
Value *newPayloadOffset = CallInst::Create(
|
|
|
|
movePayloadToStackFunc, {payload}, "new.payload.offset", insertBefore);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
// Call implementation
|
|
|
|
unsigned i = 0;
|
|
|
|
if (call->getCalledFunction()->getName().startswith("dx.op"))
|
|
|
|
i += 2; // skip intrinsic number and acceleration structure (for now)
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<Value *> args;
|
2018-06-30 02:43:40 +03:00
|
|
|
for (; i < call->getNumArgOperands() - 1; ++i)
|
|
|
|
args.push_back(call->getArgOperand(i));
|
|
|
|
args.push_back(newPayloadOffset);
|
|
|
|
CallInst::Create(traceRayImpl[0], args, "", insertBefore);
|
|
|
|
|
|
|
|
call->eraseFromParent();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static std::vector<StateFunctionTransform::ParameterSemanticType>
|
|
|
|
getParameterTypes(Function *F, DXIL::ShaderKind shaderKind) {
|
2018-06-30 02:43:40 +03:00
|
|
|
std::vector<StateFunctionTransform::ParameterSemanticType> paramTypes;
|
2023-09-19 15:49:22 +03:00
|
|
|
if (shaderKind == DXIL::ShaderKind::AnyHit ||
|
|
|
|
shaderKind == DXIL::ShaderKind::ClosestHit) {
|
2018-06-30 02:43:40 +03:00
|
|
|
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
|
|
|
|
paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE);
|
2023-09-19 15:49:22 +03:00
|
|
|
} else if (shaderKind == DXIL::ShaderKind::Miss) {
|
2018-06-30 02:43:40 +03:00
|
|
|
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
|
2023-09-19 15:49:22 +03:00
|
|
|
} else {
|
2018-06-30 02:43:40 +03:00
|
|
|
paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE);
|
|
|
|
}
|
|
|
|
return paramTypes;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static void collectResources(DxilModule &DM, std::set<Value *> &resources) {
|
|
|
|
for (auto &r : DM.GetCBuffers())
|
2018-06-30 02:43:40 +03:00
|
|
|
resources.insert(r->GetGlobalSymbol());
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &r : DM.GetUAVs())
|
2018-06-30 02:43:40 +03:00
|
|
|
resources.insert(r->GetGlobalSymbol());
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &r : DM.GetSRVs())
|
2018-06-30 02:43:40 +03:00
|
|
|
resources.insert(r->GetGlobalSymbol());
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &r : DM.GetSamplers())
|
2018-06-30 02:43:40 +03:00
|
|
|
resources.insert(r->GetGlobalSymbol());
|
|
|
|
}
|
|
|
|
|
|
|
|
void DxrFallbackCompiler::createStateFunctions(
|
2023-09-19 15:49:22 +03:00
|
|
|
IntToFuncMap &stateFunctionMap, std::vector<int> &shaderEntryStateIds,
|
|
|
|
std::vector<unsigned int> &shaderStackSizes, int baseStateId,
|
|
|
|
const std::vector<std::string> &shaderNames, Type *runtimeDataArgTy) {
|
|
|
|
for (auto &kv : m_shaderMap) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (kv.second == nullptr)
|
|
|
|
errs() << "Function not found for shader " << kv.first << "\n";
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
DxilModule &DM = m_module->GetOrCreateDxilModule();
|
|
|
|
std::set<Value *> resources;
|
2018-06-30 02:43:40 +03:00
|
|
|
collectResources(DM, resources);
|
|
|
|
|
|
|
|
shaderEntryStateIds.clear();
|
|
|
|
shaderStackSizes.clear();
|
|
|
|
int stateId = baseStateId;
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &shader : shaderNames) {
|
|
|
|
std::vector<Function *> stateFunctions;
|
|
|
|
Function *F = m_shaderMap[shader];
|
2018-06-30 02:43:40 +03:00
|
|
|
StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy);
|
|
|
|
if (m_debugOutputLevel >= 2)
|
|
|
|
sft.setVerbose(true);
|
|
|
|
if (m_debugOutputLevel >= 3)
|
|
|
|
sft.setDumpFilename("dump.ll");
|
|
|
|
if (shader == "Fallback_TraceRay")
|
|
|
|
sft.setAttributeSize(m_maxAttributeSize);
|
|
|
|
DXIL::ShaderKind shaderKind = getRayShaderKind(F);
|
|
|
|
if (shaderKind != DXIL::ShaderKind::Invalid)
|
2023-09-19 15:49:22 +03:00
|
|
|
sft.setParameterInfo(getParameterTypes(F, shaderKind),
|
|
|
|
shaderKind == DXIL::ShaderKind::ClosestHit);
|
2018-06-30 02:43:40 +03:00
|
|
|
sft.setResourceGlobals(resources);
|
|
|
|
UINT shaderStackSize = 0;
|
|
|
|
sft.run(stateFunctions, shaderStackSize);
|
|
|
|
|
|
|
|
shaderEntryStateIds.push_back(stateId);
|
|
|
|
shaderStackSizes.push_back(shaderStackSize);
|
2023-09-19 15:49:22 +03:00
|
|
|
for (Function *stateF : stateFunctions) {
|
2018-06-30 02:43:40 +03:00
|
|
|
stateFunctionMap[stateId++] = stateF;
|
|
|
|
if (DM.HasDxilFunctionProps(F)) {
|
|
|
|
DM.CloneDxilEntryProps(F, stateF);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::createLaunchParams(Function *func) {
|
|
|
|
Module *mod = func->getParent();
|
|
|
|
Function *rewrite_setLaunchParams =
|
|
|
|
mod->getFunction("rewrite_setLaunchParams");
|
|
|
|
CallInst *call = dyn_cast<CallInst>(*rewrite_setLaunchParams->user_begin());
|
|
|
|
|
|
|
|
LLVMContext &context = mod->getContext();
|
|
|
|
Instruction *insertBefore = call;
|
|
|
|
|
|
|
|
Function *DTidFunc =
|
|
|
|
FunctionBuilder(mod, "dx.op.threadId.i32").i32().i32().i32().build();
|
|
|
|
Value *DTidx =
|
|
|
|
CallInst::Create(DTidFunc,
|
|
|
|
{makeInt32((int)hlsl::OP::OpCode::ThreadId, context),
|
|
|
|
makeInt32(0, context)},
|
|
|
|
"DTidx", insertBefore);
|
|
|
|
Value *DTidy =
|
|
|
|
CallInst::Create(DTidFunc,
|
|
|
|
{makeInt32((int)hlsl::OP::OpCode::ThreadId, context),
|
|
|
|
makeInt32(1, context)},
|
|
|
|
"DTidy", insertBefore);
|
|
|
|
|
|
|
|
Value *dimx = call->getArgOperand(1);
|
|
|
|
Value *dimy = call->getArgOperand(2);
|
|
|
|
|
|
|
|
Function *groupIndexFunc =
|
|
|
|
FunctionBuilder(mod, "dx.op.flattenedThreadIdInGroup.i32")
|
|
|
|
.i32()
|
|
|
|
.i32()
|
|
|
|
.build();
|
|
|
|
Value *groupIndex = CallInst::Create(groupIndexFunc, {makeInt32(96, context)},
|
|
|
|
"groupIndex", insertBefore);
|
|
|
|
|
|
|
|
Function *fb_setLaunchParams =
|
|
|
|
mod->getFunction("fb_Fallback_SetLaunchParams");
|
|
|
|
Value *runtimeDataArg = call->getArgOperand(0);
|
|
|
|
CallInst::Create(fb_setLaunchParams,
|
|
|
|
{runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex}, "",
|
|
|
|
insertBefore);
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
call->eraseFromParent();
|
|
|
|
rewrite_setLaunchParams->eraseFromParent();
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::createStateDispatch(
|
|
|
|
Function *func, const IntToFuncMap &stateFunctionMap,
|
|
|
|
Type *runtimeDataArgTy) {
|
|
|
|
Module *mod = func->getParent();
|
|
|
|
Function *dispatchFunc =
|
|
|
|
createDispatchFunction(stateFunctionMap, runtimeDataArgTy);
|
|
|
|
Function *rewrite_dispatchFunc = mod->getFunction("rewrite_dispatch");
|
2018-06-30 02:43:40 +03:00
|
|
|
rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc);
|
|
|
|
rewrite_dispatchFunc->eraseFromParent();
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::createStack(Function *func) {
|
|
|
|
LLVMContext &context = func->getContext();
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
// We would like to allocate the properly sized stack here, but DXIL doesn't
|
|
|
|
// allow bitcasts between objects of different sizes. So we have to use the
|
|
|
|
// default size from the runtime and replace all the accesses later.
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *rewrite_createStack = m_module->getFunction("rewrite_createStack");
|
|
|
|
CallInst *call = dyn_cast<CallInst>(*rewrite_createStack->user_begin());
|
|
|
|
AllocaInst *stack = new AllocaInst(call->getType()->getPointerElementType(),
|
|
|
|
"theStack", call);
|
2018-06-30 02:43:40 +03:00
|
|
|
stack->setAlignment(sizeof(int));
|
|
|
|
call->replaceAllUsesWith(stack);
|
|
|
|
call->eraseFromParent();
|
|
|
|
rewrite_createStack->eraseFromParent();
|
|
|
|
|
|
|
|
if (m_stackSizeInBytes == 0) // Take the default
|
2023-09-19 15:49:22 +03:00
|
|
|
m_stackSizeInBytes =
|
|
|
|
stack->getType()->getPointerElementType()->getArrayNumElements() *
|
|
|
|
sizeof(int);
|
|
|
|
Function *rewrite_getStackSize =
|
|
|
|
m_module->getFunction("rewrite_getStackSize");
|
2018-06-30 02:43:40 +03:00
|
|
|
call = dyn_cast<CallInst>(*rewrite_getStackSize->user_begin());
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *stackSizeVal = makeInt32(m_stackSizeInBytes, context);
|
2018-06-30 02:43:40 +03:00
|
|
|
call->replaceAllUsesWith(stackSizeVal);
|
|
|
|
call->eraseFromParent();
|
|
|
|
rewrite_getStackSize->eraseFromParent();
|
|
|
|
}
|
|
|
|
|
|
|
|
// WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime
|
2023-09-19 15:49:22 +03:00
|
|
|
static bool expandFloat3(std::vector<Value *> &args, Value *arg,
|
|
|
|
Instruction *insertBefore) {
|
|
|
|
VectorType *argTy = dyn_cast<VectorType>(arg->getType());
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!argTy || argTy->getVectorNumElements() != 3)
|
|
|
|
return false;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
LLVMContext &C = arg->getContext();
|
|
|
|
args.push_back(
|
|
|
|
ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore));
|
|
|
|
args.push_back(
|
|
|
|
ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore));
|
|
|
|
args.push_back(
|
|
|
|
ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore));
|
2018-06-30 02:43:40 +03:00
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
static bool float3x4ToFloat12(std::vector<Value *> &args, Value *arg,
|
|
|
|
Instruction *insertBefore) {
|
|
|
|
StructType *STy = dyn_cast<StructType>(arg->getType());
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!STy || STy->getName() != "class.matrix.float.3.4")
|
|
|
|
return false;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
BasicBlock &entryBlock =
|
|
|
|
insertBefore->getParent()->getParent()->getEntryBlock();
|
|
|
|
AllocaInst *alloca =
|
|
|
|
new AllocaInst(arg->getType(), "tmp", entryBlock.begin());
|
2018-06-30 02:43:40 +03:00
|
|
|
new StoreInst(arg, alloca, insertBefore);
|
2023-09-19 15:49:22 +03:00
|
|
|
VectorType *VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12);
|
|
|
|
Value *vec12Ptr =
|
|
|
|
new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore);
|
|
|
|
Value *vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore);
|
2018-06-30 02:43:40 +03:00
|
|
|
args.push_back(vec12);
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::lowerIntrinsics() {
|
|
|
|
std::vector<Function *> intrinsics = getFunctionsWithPrefix(m_module, "fb_");
|
2018-06-30 02:43:40 +03:00
|
|
|
assert(intrinsics.size() > 0);
|
|
|
|
|
|
|
|
// Replace intrinsics in anyhit shaders with their pending versions
|
2023-09-19 15:49:22 +03:00
|
|
|
LLVMContext &C = m_module->getContext();
|
|
|
|
std::map<std::string, Function *> pendingIntrinsics;
|
|
|
|
std::string pendingPrefixes[] = {"fb_dxop_pending_", "fb_Fallback_Pending"};
|
|
|
|
for (auto &F : intrinsics) {
|
2018-06-30 02:43:40 +03:00
|
|
|
std::string intrinsicName;
|
|
|
|
if (F->getName().startswith(pendingPrefixes[0]))
|
|
|
|
intrinsicName = F->getName().substr(pendingPrefixes[0].length());
|
|
|
|
else if (F->getName().startswith(pendingPrefixes[1]))
|
2023-09-19 15:49:22 +03:00
|
|
|
intrinsicName =
|
|
|
|
"Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str();
|
2018-06-30 02:43:40 +03:00
|
|
|
else
|
|
|
|
continue;
|
|
|
|
|
|
|
|
pendingIntrinsics[intrinsicName] = F;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
for (Function *func : intrinsics) {
|
2018-06-30 02:43:40 +03:00
|
|
|
StringRef intrinsicName;
|
|
|
|
std::string name;
|
|
|
|
bool isDxilOp = false;
|
2023-09-19 15:49:22 +03:00
|
|
|
if (func->getName().startswith("fb_Fallback_")) {
|
2018-06-30 02:43:40 +03:00
|
|
|
intrinsicName = func->getName().substr(3); // after the "fb_" prefix
|
|
|
|
name = "\x1?" + intrinsicName.str();
|
2023-09-19 15:49:22 +03:00
|
|
|
} else if (func->getName().startswith("fb_dxop_")) {
|
2018-06-30 02:43:40 +03:00
|
|
|
intrinsicName = func->getName().substr(8);
|
|
|
|
name = "dx.op." + intrinsicName.str();
|
|
|
|
isDxilOp = true;
|
2023-09-19 15:49:22 +03:00
|
|
|
} else {
|
2018-06-30 02:43:40 +03:00
|
|
|
assert(0 && "Bad intrinsic");
|
|
|
|
}
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<Function *> calledFunc = getFunctionsWithPrefix(m_module, name);
|
2018-06-30 02:43:40 +03:00
|
|
|
if (calledFunc.empty())
|
|
|
|
continue;
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<CallInst *> calls = getCallsToFunction(calledFunc[0]);
|
2018-06-30 02:43:40 +03:00
|
|
|
if (calls.empty())
|
|
|
|
continue;
|
|
|
|
|
|
|
|
bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler");
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *pendingFunc = get(pendingIntrinsics, intrinsicName.str());
|
|
|
|
Function *funcInModule = nullptr;
|
|
|
|
Function *pendingFuncInModule = nullptr;
|
|
|
|
for (CallInst *call : calls) {
|
|
|
|
Function *caller = call->getParent()->getParent();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function"))
|
|
|
|
continue;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *F = nullptr;
|
|
|
|
if (pendingFunc && shouldUsePendingValue(caller, intrinsicName)) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!pendingFuncInModule)
|
|
|
|
pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc);
|
|
|
|
F = pendingFuncInModule;
|
2023-09-19 15:49:22 +03:00
|
|
|
} else {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!funcInModule)
|
|
|
|
funcInModule = getOrInsertFunction(m_module, func);
|
|
|
|
F = funcInModule;
|
|
|
|
}
|
|
|
|
|
|
|
|
// insert runtime data and the rest of the arguments
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<Value *> args;
|
2018-06-30 02:43:40 +03:00
|
|
|
if (needsRuntimeDataArg)
|
|
|
|
args.push_back(caller->arg_begin());
|
|
|
|
int argIdx = 0;
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto &arg : call->arg_operands()) {
|
2018-06-30 02:43:40 +03:00
|
|
|
if (argIdx++ == 0 && isDxilOp)
|
|
|
|
continue; // skip the intrinsic number
|
2023-09-19 15:49:22 +03:00
|
|
|
if (!expandFloat3(args, arg, call) &&
|
|
|
|
!float3x4ToFloat12(args, arg, call))
|
2018-06-30 02:43:40 +03:00
|
|
|
args.push_back(arg);
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
CallInst *newCall = CallInst::Create(F, args, "", call);
|
|
|
|
if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C)) {
|
2018-06-30 02:43:40 +03:00
|
|
|
newCall->takeName(call);
|
|
|
|
call->replaceAllUsesWith(newCall);
|
|
|
|
}
|
|
|
|
call->eraseFromParent();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Type *DxrFallbackCompiler::getRuntimeDataArgType() {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Get the first argument from a known runtime function (assuming the runtime
|
|
|
|
// has already been linked in).
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *F = m_module->getFunction("stackIntPtr");
|
2018-06-30 02:43:40 +03:00
|
|
|
return F->arg_begin()->getType();
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *DxrFallbackCompiler::createDispatchFunction(
|
|
|
|
const IntToFuncMap &stateFunctionMap, Type *runtimeDataArgTy) {
|
|
|
|
LLVMContext &context = m_module->getContext();
|
|
|
|
FunctionType *stateFuncTy =
|
|
|
|
FunctionType::get(Type::getInt32Ty(context), {runtimeDataArgTy}, false);
|
|
|
|
|
|
|
|
Function *dispatchFunc = FunctionBuilder(m_module, "dispatch")
|
|
|
|
.i32()
|
|
|
|
.type(runtimeDataArgTy, "runtimeData")
|
|
|
|
.i32("stateID")
|
|
|
|
.build();
|
|
|
|
Value *runtimeDataArg = dispatchFunc->arg_begin();
|
|
|
|
Value *stateIdArg = ++dispatchFunc->arg_begin();
|
|
|
|
BasicBlock *entryBlock = BasicBlock::Create(context, "entry", dispatchFunc);
|
|
|
|
BasicBlock *badBlock =
|
|
|
|
BasicBlock::Create(context, "badStateID", dispatchFunc);
|
2018-06-30 02:43:40 +03:00
|
|
|
IRBuilder<> builder(badBlock);
|
|
|
|
builder.SetInsertPoint(badBlock);
|
|
|
|
builder.CreateRet(makeInt32(-3, context)); // return an error value
|
|
|
|
|
|
|
|
builder.SetInsertPoint(entryBlock);
|
2023-09-19 15:49:22 +03:00
|
|
|
SwitchInst *switchInst =
|
|
|
|
builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size());
|
|
|
|
BasicBlock *endBlock = badBlock;
|
|
|
|
for (auto &kv : stateFunctionMap) {
|
2018-06-30 02:43:40 +03:00
|
|
|
int stateId = kv.first;
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *stateFunc = kv.second;
|
2018-06-30 02:43:40 +03:00
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *stateFuncInModule =
|
|
|
|
m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy);
|
|
|
|
BasicBlock *block = BasicBlock::Create(
|
|
|
|
context, "state_" + Twine(stateId) + "." + stateFunc->getName(),
|
|
|
|
dispatchFunc, endBlock);
|
2018-06-30 02:43:40 +03:00
|
|
|
builder.SetInsertPoint(block);
|
2023-09-19 15:49:22 +03:00
|
|
|
Value *nextStateId =
|
|
|
|
builder.CreateCall(stateFuncInModule, {runtimeDataArg}, "nextStateId");
|
2018-06-30 02:43:40 +03:00
|
|
|
builder.CreateRet(nextStateId);
|
|
|
|
|
|
|
|
switchInst->addCase(makeInt32(stateId, context), block);
|
|
|
|
}
|
|
|
|
|
|
|
|
return dispatchFunc;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<CallInst *>
|
|
|
|
DxrFallbackCompiler::getCallsInShadersToFunction(const std::string &funcName) {
|
|
|
|
std::vector<CallInst *> calls;
|
|
|
|
Function *F = m_module->getFunction(funcName);
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!F)
|
|
|
|
return calls;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
for (User *U : F->users()) {
|
|
|
|
CallInst *call = dyn_cast<CallInst>(U);
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!call)
|
|
|
|
continue;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *caller = call->getParent()->getParent();
|
2018-06-30 02:43:40 +03:00
|
|
|
auto it = m_shaderMap.find(cleanName(caller->getName()));
|
|
|
|
if (it != m_shaderMap.end())
|
|
|
|
calls.push_back(call);
|
|
|
|
}
|
|
|
|
return calls;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<CallInst *>
|
|
|
|
DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix(
|
|
|
|
const std::string &funcNamePrefix) {
|
|
|
|
std::vector<CallInst *> calls;
|
|
|
|
for (Function *F : getFunctionsWithPrefix(m_module, funcNamePrefix)) {
|
|
|
|
for (User *U : F->users()) {
|
|
|
|
CallInst *call = dyn_cast<CallInst>(U);
|
2018-06-30 02:43:40 +03:00
|
|
|
if (!call)
|
|
|
|
continue;
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
Function *caller = call->getParent()->getParent();
|
2018-06-30 02:43:40 +03:00
|
|
|
if (m_shaderMap.count(cleanName(caller->getName())))
|
|
|
|
calls.push_back(call);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return calls;
|
|
|
|
}
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
void DxrFallbackCompiler::resizeStack(Function *F, unsigned sizeInBytes) {
|
2018-06-30 02:43:40 +03:00
|
|
|
// Find the stack
|
2023-09-19 15:49:22 +03:00
|
|
|
AllocaInst *stack = nullptr;
|
|
|
|
for (auto &I : F->getEntryBlock().getInstList()) {
|
|
|
|
AllocaInst *alloc = dyn_cast<AllocaInst>(&I);
|
|
|
|
if (alloc && alloc->getName().startswith("theStack")) {
|
2018-06-30 02:43:40 +03:00
|
|
|
stack = alloc;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!stack)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Create a new stack
|
2023-09-19 15:49:22 +03:00
|
|
|
LLVMContext &C = F->getContext();
|
|
|
|
ArrayType *newStackTy =
|
|
|
|
ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int));
|
|
|
|
AllocaInst *newStack = new AllocaInst(newStackTy, "", stack);
|
2018-06-30 02:43:40 +03:00
|
|
|
newStack->takeName(stack);
|
|
|
|
|
|
|
|
// Remap all GEPs - replaceAllUsesWith() won't change types
|
2023-09-19 15:49:22 +03:00
|
|
|
for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE;) {
|
|
|
|
GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(*U++);
|
2018-06-30 02:43:40 +03:00
|
|
|
assert(gep && "theStack has non-gep user.");
|
|
|
|
|
2023-09-19 15:49:22 +03:00
|
|
|
std::vector<Value *> idxList(gep->idx_begin(), gep->idx_end());
|
|
|
|
GetElementPtrInst *newGep =
|
|
|
|
GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep);
|
2018-06-30 02:43:40 +03:00
|
|
|
newGep->takeName(gep);
|
|
|
|
gep->replaceAllUsesWith(newGep);
|
|
|
|
gep->eraseFromParent();
|
|
|
|
}
|
|
|
|
|
|
|
|
stack->eraseFromParent();
|
|
|
|
}
|