864 строки
30 KiB
C++
864 строки
30 KiB
C++
#include "dxc/DxrFallback/DxrFallbackCompiler.h"
|
|
|
|
#include "dxc/Support/Global.h"
|
|
#include "dxc/Support/Unicode.h"
|
|
#include "dxc/Support/WinIncludes.h"
|
|
#include "dxc/Support/FileIOHelper.h"
|
|
#include "dxc/dxcapi.h"
|
|
#include "dxc/dxcdxrfallbackcompiler.h"
|
|
#include "dxc/Support/dxcapi.use.h"
|
|
#include "dxc/Support/dxcapi.impl.h"
|
|
#include "dxc/DXIL/DxilModule.h"
|
|
#include "dxc/HLSL/DxilLinker.h"
|
|
#include "dxc/DXIL/DxilFunctionProps.h"
|
|
#include "dxc/DXIL/DxilOperations.h"
|
|
#include "dxc/DXIL/DxilInstructions.h"
|
|
|
|
#include "llvm/Analysis/CallGraph.h"
|
|
#include "llvm/IR/InstIterator.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/LegacyPassManager.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/Linker/Linker.h"
|
|
#include "llvm/Transforms/IPO.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include "llvm/Transforms/Utils/Cloning.h"
|
|
|
|
#include "FunctionBuilder.h"
|
|
#include "LLVMUtils.h"
|
|
#include "runtime.h"
|
|
#include "StateFunctionTransform.h"
|
|
|
|
#include <queue>
|
|
|
|
using namespace hlsl;
|
|
using namespace llvm;
|
|
|
|
static std::vector<Function*> getFunctionsWithPrefix(Module* module, const std::string& prefix)
|
|
{
|
|
std::vector<Function*> functions;
|
|
for (auto F = module->begin(), E = module->end(); F != E; ++F)
|
|
{
|
|
StringRef name = F->getName();
|
|
if (name.startswith(prefix))
|
|
functions.push_back(F);
|
|
}
|
|
return functions;
|
|
}
|
|
|
|
|
|
static bool inlineFunc(CallInst* call, Function* Fimpl)
|
|
{
|
|
// Note. LLVM inlining may not be sufficient if the function references DX
|
|
// resources because the corresponding metadata is not created if the function
|
|
// comes from another module.
|
|
|
|
// Make sure that we have a definition for the called function in this module
|
|
Function* F = call->getCalledFunction();
|
|
Module* dstM = F->getParent();
|
|
if (F->isDeclaration())
|
|
{
|
|
// Map called functions in impl module to functions in this one (because the
|
|
// cloning step doesn't do this automatically)
|
|
ValueToValueMapTy VMap;
|
|
for (auto& I : inst_range(Fimpl))
|
|
{
|
|
if (CallInst* c = dyn_cast<CallInst>(&I))
|
|
{
|
|
Function* calledFimpl = c->getCalledFunction();
|
|
if (VMap.count(calledFimpl))
|
|
continue;
|
|
|
|
Constant* calledF = dstM->getOrInsertFunction(calledFimpl->getName(), calledFimpl->getFunctionType(), calledFimpl->getAttributes());
|
|
VMap[calledFimpl] = calledF;
|
|
}
|
|
}
|
|
|
|
// Map arguments
|
|
for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(), DI = F->arg_begin(); SI != SE; ++SI, ++DI)
|
|
VMap[SI] = DI;
|
|
|
|
SmallVector<ReturnInst*, 4> returns;
|
|
CloneFunctionInto(F, Fimpl, VMap, true, returns);
|
|
F->setLinkage(GlobalValue::InternalLinkage);
|
|
}
|
|
|
|
InlineFunctionInfo IFI;
|
|
return InlineFunction(call, IFI, false);
|
|
}
|
|
|
|
|
|
// Remove ELF mangling
|
|
static std::string cleanName(StringRef name)
|
|
{
|
|
if (!name.startswith("\x1?"))
|
|
return name;
|
|
|
|
size_t pos = name.find("@@");
|
|
if (pos == name.npos)
|
|
return name;
|
|
|
|
std::string newName = name.substr(2, pos - 2);
|
|
return newName;
|
|
}
|
|
|
|
|
|
static inline Function* getOrInsertFunction(Module* module, Function* F)
|
|
{
|
|
return dyn_cast<Function>(module->getOrInsertFunction(F->getName(), F->getFunctionType()));
|
|
}
|
|
|
|
|
|
template<typename K, typename V>
|
|
V get(std::map<K, V>& theMap, const K& key, V defaultVal = static_cast<V>(nullptr))
|
|
{
|
|
auto it = theMap.find(key);
|
|
if (it == theMap.end())
|
|
return defaultVal;
|
|
else
|
|
return it->second;
|
|
}
|
|
|
|
|
|
DxrFallbackCompiler::DxrFallbackCompiler(llvm::Module* module, const std::vector<std::string>& shaderNames, unsigned maxAttributeSize, unsigned stackSizeInBytes, bool findCalledShaders /*= false*/)
|
|
: m_module(module)
|
|
, m_entryShaderNames(shaderNames)
|
|
, m_stackSizeInBytes(stackSizeInBytes)
|
|
, m_maxAttributeSize(maxAttributeSize)
|
|
, m_findCalledShaders(findCalledShaders)
|
|
{}
|
|
|
|
void DxrFallbackCompiler::compile(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
|
|
{
|
|
std::vector<std::string> shaderNames = m_entryShaderNames;
|
|
initShaderMap(shaderNames);
|
|
|
|
// Bring in runtime so we can get the runtime data type
|
|
linkRuntime();
|
|
Type* runtimeDataArgTy = getRuntimeDataArgType();
|
|
|
|
// Make sure all calls to intrinsics and shaders are at function scope and
|
|
// fix up control flow.
|
|
lowerAnyHitControlFlowFuncs();
|
|
lowerReportHit();
|
|
lowerTraceRay(runtimeDataArgTy);
|
|
|
|
// Create state functions
|
|
IntToFuncMap stateFunctionMap; // stateID -> state function
|
|
const int baseStateId = 1000; // could be anything but this makes stateIds more recognizable
|
|
createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes, baseStateId, shaderNames, runtimeDataArgTy);
|
|
|
|
if (pCachedMap)
|
|
{
|
|
for (auto &entry : stateFunctionMap)
|
|
{
|
|
(*pCachedMap)[entry.first] = entry.second->getName().str();
|
|
}
|
|
}
|
|
}
|
|
|
|
void DxrFallbackCompiler::link(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
|
|
{
|
|
IntToFuncMap stateFunctionMap; // stateID -> state function
|
|
if (pCachedMap)
|
|
{
|
|
for (auto entry : *pCachedMap)
|
|
{
|
|
stateFunctionMap[entry.first] = m_module->getFunction(entry.second);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for (UINT i = 0; i < shaderEntryStateIds.size(); i++)
|
|
{
|
|
UINT substateIndex = 0;
|
|
UINT baseStateId = shaderEntryStateIds[i];
|
|
while (true)
|
|
{
|
|
auto substateName = m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex);
|
|
|
|
auto function = m_module->getFunction(substateName);
|
|
if (!function) break;
|
|
stateFunctionMap[baseStateId + substateIndex] = m_module->getFunction(substateName);
|
|
substateIndex++;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fix up scheduler
|
|
Function* schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler");
|
|
createLaunchParams(schedulerFunc);
|
|
|
|
Type* runtimeDataArgTy = getRuntimeDataArgType();
|
|
createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy);
|
|
createStack(schedulerFunc);
|
|
|
|
lowerIntrinsics();
|
|
}
|
|
|
|
|
|
void DxrFallbackCompiler::setDebugOutputLevel(int val)
|
|
{
|
|
m_debugOutputLevel = val;
|
|
}
|
|
|
|
static bool isShader(Function* F)
|
|
{
|
|
if (F->hasFnAttribute("exp-shader"))
|
|
return true;
|
|
|
|
DxilModule& DM = F->getParent()->GetDxilModule();
|
|
return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay());
|
|
}
|
|
|
|
DXIL::ShaderKind getRayShaderKind(Function* F)
|
|
{
|
|
if (F->hasFnAttribute("exp-shader"))
|
|
return DXIL::ShaderKind::RayGeneration;
|
|
|
|
DxilModule& DM = F->getParent()->GetDxilModule();
|
|
if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
|
|
return DM.GetDxilFunctionProps(F).shaderKind;
|
|
|
|
return DXIL::ShaderKind::Invalid;
|
|
}
|
|
|
|
|
|
// Some shaders should use the "pending" values of intrinsics instead of the
|
|
// committed ones. In particular anyhit and intersection shaders use the
|
|
// pending values with the exception that the committed rayTCurrent should be
|
|
// used in intersection.
|
|
static bool shouldUsePendingValue(Function* F, StringRef instrinsicName)
|
|
{
|
|
DxilModule& DM = F->getParent()->GetDxilModule();
|
|
if (!DM.HasDxilFunctionProps(F))
|
|
return false;
|
|
const hlsl::DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
|
|
|
|
return props.IsAnyHit() || (props.IsIntersection() && instrinsicName != "rayTCurrent");
|
|
}
|
|
|
|
void DxrFallbackCompiler::initShaderMap(std::vector<std::string>& shaderNames)
|
|
{
|
|
// Clean names and initialize shaderMap
|
|
StringToFuncMap allShadersMap;
|
|
for (Function& F : m_module->functions())
|
|
{
|
|
if (isShader(&F))
|
|
{
|
|
if (!F.isDeclaration())
|
|
allShadersMap[cleanName(F.getName())] = &F;
|
|
}
|
|
|
|
F.removeFnAttr(Attribute::NoInline);
|
|
}
|
|
|
|
|
|
for (auto& name : shaderNames)
|
|
m_shaderMap[name] = allShadersMap[name];
|
|
|
|
|
|
if (!m_findCalledShaders)
|
|
return;
|
|
|
|
|
|
// Create a map from shader name to CallGraphNode
|
|
CallGraph callGraph(*m_module);
|
|
std::map<std::string, CallGraphNode*> allShaderNodes;
|
|
for (auto& kv : m_shaderMap)
|
|
{
|
|
const std::string& name = kv.first;
|
|
Function* func = kv.second;
|
|
allShaderNodes[name] = callGraph[func];
|
|
}
|
|
|
|
// Start traversing the call graph from given shaderNames
|
|
std::deque<CallGraphNode*> workList;
|
|
for (auto& name : shaderNames)
|
|
workList.push_back(allShaderNodes[name]);
|
|
while (!workList.empty())
|
|
{
|
|
CallGraphNode* cur = workList.front();
|
|
workList.pop_front();
|
|
for (size_t i = 0; i < cur->size(); ++i)
|
|
{
|
|
Function* nextFunc = (*cur)[i]->getFunction();
|
|
if (!nextFunc)
|
|
continue;
|
|
if (isShader(nextFunc))
|
|
{
|
|
const std::string nextName = cleanName(nextFunc->getName());
|
|
if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet?
|
|
{
|
|
workList.push_back(allShaderNodes[nextName]);
|
|
shaderNames.push_back(nextName);
|
|
m_shaderMap[nextName] = workList.back()->getFunction();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void DxrFallbackCompiler::linkRuntime()
|
|
{
|
|
Linker linker(m_module);
|
|
std::unique_ptr<Module> runtimeModule = loadModuleFromAsmString(m_module->getContext(), getRuntimeString());
|
|
bool linkErr = linker.linkInModule(runtimeModule.get());
|
|
assert(!linkErr && "Error linking runtime");
|
|
UNREFERENCED_PARAMETER(linkErr);
|
|
|
|
}
|
|
|
|
static void inlineFuncAndAddRet(CallInst* call, Function*F)
|
|
{
|
|
// Add a return after the function call.
|
|
// Should be followed immediately by "unreachable". Turn that into a "ret void".
|
|
Instruction* ret = ReturnInst::Create(call->getContext());
|
|
ReplaceInstWithInst(call->getParent()->getTerminator(), ret);
|
|
|
|
bool success = inlineFunc(call, F);
|
|
assert(success);
|
|
UNREFERENCED_PARAMETER(success);
|
|
}
|
|
|
|
void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs()
|
|
{
|
|
std::vector<CallInst*> callsToIgnoreHit = getCallsInShadersToFunction("dx.op.ignoreHit");
|
|
if (!callsToIgnoreHit.empty())
|
|
{
|
|
Function* ignoreHitFunc = m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ");
|
|
assert(ignoreHitFunc && "IgnoreHit() implementation not found");
|
|
for (CallInst* call : callsToIgnoreHit)
|
|
inlineFuncAndAddRet(call, ignoreHitFunc);
|
|
}
|
|
|
|
std::vector<CallInst*> callsToAcceptHitAndEndSearch = getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch");
|
|
if (!callsToAcceptHitAndEndSearch.empty())
|
|
{
|
|
Function* acceptHitAndEndSearchFunc = m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ");
|
|
assert(acceptHitAndEndSearchFunc && "AcceptHitAndEndSearch() implementation not found");
|
|
for (CallInst* call : callsToAcceptHitAndEndSearch)
|
|
inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc);
|
|
}
|
|
}
|
|
|
|
void DxrFallbackCompiler::lowerReportHit()
|
|
{
|
|
std::vector<CallInst*> callsToReportHit = getCallsInShadersToFunctionWithPrefix("dx.op.reportHit");
|
|
if (callsToReportHit.empty())
|
|
return;
|
|
|
|
Function* reportHitFunc = m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z");
|
|
assert(reportHitFunc && "ReportHit() implementation not found");
|
|
|
|
LLVMContext& C = m_module->getContext();
|
|
for (CallInst* call : callsToReportHit)
|
|
{
|
|
// Wrap attribute arguments in Fallback_SetPendingAttr() call
|
|
Instruction* insertBefore = call;
|
|
hlsl::DxilInst_ReportHit reportHitCall(call);
|
|
|
|
Value* attr = reportHitCall.get_Attributes();
|
|
Function* setPendingAttrFunc = FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@").voidTy().type(attr->getType(), "attr").build();
|
|
CallInst::Create(setPendingAttrFunc, { attr }, "", insertBefore);
|
|
|
|
// Make call to implementation and load result
|
|
CallInst* callImpl = CallInst::Create(reportHitFunc, { reportHitCall.get_THit(), reportHitCall.get_HitKind() }, "reportHit.result", insertBefore);
|
|
Value* result = callImpl;
|
|
|
|
// Result < 0 ==> ret
|
|
Value* zero = makeInt32(0, C);
|
|
Value* ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero, "endSearch");
|
|
BasicBlock* prevBlock = call->getParent();
|
|
BasicBlock* retBlock = prevBlock->splitBasicBlock(call, "endSearch");
|
|
BasicBlock* nextBlock = retBlock->splitBasicBlock(call, "afterReportHit");
|
|
ReplaceInstWithInst(prevBlock->getTerminator(), BranchInst::Create(retBlock, nextBlock, ltz));
|
|
ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C));
|
|
|
|
// Compare result to zero and store into original result
|
|
Value* gtz = new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted");
|
|
call->replaceAllUsesWith(gtz);
|
|
|
|
bool success = inlineFunc(callImpl, reportHitFunc);
|
|
assert(success);
|
|
(void)success;
|
|
|
|
call->eraseFromParent();
|
|
}
|
|
}
|
|
|
|
void DxrFallbackCompiler::lowerTraceRay(Type* runtimeDataArgTy)
|
|
{
|
|
std::vector<CallInst*> callsToTraceRay = getCallsInShadersToFunctionWithPrefix("dx.op.traceRay");
|
|
if (callsToTraceRay.empty())
|
|
{
|
|
// TODO: It might be worth dropping this from the tests eventually
|
|
callsToTraceRay = getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@");
|
|
if (callsToTraceRay.empty())
|
|
return;
|
|
}
|
|
|
|
std::vector<Function*> traceRayImpl = getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@");
|
|
assert(traceRayImpl.size() == 1 && "Could not find Fallback_TraceRay() implementation");
|
|
|
|
enum { CLOSEST_HIT = 0, MISS = 1 };
|
|
Function* traceRaySave[] = { m_module->getFunction("traceRaySave_ClosestHit"), m_module->getFunction("traceRaySave_Miss") };
|
|
Function* traceRayRestore[] = { m_module->getFunction("traceRayRestore_ClosestHit"), m_module->getFunction("traceRayRestore_Miss") };
|
|
assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] && traceRaySave[MISS] && traceRayRestore[MISS] &&
|
|
"Could not find TraceRay spill functions");
|
|
|
|
Function* dummyRuntimeDataArgFunc = StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module, runtimeDataArgTy);
|
|
assert(dummyRuntimeDataArgFunc && "dummyRuntimeDataArg function could not be created.");
|
|
|
|
// Process calls
|
|
LLVMContext& C = m_module->getContext();
|
|
Type* int32Ty = Type::getInt32Ty(C);
|
|
std::map<FunctionType*, Function*> movePayloadToStackFuncs;
|
|
std::map<Function*, AllocaInst*> funcToSpillAlloca;
|
|
for (CallInst* call : callsToTraceRay)
|
|
{
|
|
Instruction* insertBefore = call;
|
|
|
|
|
|
// Spill runtime data values, if necessary (closesthit and miss shaders)
|
|
Function* caller = call->getParent()->getParent();
|
|
DXIL::ShaderKind kind = getRayShaderKind(caller);
|
|
if (kind == DXIL::ShaderKind::ClosestHit || kind == DXIL::ShaderKind::Miss)
|
|
{
|
|
int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS;
|
|
AllocaInst* spillAlloca = get(funcToSpillAlloca, caller);
|
|
if (!spillAlloca)
|
|
{
|
|
Argument* spillAllocaArg = (++traceRaySave[sh]->arg_begin());
|
|
Type* spillAllocaTy = spillAllocaArg->getType()->getPointerElementType();
|
|
spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca", caller->getEntryBlock().begin());
|
|
funcToSpillAlloca[caller] = spillAlloca;
|
|
}
|
|
|
|
// Create calls. SFT will inline them.
|
|
Value* runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc, "runtimeData", insertBefore);
|
|
CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "", insertBefore);
|
|
CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "", getInstructionAfter(call));
|
|
}
|
|
|
|
|
|
// Get the payload offset to pass to trace implementation
|
|
//hlsl::DxilInst_TraceRay traceRayCall(call);
|
|
// TODO: Avoiding the intrinsic to support the test's use of TraceRayTest
|
|
Value* payload = call->getOperand(call->getNumArgOperands() - 1);
|
|
FunctionType* funcType = FunctionType::get(int32Ty, { payload->getType() }, false);
|
|
Function* movePayloadToStackFunc = getOrCreateFunction("movePayloadToStack", m_module, funcType, movePayloadToStackFuncs);
|
|
Value* newPayloadOffset = CallInst::Create(movePayloadToStackFunc, { payload }, "new.payload.offset", insertBefore);
|
|
|
|
// Call implementation
|
|
unsigned i = 0;
|
|
if (call->getCalledFunction()->getName().startswith("dx.op"))
|
|
i += 2; // skip intrinsic number and acceleration structure (for now)
|
|
std::vector<Value*> args;
|
|
for (; i < call->getNumArgOperands() - 1; ++i)
|
|
args.push_back(call->getArgOperand(i));
|
|
args.push_back(newPayloadOffset);
|
|
CallInst::Create(traceRayImpl[0], args, "", insertBefore);
|
|
|
|
call->eraseFromParent();
|
|
}
|
|
}
|
|
|
|
static std::vector<StateFunctionTransform::ParameterSemanticType> getParameterTypes(Function* F, DXIL::ShaderKind shaderKind)
|
|
{
|
|
std::vector<StateFunctionTransform::ParameterSemanticType> paramTypes;
|
|
if (shaderKind == DXIL::ShaderKind::AnyHit || shaderKind == DXIL::ShaderKind::ClosestHit)
|
|
{
|
|
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
|
|
paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE);
|
|
}
|
|
else if (shaderKind == DXIL::ShaderKind::Miss)
|
|
{
|
|
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
|
|
}
|
|
else
|
|
{
|
|
paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE);
|
|
}
|
|
return paramTypes;
|
|
}
|
|
|
|
static void collectResources(DxilModule& DM, std::set<Value*>& resources)
|
|
{
|
|
for (auto& r : DM.GetCBuffers())
|
|
resources.insert(r->GetGlobalSymbol());
|
|
for (auto& r : DM.GetUAVs())
|
|
resources.insert(r->GetGlobalSymbol());
|
|
for (auto& r : DM.GetSRVs())
|
|
resources.insert(r->GetGlobalSymbol());
|
|
for (auto& r : DM.GetSamplers())
|
|
resources.insert(r->GetGlobalSymbol());
|
|
}
|
|
|
|
|
|
void DxrFallbackCompiler::createStateFunctions(
|
|
IntToFuncMap& stateFunctionMap,
|
|
std::vector<int>& shaderEntryStateIds,
|
|
std::vector<unsigned int>& shaderStackSizes,
|
|
int baseStateId,
|
|
const std::vector<std::string>& shaderNames,
|
|
Type* runtimeDataArgTy
|
|
)
|
|
{
|
|
for (auto& kv : m_shaderMap)
|
|
{
|
|
if (kv.second == nullptr)
|
|
errs() << "Function not found for shader " << kv.first << "\n";
|
|
}
|
|
|
|
DxilModule& DM = m_module->GetOrCreateDxilModule();
|
|
std::set<Value*> resources;
|
|
collectResources(DM, resources);
|
|
|
|
shaderEntryStateIds.clear();
|
|
shaderStackSizes.clear();
|
|
int stateId = baseStateId;
|
|
for (auto& shader : shaderNames)
|
|
{
|
|
std::vector<Function*> stateFunctions;
|
|
Function* F = m_shaderMap[shader];
|
|
StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy);
|
|
if (m_debugOutputLevel >= 2)
|
|
sft.setVerbose(true);
|
|
if (m_debugOutputLevel >= 3)
|
|
sft.setDumpFilename("dump.ll");
|
|
if (shader == "Fallback_TraceRay")
|
|
sft.setAttributeSize(m_maxAttributeSize);
|
|
DXIL::ShaderKind shaderKind = getRayShaderKind(F);
|
|
if (shaderKind != DXIL::ShaderKind::Invalid)
|
|
sft.setParameterInfo(getParameterTypes(F, shaderKind), shaderKind == DXIL::ShaderKind::ClosestHit);
|
|
sft.setResourceGlobals(resources);
|
|
UINT shaderStackSize = 0;
|
|
sft.run(stateFunctions, shaderStackSize);
|
|
|
|
shaderEntryStateIds.push_back(stateId);
|
|
shaderStackSizes.push_back(shaderStackSize);
|
|
for (Function* stateF : stateFunctions)
|
|
{
|
|
stateFunctionMap[stateId++] = stateF;
|
|
if (DM.HasDxilFunctionProps(F)) {
|
|
DM.CloneDxilEntryProps(F, stateF);
|
|
}
|
|
}
|
|
}
|
|
|
|
StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds);
|
|
}
|
|
|
|
void DxrFallbackCompiler::createLaunchParams(Function* func)
|
|
{
|
|
Module* module = func->getParent();
|
|
Function* rewrite_setLaunchParams = module->getFunction("rewrite_setLaunchParams");
|
|
CallInst* call = dyn_cast<CallInst>(*rewrite_setLaunchParams->user_begin());
|
|
|
|
LLVMContext& context = module->getContext();
|
|
Instruction* insertBefore = call;
|
|
|
|
Function* DTidFunc = FunctionBuilder(module, "dx.op.threadId.i32").i32().i32().i32().build();
|
|
Value* DTidx = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(0, context) }, "DTidx", insertBefore);
|
|
Value* DTidy = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(1, context) }, "DTidy", insertBefore);
|
|
|
|
Value* dimx = call->getArgOperand(1);
|
|
Value* dimy = call->getArgOperand(2);
|
|
|
|
Function* groupIndexFunc = FunctionBuilder(module, "dx.op.flattenedThreadIdInGroup.i32").i32().i32().build();
|
|
Value* groupIndex = CallInst::Create(groupIndexFunc, { makeInt32(96, context) }, "groupIndex", insertBefore);
|
|
|
|
Function* fb_setLaunchParams = module->getFunction("fb_Fallback_SetLaunchParams");
|
|
Value* runtimeDataArg = call->getArgOperand(0);
|
|
CallInst::Create(fb_setLaunchParams, { runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex }, "", insertBefore);
|
|
|
|
call->eraseFromParent();
|
|
rewrite_setLaunchParams->eraseFromParent();
|
|
}
|
|
|
|
void DxrFallbackCompiler::createStateDispatch(Function* func, const IntToFuncMap& stateFunctionMap, Type* runtimeDataArgTy)
|
|
{
|
|
Module* module = func->getParent();
|
|
Function* dispatchFunc = createDispatchFunction(stateFunctionMap, runtimeDataArgTy);
|
|
Function* rewrite_dispatchFunc = module->getFunction("rewrite_dispatch");
|
|
rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc);
|
|
rewrite_dispatchFunc->eraseFromParent();
|
|
}
|
|
|
|
void DxrFallbackCompiler::createStack(Function* func)
|
|
{
|
|
LLVMContext& context = func->getContext();
|
|
|
|
// We would like to allocate the properly sized stack here, but DXIL doesn't
|
|
// allow bitcasts between objects of different sizes. So we have to use the
|
|
// default size from the runtime and replace all the accesses later.
|
|
Function* rewrite_createStack = m_module->getFunction("rewrite_createStack");
|
|
CallInst* call = dyn_cast<CallInst>(*rewrite_createStack->user_begin());
|
|
AllocaInst* stack = new AllocaInst(call->getType()->getPointerElementType(), "theStack", call);
|
|
stack->setAlignment(sizeof(int));
|
|
call->replaceAllUsesWith(stack);
|
|
call->eraseFromParent();
|
|
rewrite_createStack->eraseFromParent();
|
|
|
|
if (m_stackSizeInBytes == 0) // Take the default
|
|
m_stackSizeInBytes = stack->getType()->getPointerElementType()->getArrayNumElements() * sizeof(int);
|
|
Function* rewrite_getStackSize = m_module->getFunction("rewrite_getStackSize");
|
|
call = dyn_cast<CallInst>(*rewrite_getStackSize->user_begin());
|
|
Value* stackSizeVal = makeInt32(m_stackSizeInBytes, context);
|
|
call->replaceAllUsesWith(stackSizeVal);
|
|
call->eraseFromParent();
|
|
rewrite_getStackSize->eraseFromParent();
|
|
}
|
|
|
|
// WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime
|
|
static bool expandFloat3(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
|
|
{
|
|
VectorType* argTy = dyn_cast<VectorType>(arg->getType());
|
|
if (!argTy || argTy->getVectorNumElements() != 3)
|
|
return false;
|
|
|
|
LLVMContext& C = arg->getContext();
|
|
args.push_back(ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore));
|
|
args.push_back(ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore));
|
|
args.push_back(ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore));
|
|
|
|
return true;
|
|
}
|
|
|
|
static bool float3x4ToFloat12(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
|
|
{
|
|
StructType* STy = dyn_cast<StructType>(arg->getType());
|
|
if (!STy || STy->getName() != "class.matrix.float.3.4")
|
|
return false;
|
|
|
|
BasicBlock& entryBlock = insertBefore->getParent()->getParent()->getEntryBlock();
|
|
AllocaInst* alloca = new AllocaInst(arg->getType(), "tmp", entryBlock.begin());
|
|
new StoreInst(arg, alloca, insertBefore);
|
|
VectorType* VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12);
|
|
Value* vec12Ptr = new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore);
|
|
Value* vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore);
|
|
args.push_back(vec12);
|
|
|
|
return true;
|
|
}
|
|
|
|
void DxrFallbackCompiler::lowerIntrinsics()
|
|
{
|
|
std::vector<Function*> intrinsics = getFunctionsWithPrefix(m_module, "fb_");
|
|
assert(intrinsics.size() > 0);
|
|
|
|
|
|
// Replace intrinsics in anyhit shaders with their pending versions
|
|
LLVMContext& C = m_module->getContext();
|
|
std::map<std::string, Function*> pendingIntrinsics;
|
|
std::string pendingPrefixes[] = { "fb_dxop_pending_", "fb_Fallback_Pending" };
|
|
for (auto& F : intrinsics)
|
|
{
|
|
std::string intrinsicName;
|
|
if (F->getName().startswith(pendingPrefixes[0]))
|
|
intrinsicName = F->getName().substr(pendingPrefixes[0].length());
|
|
else if (F->getName().startswith(pendingPrefixes[1]))
|
|
intrinsicName = "Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str();
|
|
else
|
|
continue;
|
|
|
|
pendingIntrinsics[intrinsicName] = F;
|
|
}
|
|
|
|
for (Function* func : intrinsics)
|
|
{
|
|
StringRef intrinsicName;
|
|
std::string name;
|
|
bool isDxilOp = false;
|
|
if (func->getName().startswith("fb_Fallback_"))
|
|
{
|
|
intrinsicName = func->getName().substr(3); // after the "fb_" prefix
|
|
name = "\x1?" + intrinsicName.str();
|
|
}
|
|
else if (func->getName().startswith("fb_dxop_"))
|
|
{
|
|
intrinsicName = func->getName().substr(8);
|
|
name = "dx.op." + intrinsicName.str();
|
|
isDxilOp = true;
|
|
}
|
|
else
|
|
{
|
|
assert(0 && "Bad intrinsic");
|
|
}
|
|
std::vector<Function*> calledFunc = getFunctionsWithPrefix(m_module, name);
|
|
if (calledFunc.empty())
|
|
continue;
|
|
std::vector<CallInst*> calls = getCallsToFunction(calledFunc[0]);
|
|
if (calls.empty())
|
|
continue;
|
|
|
|
|
|
bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler");
|
|
Function* pendingFunc = get(pendingIntrinsics, intrinsicName.str());
|
|
Function* funcInModule = nullptr;
|
|
Function* pendingFuncInModule = nullptr;
|
|
for (CallInst* call : calls)
|
|
{
|
|
Function* caller = call->getParent()->getParent();
|
|
if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function"))
|
|
continue;
|
|
|
|
Function* F = nullptr;
|
|
if (pendingFunc && shouldUsePendingValue(caller, intrinsicName))
|
|
{
|
|
if (!pendingFuncInModule)
|
|
pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc);
|
|
F = pendingFuncInModule;
|
|
}
|
|
else
|
|
{
|
|
if (!funcInModule)
|
|
funcInModule = getOrInsertFunction(m_module, func);
|
|
F = funcInModule;
|
|
}
|
|
|
|
// insert runtime data and the rest of the arguments
|
|
std::vector<Value*> args;
|
|
if (needsRuntimeDataArg)
|
|
args.push_back(caller->arg_begin());
|
|
int argIdx = 0;
|
|
for (auto& arg : call->arg_operands())
|
|
{
|
|
if (argIdx++ == 0 && isDxilOp)
|
|
continue; // skip the intrinsic number
|
|
if (!expandFloat3(args, arg, call) && !float3x4ToFloat12(args, arg, call))
|
|
args.push_back(arg);
|
|
}
|
|
|
|
CallInst* newCall = CallInst::Create(F, args, "", call);
|
|
if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C))
|
|
{
|
|
newCall->takeName(call);
|
|
call->replaceAllUsesWith(newCall);
|
|
}
|
|
call->eraseFromParent();
|
|
}
|
|
}
|
|
}
|
|
|
|
Type* DxrFallbackCompiler::getRuntimeDataArgType()
|
|
{
|
|
// Get the first argument from a known runtime function (assuming the runtime
|
|
// has already been linked in).
|
|
Function* F = m_module->getFunction("stackIntPtr");
|
|
return F->arg_begin()->getType();
|
|
}
|
|
|
|
Function* DxrFallbackCompiler::createDispatchFunction(const IntToFuncMap &stateFunctionMap, Type* runtimeDataArgTy)
|
|
{
|
|
LLVMContext& context = m_module->getContext();
|
|
FunctionType* stateFuncTy = FunctionType::get(Type::getInt32Ty(context), { runtimeDataArgTy }, false);
|
|
|
|
Function* dispatchFunc = FunctionBuilder(m_module, "dispatch").i32().type(runtimeDataArgTy, "runtimeData").i32("stateID").build();
|
|
Value* runtimeDataArg = dispatchFunc->arg_begin();
|
|
Value* stateIdArg = ++dispatchFunc->arg_begin();
|
|
BasicBlock* entryBlock = BasicBlock::Create(context, "entry", dispatchFunc);
|
|
BasicBlock* badBlock = BasicBlock::Create(context, "badStateID", dispatchFunc);
|
|
IRBuilder<> builder(badBlock);
|
|
builder.SetInsertPoint(badBlock);
|
|
builder.CreateRet(makeInt32(-3, context)); // return an error value
|
|
|
|
builder.SetInsertPoint(entryBlock);
|
|
SwitchInst* switchInst = builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size());
|
|
BasicBlock* endBlock = badBlock;
|
|
for (auto& kv : stateFunctionMap)
|
|
{
|
|
int stateId = kv.first;
|
|
Function* stateFunc = kv.second;
|
|
|
|
Value* stateFuncInModule = m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy);
|
|
BasicBlock* block = BasicBlock::Create(context, "state_" + Twine(stateId) + "." + stateFunc->getName(), dispatchFunc, endBlock);
|
|
builder.SetInsertPoint(block);
|
|
Value* nextStateId = builder.CreateCall(stateFuncInModule, { runtimeDataArg }, "nextStateId");
|
|
builder.CreateRet(nextStateId);
|
|
|
|
switchInst->addCase(makeInt32(stateId, context), block);
|
|
}
|
|
|
|
return dispatchFunc;
|
|
}
|
|
|
|
std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunction(const std::string& funcName)
|
|
{
|
|
std::vector<CallInst*> calls;
|
|
Function* F = m_module->getFunction(funcName);
|
|
if (!F)
|
|
return calls;
|
|
|
|
for (User* U : F->users())
|
|
{
|
|
CallInst* call = dyn_cast<CallInst>(U);
|
|
if (!call)
|
|
continue;
|
|
|
|
Function* caller = call->getParent()->getParent();
|
|
auto it = m_shaderMap.find(cleanName(caller->getName()));
|
|
if (it != m_shaderMap.end())
|
|
calls.push_back(call);
|
|
}
|
|
return calls;
|
|
}
|
|
|
|
std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix(const std::string& funcNamePrefix)
|
|
{
|
|
std::vector<CallInst*> calls;
|
|
for (Function* F : getFunctionsWithPrefix(m_module, funcNamePrefix))
|
|
{
|
|
for (User* U : F->users())
|
|
{
|
|
CallInst* call = dyn_cast<CallInst>(U);
|
|
if (!call)
|
|
continue;
|
|
|
|
Function* caller = call->getParent()->getParent();
|
|
if (m_shaderMap.count(cleanName(caller->getName())))
|
|
calls.push_back(call);
|
|
}
|
|
}
|
|
return calls;
|
|
}
|
|
|
|
void DxrFallbackCompiler::resizeStack(Function* F, unsigned sizeInBytes)
|
|
{
|
|
// Find the stack
|
|
AllocaInst* stack = nullptr;
|
|
for (auto& I : F->getEntryBlock().getInstList())
|
|
{
|
|
AllocaInst* alloc = dyn_cast<AllocaInst>(&I);
|
|
if (alloc && alloc->getName().startswith("theStack"))
|
|
{
|
|
stack = alloc;
|
|
break;
|
|
}
|
|
}
|
|
if (!stack)
|
|
return;
|
|
|
|
// Create a new stack
|
|
LLVMContext& C = F->getContext();
|
|
ArrayType* newStackTy = ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int));
|
|
AllocaInst* newStack = new AllocaInst(newStackTy, "", stack);
|
|
newStack->takeName(stack);
|
|
|
|
// Remap all GEPs - replaceAllUsesWith() won't change types
|
|
for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE; )
|
|
{
|
|
GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(*U++);
|
|
assert(gep && "theStack has non-gep user.");
|
|
|
|
std::vector<Value*> idxList(gep->idx_begin(), gep->idx_end());
|
|
GetElementPtrInst* newGep = GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep);
|
|
newGep->takeName(gep);
|
|
gep->replaceAllUsesWith(newGep);
|
|
gep->eraseFromParent();
|
|
}
|
|
|
|
stack->eraseFromParent();
|
|
} |