#include "dxc/DxrFallback/DxrFallbackCompiler.h" #include "dxc/DXIL/DxilFunctionProps.h" #include "dxc/DXIL/DxilInstructions.h" #include "dxc/DXIL/DxilModule.h" #include "dxc/DXIL/DxilOperations.h" #include "dxc/HLSL/DxilLinker.h" #include "dxc/Support/FileIOHelper.h" #include "dxc/Support/Global.h" #include "dxc/Support/Unicode.h" #include "dxc/Support/WinIncludes.h" #include "dxc/Support/dxcapi.impl.h" #include "dxc/Support/dxcapi.use.h" #include "dxc/dxcapi.h" #include "dxc/dxcdxrfallbackcompiler.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "FunctionBuilder.h" #include "LLVMUtils.h" #include "StateFunctionTransform.h" #include "runtime.h" #include using namespace hlsl; using namespace llvm; static std::vector getFunctionsWithPrefix(Module *mod, const std::string &prefix) { std::vector functions; for (auto F = mod->begin(), E = mod->end(); F != E; ++F) { StringRef name = F->getName(); if (name.startswith(prefix)) functions.push_back(F); } return functions; } static bool inlineFunc(CallInst *call, Function *Fimpl) { // Note. LLVM inlining may not be sufficient if the function references DX // resources because the corresponding metadata is not created if the function // comes from another module. // Make sure that we have a definition for the called function in this module Function *F = call->getCalledFunction(); Module *dstM = F->getParent(); if (F->isDeclaration()) { // Map called functions in impl module to functions in this one (because the // cloning step doesn't do this automatically) ValueToValueMapTy VMap; for (auto &I : inst_range(Fimpl)) { if (CallInst *c = dyn_cast(&I)) { Function *calledFimpl = c->getCalledFunction(); if (VMap.count(calledFimpl)) continue; Constant *calledF = dstM->getOrInsertFunction( calledFimpl->getName(), calledFimpl->getFunctionType(), calledFimpl->getAttributes()); VMap[calledFimpl] = calledF; } } // Map arguments for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(), DI = F->arg_begin(); SI != SE; ++SI, ++DI) VMap[SI] = DI; SmallVector returns; CloneFunctionInto(F, Fimpl, VMap, true, returns); F->setLinkage(GlobalValue::InternalLinkage); } InlineFunctionInfo IFI; return InlineFunction(call, IFI, false); } // Remove ELF mangling static std::string cleanName(StringRef name) { if (!name.startswith("\x1?")) return name; size_t pos = name.find("@@"); if (pos == name.npos) return name; std::string newName = name.substr(2, pos - 2); return newName; } static inline Function *getOrInsertFunction(Module *mod, Function *F) { return dyn_cast( mod->getOrInsertFunction(F->getName(), F->getFunctionType())); } template V get(std::map &theMap, const K &key, V defaultVal = static_cast(nullptr)) { auto it = theMap.find(key); if (it == theMap.end()) return defaultVal; else return it->second; } DxrFallbackCompiler::DxrFallbackCompiler( llvm::Module *mod, const std::vector &shaderNames, unsigned maxAttributeSize, unsigned stackSizeInBytes, bool findCalledShaders /*= false*/) : m_module(mod), m_entryShaderNames(shaderNames), m_stackSizeInBytes(stackSizeInBytes), m_maxAttributeSize(maxAttributeSize), m_findCalledShaders(findCalledShaders) {} void DxrFallbackCompiler::compile(std::vector &shaderEntryStateIds, std::vector &shaderStackSizes, IntToFuncNameMap *pCachedMap) { std::vector shaderNames = m_entryShaderNames; initShaderMap(shaderNames); // Bring in runtime so we can get the runtime data type linkRuntime(); Type *runtimeDataArgTy = getRuntimeDataArgType(); // Make sure all calls to intrinsics and shaders are at function scope and // fix up control flow. lowerAnyHitControlFlowFuncs(); lowerReportHit(); lowerTraceRay(runtimeDataArgTy); // Create state functions IntToFuncMap stateFunctionMap; // stateID -> state function const int baseStateId = 1000; // could be anything but this makes stateIds more recognizable createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes, baseStateId, shaderNames, runtimeDataArgTy); if (pCachedMap) { for (auto &entry : stateFunctionMap) { (*pCachedMap)[entry.first] = entry.second->getName().str(); } } } void DxrFallbackCompiler::link(std::vector &shaderEntryStateIds, std::vector &shaderStackSizes, IntToFuncNameMap *pCachedMap) { IntToFuncMap stateFunctionMap; // stateID -> state function if (pCachedMap) { for (auto entry : *pCachedMap) { stateFunctionMap[entry.first] = m_module->getFunction(entry.second); } } else { for (UINT i = 0; i < shaderEntryStateIds.size(); i++) { UINT substateIndex = 0; UINT baseStateId = shaderEntryStateIds[i]; while (true) { auto substateName = m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex); auto function = m_module->getFunction(substateName); if (!function) break; stateFunctionMap[baseStateId + substateIndex] = m_module->getFunction(substateName); substateIndex++; } } } // Fix up scheduler Function *schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler"); createLaunchParams(schedulerFunc); Type *runtimeDataArgTy = getRuntimeDataArgType(); createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy); createStack(schedulerFunc); lowerIntrinsics(); } void DxrFallbackCompiler::setDebugOutputLevel(int val) { m_debugOutputLevel = val; } static bool isShader(Function *F) { if (F->hasFnAttribute("exp-shader")) return true; DxilModule &DM = F->getParent()->GetDxilModule(); return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay()); } DXIL::ShaderKind getRayShaderKind(Function *F) { if (F->hasFnAttribute("exp-shader")) return DXIL::ShaderKind::RayGeneration; DxilModule &DM = F->getParent()->GetDxilModule(); if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay()) return DM.GetDxilFunctionProps(F).shaderKind; return DXIL::ShaderKind::Invalid; } // Some shaders should use the "pending" values of intrinsics instead of the // committed ones. In particular anyhit and intersection shaders use the // pending values with the exception that the committed rayTCurrent should be // used in intersection. static bool shouldUsePendingValue(Function *F, StringRef instrinsicName) { DxilModule &DM = F->getParent()->GetDxilModule(); if (!DM.HasDxilFunctionProps(F)) return false; const hlsl::DxilFunctionProps &props = DM.GetDxilFunctionProps(F); return props.IsAnyHit() || (props.IsIntersection() && instrinsicName != "rayTCurrent"); } void DxrFallbackCompiler::initShaderMap(std::vector &shaderNames) { // Clean names and initialize shaderMap StringToFuncMap allShadersMap; for (Function &F : m_module->functions()) { if (isShader(&F)) { if (!F.isDeclaration()) allShadersMap[cleanName(F.getName())] = &F; } F.removeFnAttr(Attribute::NoInline); } for (auto &name : shaderNames) m_shaderMap[name] = allShadersMap[name]; if (!m_findCalledShaders) return; // Create a map from shader name to CallGraphNode CallGraph callGraph(*m_module); std::map allShaderNodes; for (auto &kv : m_shaderMap) { const std::string &name = kv.first; Function *func = kv.second; allShaderNodes[name] = callGraph[func]; } // Start traversing the call graph from given shaderNames std::deque workList; for (auto &name : shaderNames) workList.push_back(allShaderNodes[name]); while (!workList.empty()) { CallGraphNode *cur = workList.front(); workList.pop_front(); for (size_t i = 0; i < cur->size(); ++i) { Function *nextFunc = (*cur)[i]->getFunction(); if (!nextFunc) continue; if (isShader(nextFunc)) { const std::string nextName = cleanName(nextFunc->getName()); if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet? { workList.push_back(allShaderNodes[nextName]); shaderNames.push_back(nextName); m_shaderMap[nextName] = workList.back()->getFunction(); } } } } } void DxrFallbackCompiler::linkRuntime() { Linker linker(m_module); std::unique_ptr runtimeModule = loadModuleFromAsmString(m_module->getContext(), getRuntimeString()); bool linkErr = linker.linkInModule(runtimeModule.get()); assert(!linkErr && "Error linking runtime"); UNREFERENCED_PARAMETER(linkErr); } static void inlineFuncAndAddRet(CallInst *call, Function *F) { // Add a return after the function call. // Should be followed immediately by "unreachable". Turn that into a "ret // void". Instruction *ret = ReturnInst::Create(call->getContext()); ReplaceInstWithInst(call->getParent()->getTerminator(), ret); bool success = inlineFunc(call, F); assert(success); UNREFERENCED_PARAMETER(success); } void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs() { std::vector callsToIgnoreHit = getCallsInShadersToFunction("dx.op.ignoreHit"); if (!callsToIgnoreHit.empty()) { Function *ignoreHitFunc = m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ"); assert(ignoreHitFunc && "IgnoreHit() implementation not found"); for (CallInst *call : callsToIgnoreHit) inlineFuncAndAddRet(call, ignoreHitFunc); } std::vector callsToAcceptHitAndEndSearch = getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch"); if (!callsToAcceptHitAndEndSearch.empty()) { Function *acceptHitAndEndSearchFunc = m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ"); assert(acceptHitAndEndSearchFunc && "AcceptHitAndEndSearch() implementation not found"); for (CallInst *call : callsToAcceptHitAndEndSearch) inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc); } } void DxrFallbackCompiler::lowerReportHit() { std::vector callsToReportHit = getCallsInShadersToFunctionWithPrefix("dx.op.reportHit"); if (callsToReportHit.empty()) return; Function *reportHitFunc = m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z"); assert(reportHitFunc && "ReportHit() implementation not found"); LLVMContext &C = m_module->getContext(); for (CallInst *call : callsToReportHit) { // Wrap attribute arguments in Fallback_SetPendingAttr() call Instruction *insertBefore = call; hlsl::DxilInst_ReportHit reportHitCall(call); Value *attr = reportHitCall.get_Attributes(); Function *setPendingAttrFunc = FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@") .voidTy() .type(attr->getType(), "attr") .build(); CallInst::Create(setPendingAttrFunc, {attr}, "", insertBefore); // Make call to implementation and load result CallInst *callImpl = CallInst::Create( reportHitFunc, {reportHitCall.get_THit(), reportHitCall.get_HitKind()}, "reportHit.result", insertBefore); Value *result = callImpl; // Result < 0 ==> ret Value *zero = makeInt32(0, C); Value *ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero, "endSearch"); BasicBlock *prevBlock = call->getParent(); BasicBlock *retBlock = prevBlock->splitBasicBlock(call, "endSearch"); BasicBlock *nextBlock = retBlock->splitBasicBlock(call, "afterReportHit"); ReplaceInstWithInst(prevBlock->getTerminator(), BranchInst::Create(retBlock, nextBlock, ltz)); ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C)); // Compare result to zero and store into original result Value *gtz = new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted"); call->replaceAllUsesWith(gtz); bool success = inlineFunc(callImpl, reportHitFunc); assert(success); (void)success; call->eraseFromParent(); } } void DxrFallbackCompiler::lowerTraceRay(Type *runtimeDataArgTy) { std::vector callsToTraceRay = getCallsInShadersToFunctionWithPrefix("dx.op.traceRay"); if (callsToTraceRay.empty()) { // TODO: It might be worth dropping this from the tests eventually callsToTraceRay = getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@"); if (callsToTraceRay.empty()) return; } std::vector traceRayImpl = getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@"); assert(traceRayImpl.size() == 1 && "Could not find Fallback_TraceRay() implementation"); enum { CLOSEST_HIT = 0, MISS = 1 }; Function *traceRaySave[] = {m_module->getFunction("traceRaySave_ClosestHit"), m_module->getFunction("traceRaySave_Miss")}; Function *traceRayRestore[] = { m_module->getFunction("traceRayRestore_ClosestHit"), m_module->getFunction("traceRayRestore_Miss")}; assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] && traceRaySave[MISS] && traceRayRestore[MISS] && "Could not find TraceRay spill functions"); Function *dummyRuntimeDataArgFunc = StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module, runtimeDataArgTy); assert(dummyRuntimeDataArgFunc && "dummyRuntimeDataArg function could not be created."); // Process calls LLVMContext &C = m_module->getContext(); Type *int32Ty = Type::getInt32Ty(C); std::map movePayloadToStackFuncs; std::map funcToSpillAlloca; for (CallInst *call : callsToTraceRay) { Instruction *insertBefore = call; // Spill runtime data values, if necessary (closesthit and miss shaders) Function *caller = call->getParent()->getParent(); DXIL::ShaderKind kind = getRayShaderKind(caller); if (kind == DXIL::ShaderKind::ClosestHit || kind == DXIL::ShaderKind::Miss) { int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS; AllocaInst *spillAlloca = get(funcToSpillAlloca, caller); if (!spillAlloca) { Argument *spillAllocaArg = (++traceRaySave[sh]->arg_begin()); Type *spillAllocaTy = spillAllocaArg->getType()->getPointerElementType(); spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca", caller->getEntryBlock().begin()); funcToSpillAlloca[caller] = spillAlloca; } // Create calls. SFT will inline them. Value *runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc, "runtimeData", insertBefore); CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "", insertBefore); CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "", getInstructionAfter(call)); } // Get the payload offset to pass to trace implementation // hlsl::DxilInst_TraceRay traceRayCall(call); // TODO: Avoiding the intrinsic to support the test's use of TraceRayTest Value *payload = call->getOperand(call->getNumArgOperands() - 1); FunctionType *funcType = FunctionType::get(int32Ty, {payload->getType()}, false); Function *movePayloadToStackFunc = getOrCreateFunction( "movePayloadToStack", m_module, funcType, movePayloadToStackFuncs); Value *newPayloadOffset = CallInst::Create( movePayloadToStackFunc, {payload}, "new.payload.offset", insertBefore); // Call implementation unsigned i = 0; if (call->getCalledFunction()->getName().startswith("dx.op")) i += 2; // skip intrinsic number and acceleration structure (for now) std::vector args; for (; i < call->getNumArgOperands() - 1; ++i) args.push_back(call->getArgOperand(i)); args.push_back(newPayloadOffset); CallInst::Create(traceRayImpl[0], args, "", insertBefore); call->eraseFromParent(); } } static std::vector getParameterTypes(Function *F, DXIL::ShaderKind shaderKind) { std::vector paramTypes; if (shaderKind == DXIL::ShaderKind::AnyHit || shaderKind == DXIL::ShaderKind::ClosestHit) { paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD); paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE); } else if (shaderKind == DXIL::ShaderKind::Miss) { paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD); } else { paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE); } return paramTypes; } static void collectResources(DxilModule &DM, std::set &resources) { for (auto &r : DM.GetCBuffers()) resources.insert(r->GetGlobalSymbol()); for (auto &r : DM.GetUAVs()) resources.insert(r->GetGlobalSymbol()); for (auto &r : DM.GetSRVs()) resources.insert(r->GetGlobalSymbol()); for (auto &r : DM.GetSamplers()) resources.insert(r->GetGlobalSymbol()); } void DxrFallbackCompiler::createStateFunctions( IntToFuncMap &stateFunctionMap, std::vector &shaderEntryStateIds, std::vector &shaderStackSizes, int baseStateId, const std::vector &shaderNames, Type *runtimeDataArgTy) { for (auto &kv : m_shaderMap) { if (kv.second == nullptr) errs() << "Function not found for shader " << kv.first << "\n"; } DxilModule &DM = m_module->GetOrCreateDxilModule(); std::set resources; collectResources(DM, resources); shaderEntryStateIds.clear(); shaderStackSizes.clear(); int stateId = baseStateId; for (auto &shader : shaderNames) { std::vector stateFunctions; Function *F = m_shaderMap[shader]; StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy); if (m_debugOutputLevel >= 2) sft.setVerbose(true); if (m_debugOutputLevel >= 3) sft.setDumpFilename("dump.ll"); if (shader == "Fallback_TraceRay") sft.setAttributeSize(m_maxAttributeSize); DXIL::ShaderKind shaderKind = getRayShaderKind(F); if (shaderKind != DXIL::ShaderKind::Invalid) sft.setParameterInfo(getParameterTypes(F, shaderKind), shaderKind == DXIL::ShaderKind::ClosestHit); sft.setResourceGlobals(resources); UINT shaderStackSize = 0; sft.run(stateFunctions, shaderStackSize); shaderEntryStateIds.push_back(stateId); shaderStackSizes.push_back(shaderStackSize); for (Function *stateF : stateFunctions) { stateFunctionMap[stateId++] = stateF; if (DM.HasDxilFunctionProps(F)) { DM.CloneDxilEntryProps(F, stateF); } } } StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds); } void DxrFallbackCompiler::createLaunchParams(Function *func) { Module *mod = func->getParent(); Function *rewrite_setLaunchParams = mod->getFunction("rewrite_setLaunchParams"); CallInst *call = dyn_cast(*rewrite_setLaunchParams->user_begin()); LLVMContext &context = mod->getContext(); Instruction *insertBefore = call; Function *DTidFunc = FunctionBuilder(mod, "dx.op.threadId.i32").i32().i32().i32().build(); Value *DTidx = CallInst::Create(DTidFunc, {makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(0, context)}, "DTidx", insertBefore); Value *DTidy = CallInst::Create(DTidFunc, {makeInt32((int)hlsl::OP::OpCode::ThreadId, context), makeInt32(1, context)}, "DTidy", insertBefore); Value *dimx = call->getArgOperand(1); Value *dimy = call->getArgOperand(2); Function *groupIndexFunc = FunctionBuilder(mod, "dx.op.flattenedThreadIdInGroup.i32") .i32() .i32() .build(); Value *groupIndex = CallInst::Create(groupIndexFunc, {makeInt32(96, context)}, "groupIndex", insertBefore); Function *fb_setLaunchParams = mod->getFunction("fb_Fallback_SetLaunchParams"); Value *runtimeDataArg = call->getArgOperand(0); CallInst::Create(fb_setLaunchParams, {runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex}, "", insertBefore); call->eraseFromParent(); rewrite_setLaunchParams->eraseFromParent(); } void DxrFallbackCompiler::createStateDispatch( Function *func, const IntToFuncMap &stateFunctionMap, Type *runtimeDataArgTy) { Module *mod = func->getParent(); Function *dispatchFunc = createDispatchFunction(stateFunctionMap, runtimeDataArgTy); Function *rewrite_dispatchFunc = mod->getFunction("rewrite_dispatch"); rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc); rewrite_dispatchFunc->eraseFromParent(); } void DxrFallbackCompiler::createStack(Function *func) { LLVMContext &context = func->getContext(); // We would like to allocate the properly sized stack here, but DXIL doesn't // allow bitcasts between objects of different sizes. So we have to use the // default size from the runtime and replace all the accesses later. Function *rewrite_createStack = m_module->getFunction("rewrite_createStack"); CallInst *call = dyn_cast(*rewrite_createStack->user_begin()); AllocaInst *stack = new AllocaInst(call->getType()->getPointerElementType(), "theStack", call); stack->setAlignment(sizeof(int)); call->replaceAllUsesWith(stack); call->eraseFromParent(); rewrite_createStack->eraseFromParent(); if (m_stackSizeInBytes == 0) // Take the default m_stackSizeInBytes = stack->getType()->getPointerElementType()->getArrayNumElements() * sizeof(int); Function *rewrite_getStackSize = m_module->getFunction("rewrite_getStackSize"); call = dyn_cast(*rewrite_getStackSize->user_begin()); Value *stackSizeVal = makeInt32(m_stackSizeInBytes, context); call->replaceAllUsesWith(stackSizeVal); call->eraseFromParent(); rewrite_getStackSize->eraseFromParent(); } // WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime static bool expandFloat3(std::vector &args, Value *arg, Instruction *insertBefore) { VectorType *argTy = dyn_cast(arg->getType()); if (!argTy || argTy->getVectorNumElements() != 3) return false; LLVMContext &C = arg->getContext(); args.push_back( ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore)); args.push_back( ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore)); args.push_back( ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore)); return true; } static bool float3x4ToFloat12(std::vector &args, Value *arg, Instruction *insertBefore) { StructType *STy = dyn_cast(arg->getType()); if (!STy || STy->getName() != "class.matrix.float.3.4") return false; BasicBlock &entryBlock = insertBefore->getParent()->getParent()->getEntryBlock(); AllocaInst *alloca = new AllocaInst(arg->getType(), "tmp", entryBlock.begin()); new StoreInst(arg, alloca, insertBefore); VectorType *VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12); Value *vec12Ptr = new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore); Value *vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore); args.push_back(vec12); return true; } void DxrFallbackCompiler::lowerIntrinsics() { std::vector intrinsics = getFunctionsWithPrefix(m_module, "fb_"); assert(intrinsics.size() > 0); // Replace intrinsics in anyhit shaders with their pending versions LLVMContext &C = m_module->getContext(); std::map pendingIntrinsics; std::string pendingPrefixes[] = {"fb_dxop_pending_", "fb_Fallback_Pending"}; for (auto &F : intrinsics) { std::string intrinsicName; if (F->getName().startswith(pendingPrefixes[0])) intrinsicName = F->getName().substr(pendingPrefixes[0].length()); else if (F->getName().startswith(pendingPrefixes[1])) intrinsicName = "Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str(); else continue; pendingIntrinsics[intrinsicName] = F; } for (Function *func : intrinsics) { StringRef intrinsicName; std::string name; bool isDxilOp = false; if (func->getName().startswith("fb_Fallback_")) { intrinsicName = func->getName().substr(3); // after the "fb_" prefix name = "\x1?" + intrinsicName.str(); } else if (func->getName().startswith("fb_dxop_")) { intrinsicName = func->getName().substr(8); name = "dx.op." + intrinsicName.str(); isDxilOp = true; } else { assert(0 && "Bad intrinsic"); } std::vector calledFunc = getFunctionsWithPrefix(m_module, name); if (calledFunc.empty()) continue; std::vector calls = getCallsToFunction(calledFunc[0]); if (calls.empty()) continue; bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler"); Function *pendingFunc = get(pendingIntrinsics, intrinsicName.str()); Function *funcInModule = nullptr; Function *pendingFuncInModule = nullptr; for (CallInst *call : calls) { Function *caller = call->getParent()->getParent(); if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function")) continue; Function *F = nullptr; if (pendingFunc && shouldUsePendingValue(caller, intrinsicName)) { if (!pendingFuncInModule) pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc); F = pendingFuncInModule; } else { if (!funcInModule) funcInModule = getOrInsertFunction(m_module, func); F = funcInModule; } // insert runtime data and the rest of the arguments std::vector args; if (needsRuntimeDataArg) args.push_back(caller->arg_begin()); int argIdx = 0; for (auto &arg : call->arg_operands()) { if (argIdx++ == 0 && isDxilOp) continue; // skip the intrinsic number if (!expandFloat3(args, arg, call) && !float3x4ToFloat12(args, arg, call)) args.push_back(arg); } CallInst *newCall = CallInst::Create(F, args, "", call); if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C)) { newCall->takeName(call); call->replaceAllUsesWith(newCall); } call->eraseFromParent(); } } } Type *DxrFallbackCompiler::getRuntimeDataArgType() { // Get the first argument from a known runtime function (assuming the runtime // has already been linked in). Function *F = m_module->getFunction("stackIntPtr"); return F->arg_begin()->getType(); } Function *DxrFallbackCompiler::createDispatchFunction( const IntToFuncMap &stateFunctionMap, Type *runtimeDataArgTy) { LLVMContext &context = m_module->getContext(); FunctionType *stateFuncTy = FunctionType::get(Type::getInt32Ty(context), {runtimeDataArgTy}, false); Function *dispatchFunc = FunctionBuilder(m_module, "dispatch") .i32() .type(runtimeDataArgTy, "runtimeData") .i32("stateID") .build(); Value *runtimeDataArg = dispatchFunc->arg_begin(); Value *stateIdArg = ++dispatchFunc->arg_begin(); BasicBlock *entryBlock = BasicBlock::Create(context, "entry", dispatchFunc); BasicBlock *badBlock = BasicBlock::Create(context, "badStateID", dispatchFunc); IRBuilder<> builder(badBlock); builder.SetInsertPoint(badBlock); builder.CreateRet(makeInt32(-3, context)); // return an error value builder.SetInsertPoint(entryBlock); SwitchInst *switchInst = builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size()); BasicBlock *endBlock = badBlock; for (auto &kv : stateFunctionMap) { int stateId = kv.first; Function *stateFunc = kv.second; Value *stateFuncInModule = m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy); BasicBlock *block = BasicBlock::Create( context, "state_" + Twine(stateId) + "." + stateFunc->getName(), dispatchFunc, endBlock); builder.SetInsertPoint(block); Value *nextStateId = builder.CreateCall(stateFuncInModule, {runtimeDataArg}, "nextStateId"); builder.CreateRet(nextStateId); switchInst->addCase(makeInt32(stateId, context), block); } return dispatchFunc; } std::vector DxrFallbackCompiler::getCallsInShadersToFunction(const std::string &funcName) { std::vector calls; Function *F = m_module->getFunction(funcName); if (!F) return calls; for (User *U : F->users()) { CallInst *call = dyn_cast(U); if (!call) continue; Function *caller = call->getParent()->getParent(); auto it = m_shaderMap.find(cleanName(caller->getName())); if (it != m_shaderMap.end()) calls.push_back(call); } return calls; } std::vector DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix( const std::string &funcNamePrefix) { std::vector calls; for (Function *F : getFunctionsWithPrefix(m_module, funcNamePrefix)) { for (User *U : F->users()) { CallInst *call = dyn_cast(U); if (!call) continue; Function *caller = call->getParent()->getParent(); if (m_shaderMap.count(cleanName(caller->getName()))) calls.push_back(call); } } return calls; } void DxrFallbackCompiler::resizeStack(Function *F, unsigned sizeInBytes) { // Find the stack AllocaInst *stack = nullptr; for (auto &I : F->getEntryBlock().getInstList()) { AllocaInst *alloc = dyn_cast(&I); if (alloc && alloc->getName().startswith("theStack")) { stack = alloc; break; } } if (!stack) return; // Create a new stack LLVMContext &C = F->getContext(); ArrayType *newStackTy = ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int)); AllocaInst *newStack = new AllocaInst(newStackTy, "", stack); newStack->takeName(stack); // Remap all GEPs - replaceAllUsesWith() won't change types for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE;) { GetElementPtrInst *gep = dyn_cast(*U++); assert(gep && "theStack has non-gep user."); std::vector idxList(gep->idx_begin(), gep->idx_end()); GetElementPtrInst *newGep = GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep); newGep->takeName(gep); gep->replaceAllUsesWith(newGep); gep->eraseFromParent(); } stack->eraseFromParent(); }