DirectXShaderCompiler/lib/DxrFallback/DxrFallbackCompiler.cpp

864 строки
30 KiB
C++

#include "dxc/DxrFallback/DxrFallbackCompiler.h"
#include "dxc/Support/Global.h"
#include "dxc/Support/Unicode.h"
#include "dxc/Support/WinIncludes.h"
#include "dxc/Support/FileIOHelper.h"
#include "dxc/dxcapi.h"
#include "dxc/dxcdxrfallbackcompiler.h"
#include "dxc/Support/dxcapi.use.h"
#include "dxc/Support/dxcapi.impl.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/HLSL/DxilLinker.h"
#include "dxc/DXIL/DxilFunctionProps.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilInstructions.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "FunctionBuilder.h"
#include "LLVMUtils.h"
#include "runtime.h"
#include "StateFunctionTransform.h"
#include <queue>
using namespace hlsl;
using namespace llvm;
static std::vector<Function*> getFunctionsWithPrefix(Module* module, const std::string& prefix)
{
std::vector<Function*> functions;
for (auto F = module->begin(), E = module->end(); F != E; ++F)
{
StringRef name = F->getName();
if (name.startswith(prefix))
functions.push_back(F);
}
return functions;
}
static bool inlineFunc(CallInst* call, Function* Fimpl)
{
// Note. LLVM inlining may not be sufficient if the function references DX
// resources because the corresponding metadata is not created if the function
// comes from another module.
// Make sure that we have a definition for the called function in this module
Function* F = call->getCalledFunction();
Module* dstM = F->getParent();
if (F->isDeclaration())
{
// Map called functions in impl module to functions in this one (because the
// cloning step doesn't do this automatically)
ValueToValueMapTy VMap;
for (auto& I : inst_range(Fimpl))
{
if (CallInst* c = dyn_cast<CallInst>(&I))
{
Function* calledFimpl = c->getCalledFunction();
if (VMap.count(calledFimpl))
continue;
Constant* calledF = dstM->getOrInsertFunction(calledFimpl->getName(), calledFimpl->getFunctionType(), calledFimpl->getAttributes());
VMap[calledFimpl] = calledF;
}
}
// Map arguments
for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(), DI = F->arg_begin(); SI != SE; ++SI, ++DI)
VMap[SI] = DI;
SmallVector<ReturnInst*, 4> returns;
CloneFunctionInto(F, Fimpl, VMap, true, returns);
F->setLinkage(GlobalValue::InternalLinkage);
}
InlineFunctionInfo IFI;
return InlineFunction(call, IFI, false);
}
// Remove ELF mangling
static std::string cleanName(StringRef name)
{
if (!name.startswith("\x1?"))
return name;
size_t pos = name.find("@@");
if (pos == name.npos)
return name;
std::string newName = name.substr(2, pos - 2);
return newName;
}
static inline Function* getOrInsertFunction(Module* module, Function* F)
{
return dyn_cast<Function>(module->getOrInsertFunction(F->getName(), F->getFunctionType()));
}
template<typename K, typename V>
V get(std::map<K, V>& theMap, const K& key, V defaultVal = static_cast<V>(nullptr))
{
auto it = theMap.find(key);
if (it == theMap.end())
return defaultVal;
else
return it->second;
}
DxrFallbackCompiler::DxrFallbackCompiler(llvm::Module* module, const std::vector<std::string>& shaderNames, unsigned maxAttributeSize, unsigned stackSizeInBytes, bool findCalledShaders /*= false*/)
: m_module(module)
, m_entryShaderNames(shaderNames)
, m_stackSizeInBytes(stackSizeInBytes)
, m_maxAttributeSize(maxAttributeSize)
, m_findCalledShaders(findCalledShaders)
{}
void DxrFallbackCompiler::compile(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
{
std::vector<std::string> shaderNames = m_entryShaderNames;
initShaderMap(shaderNames);
// Bring in runtime so we can get the runtime data type
linkRuntime();
Type* runtimeDataArgTy = getRuntimeDataArgType();
// Make sure all calls to intrinsics and shaders are at function scope and
// fix up control flow.
lowerAnyHitControlFlowFuncs();
lowerReportHit();
lowerTraceRay(runtimeDataArgTy);
// Create state functions
IntToFuncMap stateFunctionMap; // stateID -> state function
const int baseStateId = 1000; // could be anything but this makes stateIds more recognizable
createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes, baseStateId, shaderNames, runtimeDataArgTy);
if (pCachedMap)
{
for (auto &entry : stateFunctionMap)
{
(*pCachedMap)[entry.first] = entry.second->getName().str();
}
}
}
void DxrFallbackCompiler::link(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap)
{
IntToFuncMap stateFunctionMap; // stateID -> state function
if (pCachedMap)
{
for (auto entry : *pCachedMap)
{
stateFunctionMap[entry.first] = m_module->getFunction(entry.second);
}
}
else
{
for (UINT i = 0; i < shaderEntryStateIds.size(); i++)
{
UINT substateIndex = 0;
UINT baseStateId = shaderEntryStateIds[i];
while (true)
{
auto substateName = m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex);
auto function = m_module->getFunction(substateName);
if (!function) break;
stateFunctionMap[baseStateId + substateIndex] = m_module->getFunction(substateName);
substateIndex++;
}
}
}
// Fix up scheduler
Function* schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler");
createLaunchParams(schedulerFunc);
Type* runtimeDataArgTy = getRuntimeDataArgType();
createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy);
createStack(schedulerFunc);
lowerIntrinsics();
}
void DxrFallbackCompiler::setDebugOutputLevel(int val)
{
m_debugOutputLevel = val;
}
static bool isShader(Function* F)
{
if (F->hasFnAttribute("exp-shader"))
return true;
DxilModule& DM = F->getParent()->GetDxilModule();
return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay());
}
DXIL::ShaderKind getRayShaderKind(Function* F)
{
if (F->hasFnAttribute("exp-shader"))
return DXIL::ShaderKind::RayGeneration;
DxilModule& DM = F->getParent()->GetDxilModule();
if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay())
return DM.GetDxilFunctionProps(F).shaderKind;
return DXIL::ShaderKind::Invalid;
}
// Some shaders should use the "pending" values of intrinsics instead of the
// committed ones. In particular anyhit and intersection shaders use the
// pending values with the exception that the committed rayTCurrent should be
// used in intersection.
static bool shouldUsePendingValue(Function* F, StringRef instrinsicName)
{
DxilModule& DM = F->getParent()->GetDxilModule();
if (!DM.HasDxilFunctionProps(F))
return false;
const hlsl::DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
return props.IsAnyHit() || (props.IsIntersection() && instrinsicName != "rayTCurrent");
}
void DxrFallbackCompiler::initShaderMap(std::vector<std::string>& shaderNames)
{
// Clean names and initialize shaderMap
StringToFuncMap allShadersMap;
for (Function& F : m_module->functions())
{
if (isShader(&F))
{
if (!F.isDeclaration())
allShadersMap[cleanName(F.getName())] = &F;
}
F.removeFnAttr(Attribute::NoInline);
}
for (auto& name : shaderNames)
m_shaderMap[name] = allShadersMap[name];
if (!m_findCalledShaders)
return;
// Create a map from shader name to CallGraphNode
CallGraph callGraph(*m_module);
std::map<std::string, CallGraphNode*> allShaderNodes;
for (auto& kv : m_shaderMap)
{
const std::string& name = kv.first;
Function* func = kv.second;
allShaderNodes[name] = callGraph[func];
}
// Start traversing the call graph from given shaderNames
std::deque<CallGraphNode*> workList;
for (auto& name : shaderNames)
workList.push_back(allShaderNodes[name]);
while (!workList.empty())
{
CallGraphNode* cur = workList.front();
workList.pop_front();
for (size_t i = 0; i < cur->size(); ++i)
{
Function* nextFunc = (*cur)[i]->getFunction();
if (!nextFunc)
continue;
if (isShader(nextFunc))
{
const std::string nextName = cleanName(nextFunc->getName());
if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet?
{
workList.push_back(allShaderNodes[nextName]);
shaderNames.push_back(nextName);
m_shaderMap[nextName] = workList.back()->getFunction();
}
}
}
}
}
void DxrFallbackCompiler::linkRuntime()
{
Linker linker(m_module);
std::unique_ptr<Module> runtimeModule = loadModuleFromAsmString(m_module->getContext(), getRuntimeString());
bool linkErr = linker.linkInModule(runtimeModule.get());
assert(!linkErr && "Error linking runtime");
UNREFERENCED_PARAMETER(linkErr);
}
static void inlineFuncAndAddRet(CallInst* call, Function*F)
{
// Add a return after the function call.
// Should be followed immediately by "unreachable". Turn that into a "ret void".
Instruction* ret = ReturnInst::Create(call->getContext());
ReplaceInstWithInst(call->getParent()->getTerminator(), ret);
bool success = inlineFunc(call, F);
assert(success);
UNREFERENCED_PARAMETER(success);
}
void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs()
{
std::vector<CallInst*> callsToIgnoreHit = getCallsInShadersToFunction("dx.op.ignoreHit");
if (!callsToIgnoreHit.empty())
{
Function* ignoreHitFunc = m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ");
assert(ignoreHitFunc && "IgnoreHit() implementation not found");
for (CallInst* call : callsToIgnoreHit)
inlineFuncAndAddRet(call, ignoreHitFunc);
}
std::vector<CallInst*> callsToAcceptHitAndEndSearch = getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch");
if (!callsToAcceptHitAndEndSearch.empty())
{
Function* acceptHitAndEndSearchFunc = m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ");
assert(acceptHitAndEndSearchFunc && "AcceptHitAndEndSearch() implementation not found");
for (CallInst* call : callsToAcceptHitAndEndSearch)
inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc);
}
}
void DxrFallbackCompiler::lowerReportHit()
{
std::vector<CallInst*> callsToReportHit = getCallsInShadersToFunctionWithPrefix("dx.op.reportHit");
if (callsToReportHit.empty())
return;
Function* reportHitFunc = m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z");
assert(reportHitFunc && "ReportHit() implementation not found");
LLVMContext& C = m_module->getContext();
for (CallInst* call : callsToReportHit)
{
// Wrap attribute arguments in Fallback_SetPendingAttr() call
Instruction* insertBefore = call;
hlsl::DxilInst_ReportHit reportHitCall(call);
Value* attr = reportHitCall.get_Attributes();
Function* setPendingAttrFunc = FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@").voidTy().type(attr->getType(), "attr").build();
CallInst::Create(setPendingAttrFunc, { attr }, "", insertBefore);
// Make call to implementation and load result
CallInst* callImpl = CallInst::Create(reportHitFunc, { reportHitCall.get_THit(), reportHitCall.get_HitKind() }, "reportHit.result", insertBefore);
Value* result = callImpl;
// Result < 0 ==> ret
Value* zero = makeInt32(0, C);
Value* ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero, "endSearch");
BasicBlock* prevBlock = call->getParent();
BasicBlock* retBlock = prevBlock->splitBasicBlock(call, "endSearch");
BasicBlock* nextBlock = retBlock->splitBasicBlock(call, "afterReportHit");
ReplaceInstWithInst(prevBlock->getTerminator(), BranchInst::Create(retBlock, nextBlock, ltz));
ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C));
// Compare result to zero and store into original result
Value* gtz = new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted");
call->replaceAllUsesWith(gtz);
bool success = inlineFunc(callImpl, reportHitFunc);
assert(success);
(void)success;
call->eraseFromParent();
}
}
void DxrFallbackCompiler::lowerTraceRay(Type* runtimeDataArgTy)
{
std::vector<CallInst*> callsToTraceRay = getCallsInShadersToFunctionWithPrefix("dx.op.traceRay");
if (callsToTraceRay.empty())
{
// TODO: It might be worth dropping this from the tests eventually
callsToTraceRay = getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@");
if (callsToTraceRay.empty())
return;
}
std::vector<Function*> traceRayImpl = getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@");
assert(traceRayImpl.size() == 1 && "Could not find Fallback_TraceRay() implementation");
enum { CLOSEST_HIT = 0, MISS = 1 };
Function* traceRaySave[] = { m_module->getFunction("traceRaySave_ClosestHit"), m_module->getFunction("traceRaySave_Miss") };
Function* traceRayRestore[] = { m_module->getFunction("traceRayRestore_ClosestHit"), m_module->getFunction("traceRayRestore_Miss") };
assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] && traceRaySave[MISS] && traceRayRestore[MISS] &&
"Could not find TraceRay spill functions");
Function* dummyRuntimeDataArgFunc = StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module, runtimeDataArgTy);
assert(dummyRuntimeDataArgFunc && "dummyRuntimeDataArg function could not be created.");
// Process calls
LLVMContext& C = m_module->getContext();
Type* int32Ty = Type::getInt32Ty(C);
std::map<FunctionType*, Function*> movePayloadToStackFuncs;
std::map<Function*, AllocaInst*> funcToSpillAlloca;
for (CallInst* call : callsToTraceRay)
{
Instruction* insertBefore = call;
// Spill runtime data values, if necessary (closesthit and miss shaders)
Function* caller = call->getParent()->getParent();
DXIL::ShaderKind kind = getRayShaderKind(caller);
if (kind == DXIL::ShaderKind::ClosestHit || kind == DXIL::ShaderKind::Miss)
{
int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS;
AllocaInst* spillAlloca = get(funcToSpillAlloca, caller);
if (!spillAlloca)
{
Argument* spillAllocaArg = (++traceRaySave[sh]->arg_begin());
Type* spillAllocaTy = spillAllocaArg->getType()->getPointerElementType();
spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca", caller->getEntryBlock().begin());
funcToSpillAlloca[caller] = spillAlloca;
}
// Create calls. SFT will inline them.
Value* runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc, "runtimeData", insertBefore);
CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "", insertBefore);
CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "", getInstructionAfter(call));
}
// Get the payload offset to pass to trace implementation
//hlsl::DxilInst_TraceRay traceRayCall(call);
// TODO: Avoiding the intrinsic to support the test's use of TraceRayTest
Value* payload = call->getOperand(call->getNumArgOperands() - 1);
FunctionType* funcType = FunctionType::get(int32Ty, { payload->getType() }, false);
Function* movePayloadToStackFunc = getOrCreateFunction("movePayloadToStack", m_module, funcType, movePayloadToStackFuncs);
Value* newPayloadOffset = CallInst::Create(movePayloadToStackFunc, { payload }, "new.payload.offset", insertBefore);
// Call implementation
unsigned i = 0;
if (call->getCalledFunction()->getName().startswith("dx.op"))
i += 2; // skip intrinsic number and acceleration structure (for now)
std::vector<Value*> args;
for (; i < call->getNumArgOperands() - 1; ++i)
args.push_back(call->getArgOperand(i));
args.push_back(newPayloadOffset);
CallInst::Create(traceRayImpl[0], args, "", insertBefore);
call->eraseFromParent();
}
}
static std::vector<StateFunctionTransform::ParameterSemanticType> getParameterTypes(Function* F, DXIL::ShaderKind shaderKind)
{
std::vector<StateFunctionTransform::ParameterSemanticType> paramTypes;
if (shaderKind == DXIL::ShaderKind::AnyHit || shaderKind == DXIL::ShaderKind::ClosestHit)
{
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE);
}
else if (shaderKind == DXIL::ShaderKind::Miss)
{
paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD);
}
else
{
paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE);
}
return paramTypes;
}
static void collectResources(DxilModule& DM, std::set<Value*>& resources)
{
for (auto& r : DM.GetCBuffers())
resources.insert(r->GetGlobalSymbol());
for (auto& r : DM.GetUAVs())
resources.insert(r->GetGlobalSymbol());
for (auto& r : DM.GetSRVs())
resources.insert(r->GetGlobalSymbol());
for (auto& r : DM.GetSamplers())
resources.insert(r->GetGlobalSymbol());
}
void DxrFallbackCompiler::createStateFunctions(
IntToFuncMap& stateFunctionMap,
std::vector<int>& shaderEntryStateIds,
std::vector<unsigned int>& shaderStackSizes,
int baseStateId,
const std::vector<std::string>& shaderNames,
Type* runtimeDataArgTy
)
{
for (auto& kv : m_shaderMap)
{
if (kv.second == nullptr)
errs() << "Function not found for shader " << kv.first << "\n";
}
DxilModule& DM = m_module->GetOrCreateDxilModule();
std::set<Value*> resources;
collectResources(DM, resources);
shaderEntryStateIds.clear();
shaderStackSizes.clear();
int stateId = baseStateId;
for (auto& shader : shaderNames)
{
std::vector<Function*> stateFunctions;
Function* F = m_shaderMap[shader];
StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy);
if (m_debugOutputLevel >= 2)
sft.setVerbose(true);
if (m_debugOutputLevel >= 3)
sft.setDumpFilename("dump.ll");
if (shader == "Fallback_TraceRay")
sft.setAttributeSize(m_maxAttributeSize);
DXIL::ShaderKind shaderKind = getRayShaderKind(F);
if (shaderKind != DXIL::ShaderKind::Invalid)
sft.setParameterInfo(getParameterTypes(F, shaderKind), shaderKind == DXIL::ShaderKind::ClosestHit);
sft.setResourceGlobals(resources);
UINT shaderStackSize = 0;
sft.run(stateFunctions, shaderStackSize);
shaderEntryStateIds.push_back(stateId);
shaderStackSizes.push_back(shaderStackSize);
for (Function* stateF : stateFunctions)
{
stateFunctionMap[stateId++] = stateF;
if (DM.HasDxilFunctionProps(F)) {
DM.CloneDxilEntryProps(F, stateF);
}
}
}
StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds);
}
void DxrFallbackCompiler::createLaunchParams(Function* func)
{
Module* module = func->getParent();
Function* rewrite_setLaunchParams = module->getFunction("rewrite_setLaunchParams");
CallInst* call = dyn_cast<CallInst>(*rewrite_setLaunchParams->user_begin());
LLVMContext& context = module->getContext();
Instruction* insertBefore = call;
Function* DTidFunc = FunctionBuilder(module, "dx.op.threadId.i32").i32().i32().i32().build();
Value* DTidx = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(0, context) }, "DTidx", insertBefore);
Value* DTidy = CallInst::Create(DTidFunc, { makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(1, context) }, "DTidy", insertBefore);
Value* dimx = call->getArgOperand(1);
Value* dimy = call->getArgOperand(2);
Function* groupIndexFunc = FunctionBuilder(module, "dx.op.flattenedThreadIdInGroup.i32").i32().i32().build();
Value* groupIndex = CallInst::Create(groupIndexFunc, { makeInt32(96, context) }, "groupIndex", insertBefore);
Function* fb_setLaunchParams = module->getFunction("fb_Fallback_SetLaunchParams");
Value* runtimeDataArg = call->getArgOperand(0);
CallInst::Create(fb_setLaunchParams, { runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex }, "", insertBefore);
call->eraseFromParent();
rewrite_setLaunchParams->eraseFromParent();
}
void DxrFallbackCompiler::createStateDispatch(Function* func, const IntToFuncMap& stateFunctionMap, Type* runtimeDataArgTy)
{
Module* module = func->getParent();
Function* dispatchFunc = createDispatchFunction(stateFunctionMap, runtimeDataArgTy);
Function* rewrite_dispatchFunc = module->getFunction("rewrite_dispatch");
rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc);
rewrite_dispatchFunc->eraseFromParent();
}
void DxrFallbackCompiler::createStack(Function* func)
{
LLVMContext& context = func->getContext();
// We would like to allocate the properly sized stack here, but DXIL doesn't
// allow bitcasts between objects of different sizes. So we have to use the
// default size from the runtime and replace all the accesses later.
Function* rewrite_createStack = m_module->getFunction("rewrite_createStack");
CallInst* call = dyn_cast<CallInst>(*rewrite_createStack->user_begin());
AllocaInst* stack = new AllocaInst(call->getType()->getPointerElementType(), "theStack", call);
stack->setAlignment(sizeof(int));
call->replaceAllUsesWith(stack);
call->eraseFromParent();
rewrite_createStack->eraseFromParent();
if (m_stackSizeInBytes == 0) // Take the default
m_stackSizeInBytes = stack->getType()->getPointerElementType()->getArrayNumElements() * sizeof(int);
Function* rewrite_getStackSize = m_module->getFunction("rewrite_getStackSize");
call = dyn_cast<CallInst>(*rewrite_getStackSize->user_begin());
Value* stackSizeVal = makeInt32(m_stackSizeInBytes, context);
call->replaceAllUsesWith(stackSizeVal);
call->eraseFromParent();
rewrite_getStackSize->eraseFromParent();
}
// WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime
static bool expandFloat3(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
{
VectorType* argTy = dyn_cast<VectorType>(arg->getType());
if (!argTy || argTy->getVectorNumElements() != 3)
return false;
LLVMContext& C = arg->getContext();
args.push_back(ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore));
args.push_back(ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore));
args.push_back(ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore));
return true;
}
static bool float3x4ToFloat12(std::vector<Value*>& args, Value* arg, Instruction* insertBefore)
{
StructType* STy = dyn_cast<StructType>(arg->getType());
if (!STy || STy->getName() != "class.matrix.float.3.4")
return false;
BasicBlock& entryBlock = insertBefore->getParent()->getParent()->getEntryBlock();
AllocaInst* alloca = new AllocaInst(arg->getType(), "tmp", entryBlock.begin());
new StoreInst(arg, alloca, insertBefore);
VectorType* VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12);
Value* vec12Ptr = new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore);
Value* vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore);
args.push_back(vec12);
return true;
}
void DxrFallbackCompiler::lowerIntrinsics()
{
std::vector<Function*> intrinsics = getFunctionsWithPrefix(m_module, "fb_");
assert(intrinsics.size() > 0);
// Replace intrinsics in anyhit shaders with their pending versions
LLVMContext& C = m_module->getContext();
std::map<std::string, Function*> pendingIntrinsics;
std::string pendingPrefixes[] = { "fb_dxop_pending_", "fb_Fallback_Pending" };
for (auto& F : intrinsics)
{
std::string intrinsicName;
if (F->getName().startswith(pendingPrefixes[0]))
intrinsicName = F->getName().substr(pendingPrefixes[0].length());
else if (F->getName().startswith(pendingPrefixes[1]))
intrinsicName = "Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str();
else
continue;
pendingIntrinsics[intrinsicName] = F;
}
for (Function* func : intrinsics)
{
StringRef intrinsicName;
std::string name;
bool isDxilOp = false;
if (func->getName().startswith("fb_Fallback_"))
{
intrinsicName = func->getName().substr(3); // after the "fb_" prefix
name = "\x1?" + intrinsicName.str();
}
else if (func->getName().startswith("fb_dxop_"))
{
intrinsicName = func->getName().substr(8);
name = "dx.op." + intrinsicName.str();
isDxilOp = true;
}
else
{
assert(0 && "Bad intrinsic");
}
std::vector<Function*> calledFunc = getFunctionsWithPrefix(m_module, name);
if (calledFunc.empty())
continue;
std::vector<CallInst*> calls = getCallsToFunction(calledFunc[0]);
if (calls.empty())
continue;
bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler");
Function* pendingFunc = get(pendingIntrinsics, intrinsicName.str());
Function* funcInModule = nullptr;
Function* pendingFuncInModule = nullptr;
for (CallInst* call : calls)
{
Function* caller = call->getParent()->getParent();
if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function"))
continue;
Function* F = nullptr;
if (pendingFunc && shouldUsePendingValue(caller, intrinsicName))
{
if (!pendingFuncInModule)
pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc);
F = pendingFuncInModule;
}
else
{
if (!funcInModule)
funcInModule = getOrInsertFunction(m_module, func);
F = funcInModule;
}
// insert runtime data and the rest of the arguments
std::vector<Value*> args;
if (needsRuntimeDataArg)
args.push_back(caller->arg_begin());
int argIdx = 0;
for (auto& arg : call->arg_operands())
{
if (argIdx++ == 0 && isDxilOp)
continue; // skip the intrinsic number
if (!expandFloat3(args, arg, call) && !float3x4ToFloat12(args, arg, call))
args.push_back(arg);
}
CallInst* newCall = CallInst::Create(F, args, "", call);
if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C))
{
newCall->takeName(call);
call->replaceAllUsesWith(newCall);
}
call->eraseFromParent();
}
}
}
Type* DxrFallbackCompiler::getRuntimeDataArgType()
{
// Get the first argument from a known runtime function (assuming the runtime
// has already been linked in).
Function* F = m_module->getFunction("stackIntPtr");
return F->arg_begin()->getType();
}
Function* DxrFallbackCompiler::createDispatchFunction(const IntToFuncMap &stateFunctionMap, Type* runtimeDataArgTy)
{
LLVMContext& context = m_module->getContext();
FunctionType* stateFuncTy = FunctionType::get(Type::getInt32Ty(context), { runtimeDataArgTy }, false);
Function* dispatchFunc = FunctionBuilder(m_module, "dispatch").i32().type(runtimeDataArgTy, "runtimeData").i32("stateID").build();
Value* runtimeDataArg = dispatchFunc->arg_begin();
Value* stateIdArg = ++dispatchFunc->arg_begin();
BasicBlock* entryBlock = BasicBlock::Create(context, "entry", dispatchFunc);
BasicBlock* badBlock = BasicBlock::Create(context, "badStateID", dispatchFunc);
IRBuilder<> builder(badBlock);
builder.SetInsertPoint(badBlock);
builder.CreateRet(makeInt32(-3, context)); // return an error value
builder.SetInsertPoint(entryBlock);
SwitchInst* switchInst = builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size());
BasicBlock* endBlock = badBlock;
for (auto& kv : stateFunctionMap)
{
int stateId = kv.first;
Function* stateFunc = kv.second;
Value* stateFuncInModule = m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy);
BasicBlock* block = BasicBlock::Create(context, "state_" + Twine(stateId) + "." + stateFunc->getName(), dispatchFunc, endBlock);
builder.SetInsertPoint(block);
Value* nextStateId = builder.CreateCall(stateFuncInModule, { runtimeDataArg }, "nextStateId");
builder.CreateRet(nextStateId);
switchInst->addCase(makeInt32(stateId, context), block);
}
return dispatchFunc;
}
std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunction(const std::string& funcName)
{
std::vector<CallInst*> calls;
Function* F = m_module->getFunction(funcName);
if (!F)
return calls;
for (User* U : F->users())
{
CallInst* call = dyn_cast<CallInst>(U);
if (!call)
continue;
Function* caller = call->getParent()->getParent();
auto it = m_shaderMap.find(cleanName(caller->getName()));
if (it != m_shaderMap.end())
calls.push_back(call);
}
return calls;
}
std::vector<CallInst*> DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix(const std::string& funcNamePrefix)
{
std::vector<CallInst*> calls;
for (Function* F : getFunctionsWithPrefix(m_module, funcNamePrefix))
{
for (User* U : F->users())
{
CallInst* call = dyn_cast<CallInst>(U);
if (!call)
continue;
Function* caller = call->getParent()->getParent();
if (m_shaderMap.count(cleanName(caller->getName())))
calls.push_back(call);
}
}
return calls;
}
void DxrFallbackCompiler::resizeStack(Function* F, unsigned sizeInBytes)
{
// Find the stack
AllocaInst* stack = nullptr;
for (auto& I : F->getEntryBlock().getInstList())
{
AllocaInst* alloc = dyn_cast<AllocaInst>(&I);
if (alloc && alloc->getName().startswith("theStack"))
{
stack = alloc;
break;
}
}
if (!stack)
return;
// Create a new stack
LLVMContext& C = F->getContext();
ArrayType* newStackTy = ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int));
AllocaInst* newStack = new AllocaInst(newStackTy, "", stack);
newStack->takeName(stack);
// Remap all GEPs - replaceAllUsesWith() won't change types
for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE; )
{
GetElementPtrInst* gep = dyn_cast<GetElementPtrInst>(*U++);
assert(gep && "theStack has non-gep user.");
std::vector<Value*> idxList(gep->idx_begin(), gep->idx_end());
GetElementPtrInst* newGep = GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep);
newGep->takeName(gep);
gep->replaceAllUsesWith(newGep);
gep->eraseFromParent();
}
stack->eraseFromParent();
}