Matrix lowering for functions with UDT params preserved.

- Keep track of patch constant functions for later identification
- functions that require input/output signature processing identified
  with IsEntryThatUsesSignatures
- update lib_rt.hlsl intrinsics and naming
This commit is contained in:
Tex Riddell 2018-01-30 23:05:43 -08:00
Родитель 1c8218861a
Коммит 15cd5f16e6
11 изменённых файлов: 254 добавлений и 80 удалений

Просмотреть файл

@ -26,6 +26,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace llvm { namespace llvm {
class LLVMContext; class LLVMContext;
@ -132,6 +133,14 @@ public:
DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F); DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F);
// Move DxilFunctionProps of F to NewF. // Move DxilFunctionProps of F to NewF.
void ReplaceDxilFunctionProps(llvm::Function *F, llvm::Function *NewF); void ReplaceDxilFunctionProps(llvm::Function *F, llvm::Function *NewF);
void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc);
bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps
bool IsPatchConstantShader(llvm::Function *F);
bool IsComputeShader(llvm::Function *F);
// Is an entry function that uses input/output signature conventions?
// Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function.
bool IsEntryThatUsesSignatures(llvm::Function *F);
// Remove Root Signature from module metadata // Remove Root Signature from module metadata
void StripRootSignatureFromMetadata(); void StripRootSignatureFromMetadata();
@ -436,6 +445,9 @@ private:
std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>> std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>>
m_DxilEntrySignatureMap; m_DxilEntrySignatureMap;
// Keeps track of patch constant functions used by hull shaders
std::unordered_set<llvm::Function *> m_PatchConstantFunctions;
// ViewId state. // ViewId state.
std::unique_ptr<DxilViewIdState> m_pViewIdState; std::unique_ptr<DxilViewIdState> m_pViewIdState;

Просмотреть файл

@ -24,6 +24,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace llvm { namespace llvm {
class LLVMContext; class LLVMContext;
@ -127,6 +128,14 @@ public:
bool HasDxilFunctionProps(llvm::Function *F); bool HasDxilFunctionProps(llvm::Function *F);
DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F); DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F);
void AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info); void AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info);
void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc);
bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps
bool IsPatchConstantShader(llvm::Function *F);
bool IsComputeShader(llvm::Function *F);
// Is an entry function that uses input/output signature conventions?
// Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function.
bool IsEntryThatUsesSignatures(llvm::Function *F);
DxilFunctionAnnotation *GetFunctionAnnotation(llvm::Function *F); DxilFunctionAnnotation *GetFunctionAnnotation(llvm::Function *F);
DxilFunctionAnnotation *AddFunctionAnnotation(llvm::Function *F); DxilFunctionAnnotation *AddFunctionAnnotation(llvm::Function *F);
@ -238,6 +247,7 @@ private:
// High level function info. // High level function info.
std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>> m_DxilFunctionPropsMap; std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>> m_DxilFunctionPropsMap;
std::unordered_set<llvm::Function *> m_PatchConstantFunctions;
// Resource type annotation. // Resource type annotation.
std::unordered_map<llvm::Type *, std::pair<DXIL::ResourceClass, DXIL::ResourceKind>> m_ResTypeAnnotation; std::unordered_map<llvm::Type *, std::pair<DXIL::ResourceClass, DXIL::ResourceKind>> m_ResTypeAnnotation;

Просмотреть файл

@ -238,7 +238,7 @@ public:
for (auto It = M.begin(); It != M.end();) { for (auto It = M.begin(); It != M.end();) {
Function &F = *(It++); Function &F = *(It++);
// Lower signature for each entry function. // Lower signature for each entry function.
if (m_pHLModule->HasDxilFunctionProps(&F)) { if (m_pHLModule->IsEntryThatUsesSignatures(&F)) {
DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&F); DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&F);
std::unique_ptr<DxilEntrySignature> pSig = std::unique_ptr<DxilEntrySignature> pSig =
llvm::make_unique<DxilEntrySignature>(props.shaderKind, m_pHLModule->GetHLOptions().bUseMinPrecision); llvm::make_unique<DxilEntrySignature>(props.shaderKind, m_pHLModule->GetHLOptions().bUseMinPrecision);

Просмотреть файл

@ -607,7 +607,8 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc; Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc;
Function *newPatchConstantFunc = Function *newPatchConstantFunc =
m_newFunctions[patchConstantFunc->getName()]; m_newFunctions[patchConstantFunc->getName()];
props.ShaderProps.HS.patchConstantFunc = newPatchConstantFunc; DM.SetPatchConstantFunctionForHS(entryFunc, nullptr);
DM.SetPatchConstantFunctionForHS(NewEntryFunc, newPatchConstantFunc);
if (newPatchConstantFunc->hasFnAttribute(llvm::Attribute::AlwaysInline)) if (newPatchConstantFunc->hasFnAttribute(llvm::Attribute::AlwaysInline))
newPatchConstantFunc->removeFnAttr(llvm::Attribute::AlwaysInline); newPatchConstantFunc->removeFnAttr(llvm::Attribute::AlwaysInline);

Просмотреть файл

@ -1102,6 +1102,35 @@ void DxilModule::ReplaceDxilFunctionProps(llvm::Function *F,
m_DxilFunctionPropsMap.erase(F); m_DxilFunctionPropsMap.erase(F);
m_DxilFunctionPropsMap[NewF] = std::move(props); m_DxilFunctionPropsMap[NewF] = std::move(props);
} }
void DxilModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) {
auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc);
DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "Hull shader must already have function props!");
DxilFunctionProps &props = *(propIter->second);
DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader");
if (props.ShaderProps.HS.patchConstantFunc)
m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc);
props.ShaderProps.HS.patchConstantFunc = patchConstantFunc;
if (patchConstantFunc)
m_PatchConstantFunctions.insert(patchConstantFunc);
}
bool DxilModule::IsGraphicsShader(llvm::Function *F) {
return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
}
bool DxilModule::IsPatchConstantShader(llvm::Function *F) {
return m_PatchConstantFunctions.count(F) != 0;
}
bool DxilModule::IsComputeShader(llvm::Function *F) {
return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS();
}
bool DxilModule::IsEntryThatUsesSignatures(llvm::Function *F) {
auto propIter = m_DxilFunctionPropsMap.find(F);
if (propIter != m_DxilFunctionPropsMap.end()) {
DxilFunctionProps &props = *(propIter->second);
return props.IsGraphics() || props.IsCS();
}
// Otherwise, return true if patch constant function
return IsPatchConstantShader(F);
}
void DxilModule::StripRootSignatureFromMetadata() { void DxilModule::StripRootSignatureFromMetadata() {
NamedMDNode *pRootSignatureNamedMD = GetModule()->getNamedMetadata(DxilMDHelper::kDxilRootSignatureMDName); NamedMDNode *pRootSignatureNamedMD = GetModule()->getNamedMetadata(DxilMDHelper::kDxilRootSignatureMDName);
@ -1319,6 +1348,11 @@ void DxilModule::LoadDxilMetadata() {
Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get()); Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get());
if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) {
// Add patch constant function to m_PatchConstantFunctions
m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc);
}
m_DxilFunctionPropsMap[F] = std::move(props); m_DxilFunctionPropsMap[F] = std::move(props);
} }

Просмотреть файл

@ -374,7 +374,7 @@ private:
} else { } else {
std::vector<Function *> entries; std::vector<Function *> entries;
for (iplist<Function>::iterator F : M.getFunctionList()) { for (iplist<Function>::iterator F : M.getFunctionList()) {
if (DM.HasDxilFunctionProps(F)) { if (DM.IsEntryThatUsesSignatures(F)) {
entries.emplace_back(F); entries.emplace_back(F);
} }
} }
@ -384,7 +384,7 @@ private:
// Strip patch constant function first. // Strip patch constant function first.
Function *patchConstFunc = StripFunctionParameter( Function *patchConstFunc = StripFunctionParameter(
props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs); props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs);
props.ShaderProps.HS.patchConstantFunc = patchConstFunc; DM.SetPatchConstantFunctionForHS(entry, patchConstFunc);
} }
StripFunctionParameter(entry, DM, FunctionDIs); StripFunctionParameter(entry, DM, FunctionDIs);
} }

Просмотреть файл

@ -272,6 +272,9 @@ private:
// Get new matrix value corresponding to vecVal // Get new matrix value corresponding to vecVal
Value *GetMatrixForVec(Value *vecVal, Type *matTy); Value *GetMatrixForVec(Value *vecVal, Type *matTy);
// Translate library function input/output to preserve function signatures
void TranslateLibraryArgs(Function &F);
// Replace matVal with vecVal on matUseInst. // Replace matVal with vecVal on matUseInst.
void TrivialMatReplace(Value *matVal, Value *vecVal, void TrivialMatReplace(Value *matVal, Value *vecVal,
CallInst *matUseInst); CallInst *matUseInst);
@ -1269,6 +1272,16 @@ void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
} }
} }
static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned row, unsigned col) {
SmallVector<int, 16> castMask(col * row);
unsigned idx = 0;
for (unsigned c = 0; c < col; c++)
for (unsigned r = 0; r < row; r++)
castMask[idx++] = r * col + c;
return cast<Instruction>(
Builder.CreateShuffleVector(vecVal, vecVal, castMask));
}
void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal, void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
Value *vecVal, Value *vecVal,
CallInst *castInst, CallInst *castInst,
@ -1291,25 +1304,9 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
IRBuilder<> Builder(castInst); IRBuilder<> Builder(castInst);
// shuf to change major. if (bRowToCol)
SmallVector<int, 16> castMask(col * row); std::swap(row, col);
unsigned idx = 0; Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
if (bRowToCol) {
for (unsigned c = 0; c < col; c++)
for (unsigned r = 0; r < row; r++) {
unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
castMask[idx++] = matIdx;
}
} else {
for (unsigned r = 0; r < row; r++)
for (unsigned c = 0; c < col; c++) {
unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
castMask[idx++] = matIdx;
}
}
Instruction *vecCast = cast<Instruction>(
Builder.CreateShuffleVector(vecVal, vecVal, castMask));
// Replace vec cast function call with vecCast. // Replace vec cast function call with vecCast.
DXASSERT(matToVecMap.count(castInst), "must has vec version"); DXASSERT(matToVecMap.count(castInst), "must has vec version");
@ -2109,12 +2106,10 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal, void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
Value *vecVal) { Value *vecVal) {
Type *matTy = matVal->getType();
for (Value::user_iterator user = matVal->user_begin(); for (Value::user_iterator user = matVal->user_begin();
user != matVal->user_end();) { user != matVal->user_end();) {
Instruction *useInst = cast<Instruction>(*(user++)); Instruction *useInst = cast<Instruction>(*(user++));
// Skip return here.
if (isa<ReturnInst>(useInst))
continue;
// User must be function call. // User must be function call.
if (CallInst *useCall = dyn_cast<CallInst>(useInst)) { if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
hlsl::HLOpcodeGroup group = hlsl::HLOpcodeGroup group =
@ -2183,7 +2178,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) { for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
if (useCall->getArgOperand(i) == matVal) { if (useCall->getArgOperand(i) == matVal) {
// update the user call with the correct matrix value in new code sequence // update the user call with the correct matrix value in new code sequence
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType()); Value *newMatVal = GetMatrixForVec(vecVal, matTy);
if (matVal != newMatVal) if (matVal != newMatVal)
useCall->setArgOperand(i, newMatVal); useCall->setArgOperand(i, newMatVal);
} }
@ -2194,8 +2189,10 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
// Just replace the src with vec version. // Just replace the src with vec version.
useInst->setOperand(0, vecVal); useInst->setOperand(0, vecVal);
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) { } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType()); Value *newMatVal = GetMatrixForVec(vecVal, matTy);
RI->setOperand(0, newMatVal); RI->setOperand(0, newMatVal);
} else if (isa<StoreInst>(useInst)) {
DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
} else { } else {
// Must be GEP on mat array alloca. // Must be GEP on mat array alloca.
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst); GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
@ -2467,6 +2464,85 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
} }
} }
void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
// Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
// Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
for (auto &arg : F.args()) {
SmallVector<CallInst *, 4> Candidates;
for (User *U : arg.users()) {
if (CallInst *CI = dyn_cast<CallInst>(U)) {
HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
switch (group) {
case HLOpcodeGroup::HLCast:
case HLOpcodeGroup::HLMatLoadStore:
Candidates.push_back(CI);
break;
}
}
}
for (CallInst *CI : Candidates) {
IRBuilder<> Builder(CI);
HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
switch (group) {
case HLOpcodeGroup::HLCast: {
HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
if (opcode == HLCastOpcode::RowMatrixToVecCast ||
opcode == HLCastOpcode::ColMatrixToVecCast) {
Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
/*bOrigAllocaTy*/false,
matVal->getName());
if (opcode == HLCastOpcode::ColMatrixToVecCast) {
unsigned row, col;
HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
}
CI->replaceAllUsesWith(vecVal);
CI->eraseFromParent();
}
} break;
case HLOpcodeGroup::HLMatLoadStore: {
HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
bool bTranspose = false;
switch (opcode) {
case HLMatLoadStoreOpcode::ColMatStore:
bTranspose = true;
case HLMatLoadStoreOpcode::RowMatStore: {
// shuffle if transposed, bitcast, and store
Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
if (bTranspose) {
unsigned row, col;
HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
}
Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
Builder.CreateStore(vecVal, castPtr);
CI->eraseFromParent();
} break;
case HLMatLoadStoreOpcode::ColMatLoad:
bTranspose = true;
case HLMatLoadStoreOpcode::RowMatLoad: {
// bitcast, load, and shuffle if transposed
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
Value *vecVal = Builder.CreateLoad(castPtr);
if (bTranspose) {
unsigned row, col;
HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
// row/col swapped for col major source
vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
}
CI->replaceAllUsesWith(vecVal);
CI->eraseFromParent();
} break;
}
} break;
}
}
}
}
void HLMatrixLowerPass::runOnFunction(Function &F) { void HLMatrixLowerPass::runOnFunction(Function &F) {
// Create vector version of matrix instructions first. // Create vector version of matrix instructions first.
// The matrix operands will be undefval for these instructions. // The matrix operands will be undefval for these instructions.
@ -2531,4 +2607,12 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
DeleteDeadInsts(); DeleteDeadInsts();
matToVecMap.clear(); matToVecMap.clear();
vecToMatMap.clear();
// If this is a library function, now fix input/output matrix params
// TODO: What about Patch Constant Shaders?
if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
TranslateLibraryArgs(F);
}
return;
} }

Просмотреть файл

@ -350,6 +350,35 @@ void HLModule::AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunct
DXASSERT_NOMSG(info->shaderKind != DXIL::ShaderKind::Invalid); DXASSERT_NOMSG(info->shaderKind != DXIL::ShaderKind::Invalid);
m_DxilFunctionPropsMap[F] = std::move(info); m_DxilFunctionPropsMap[F] = std::move(info);
} }
void HLModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) {
auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc);
DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "else Hull Shader missing function props");
DxilFunctionProps &props = *(propIter->second);
DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader");
if (props.ShaderProps.HS.patchConstantFunc)
m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc);
props.ShaderProps.HS.patchConstantFunc = patchConstantFunc;
if (patchConstantFunc)
m_PatchConstantFunctions.insert(patchConstantFunc);
}
bool HLModule::IsGraphicsShader(llvm::Function *F) {
return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
}
bool HLModule::IsPatchConstantShader(llvm::Function *F) {
return m_PatchConstantFunctions.count(F) != 0;
}
bool HLModule::IsComputeShader(llvm::Function *F) {
return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS();
}
bool HLModule::IsEntryThatUsesSignatures(llvm::Function *F) {
auto propIter = m_DxilFunctionPropsMap.find(F);
if (propIter != m_DxilFunctionPropsMap.end()) {
DxilFunctionProps &props = *(propIter->second);
return props.IsGraphics() || props.IsCS();
}
// Otherwise, return true if patch constant function
return IsPatchConstantShader(F);
}
DxilFunctionAnnotation *HLModule::GetFunctionAnnotation(llvm::Function *F) { DxilFunctionAnnotation *HLModule::GetFunctionAnnotation(llvm::Function *F) {
return m_pTypeSystem->GetFunctionAnnotation(F); return m_pTypeSystem->GetFunctionAnnotation(F);
@ -475,6 +504,11 @@ void HLModule::LoadHLMetadata() {
Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get()); Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get());
if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) {
// Add patch constant function to m_PatchConstantFunctions
m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc);
}
m_DxilFunctionPropsMap[F] = std::move(props); m_DxilFunctionPropsMap[F] = std::move(props);
} }

Просмотреть файл

@ -5173,6 +5173,9 @@ void SROA_Parameter_HLSL::flattenArgument(
Type *Ty = V->getType(); Type *Ty = V->getType();
if (Ty->isPointerTy()) if (Ty->isPointerTy())
Ty = Ty->getPointerElementType(); Ty = Ty->getPointerElementType();
// Stop doing this when preserving resource types and using new
// createHandleFrom??? whatever it's going to be called...
V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder); V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder);
// Cannot SROA, save it to final parameter list. // Cannot SROA, save it to final parameter list.
@ -5829,20 +5832,8 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get()); IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get());
RetBuilder.CreateRetVoid(); RetBuilder.CreateRetVoid();
} else { } else {
Function *Entry = m_pHLModule->GetEntryFunction(); hasShaderInputOutput = F == m_pHLModule->GetEntryFunction() ||
hasShaderInputOutput = F == Entry; m_pHLModule->IsEntryThatUsesSignatures(F);
if (m_pHLModule->HasDxilFunctionProps(F)) {
DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
if (!funcProps.IsRay())
hasShaderInputOutput = true;
}
if (m_pHLModule->HasDxilFunctionProps(Entry)) {
DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(Entry);
if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
Function *patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc;
hasShaderInputOutput |= F == patchConstantFunc;
}
}
} }
std::vector<Value *> FlatParamList; std::vector<Value *> FlatParamList;
@ -6361,9 +6352,9 @@ void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
if (funcProps.shaderKind == DXIL::ShaderKind::Hull) { if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
Function *oldPatchConstantFunc = Function *oldPatchConstantFunc =
funcProps.ShaderProps.HS.patchConstantFunc; funcProps.ShaderProps.HS.patchConstantFunc;
if (funcMap.count(oldPatchConstantFunc)) if (funcMap.count(oldPatchConstantFunc)) {
funcProps.ShaderProps.HS.patchConstantFunc = m_pHLModule->SetPatchConstantFunctionForHS(flatF, funcMap[oldPatchConstantFunc]);
funcMap[oldPatchConstantFunc]; }
} }
} }
// TODO: flatten vector argument and lower resource argument when flatten // TODO: flatten vector argument and lower resource argument when flatten

Просмотреть файл

@ -4317,11 +4317,11 @@ void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr(
} }
Function *patchConstFunc = Entry->second.Func; Function *patchConstFunc = Entry->second.Func;
DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func); DXASSERT(m_pHLModule->HasDxilFunctionProps(EntryFunc.Func),
DXASSERT(HSProps != nullptr,
" else AddHLSLFunctionInfo did not save the dxil function props for the " " else AddHLSLFunctionInfo did not save the dxil function props for the "
"HS entry."); "HS entry.");
HSProps->ShaderProps.HS.patchConstantFunc = patchConstFunc; DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func);
m_pHLModule->SetPatchConstantFunctionForHS(EntryFunc.Func, patchConstFunc);
DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc)); DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc));
// Check no inout parameter for patch constant function. // Check no inout parameter for patch constant function.
DxilFunctionAnnotation *patchConstFuncAnnotation = DxilFunctionAnnotation *patchConstFuncAnnotation =

Просмотреть файл

@ -2,19 +2,19 @@
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Prototype header contents to be removed on implementation of features: // Prototype header contents to be removed on implementation of features:
#define HIT_KIND_TRIANGLE_FRONT_FACE 0xFE #define HIT_KIND_TRIANGLE_FRONT_FACE 0xFE
#define HIT_KIND_TRIANGLE_BACK_FACE 0xFF #define HIT_KIND_TRIANGLE_BACK_FACE 0xFF
typedef uint RAY_FLAG; typedef uint RAY_FLAG;
#define RAY_FLAG_NONE 0x00 #define RAY_FLAG_NONE 0x00
#define RAY_FLAG_FORCE_OPAQUE 0x01 #define RAY_FLAG_FORCE_OPAQUE 0x01
#define RAY_FLAG_FORCE_NON_OPAQUE 0x02 #define RAY_FLAG_FORCE_NON_OPAQUE 0x02
#define RAY_FLAG_TERMINATE_ON_FIRST_HIT 0x04 #define RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH 0x04
#define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER 0x08 #define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER 0x08
#define RAY_FLAG_CULL_BACK_FACING_TRIANGLES 0x10 #define RAY_FLAG_CULL_BACK_FACING_TRIANGLES 0x10
#define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES 0x20 #define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES 0x20
#define RAY_FLAG_CULL_OPAQUE 0x40 #define RAY_FLAG_CULL_OPAQUE 0x40
#define RAY_FLAG_CULL_NON_OPAQUE 0x80 #define RAY_FLAG_CULL_NON_OPAQUE 0x80
struct RayDesc struct RayDesc
{ {
@ -29,38 +29,46 @@ struct BuiltInTriangleIntersectionAttributes
float2 barycentrics; float2 barycentrics;
}; };
typedef ByteAddressBuffer RayTracingAccelerationStructure; typedef ByteAddressBuffer RaytracingAccelerationStructure;
// group: Indirect Shader Invocation
// Declare TraceRay overload for given payload structure // Declare TraceRay overload for given payload structure
#define Declare_TraceRay(payload_t) \ #define Declare_TraceRay(payload_t) \
void TraceRay(RayTracingAccelerationStructure, uint RayFlags, uint InstanceCullMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t); void TraceRay(RaytracingAccelerationStructure, uint RayFlags, uint InstanceInclusionMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t);
// Declare ReportIntersection overload for given attribute structure // Declare ReportHit overload for given attribute structure
#define Declare_ReportIntersection(attr_t) \ #define Declare_ReportHit(attr_t) \
bool ReportIntersection(float HitT, uint HitKind, attr_t); bool ReportHit(float HitT, uint HitKind, attr_t);
// Declare CallShader overload for given param structure // Declare CallShader overload for given param structure
#define Declare_CallShader(param_t) \ #define Declare_CallShader(param_t) \
void CallShader(uint ShaderIndex, inout param_t); void CallShader(uint ShaderIndex, inout param_t);
void IgnoreIntersection(); // group: AnyHit Terminals
void TerminateRay(); void IgnoreHit();
void AcceptHitAndEndSearch();
// System Value retrieval functions // System Value retrieval functions
// group: Ray Dispatch Arguments
uint2 RayDispatchIndex(); uint2 RayDispatchIndex();
uint2 RayDispatchDimension(); uint2 RayDispatchDimension();
// group: Ray Vectors
float3 WorldRayOrigin(); float3 WorldRayOrigin();
float3 WorldRayDirection(); float3 WorldRayDirection();
float RayTMin();
float CurrentRayT();
uint PrimitiveID();
uint InstanceID();
uint InstanceIndex();
float3 ObjectRayOrigin(); float3 ObjectRayOrigin();
float3 ObjectRayDirection(); float3 ObjectRayDirection();
// group: RayT
float RayTMin();
float CurrentRayT();
// group: Raytracing uint System Values
uint PrimitiveID(); // watch for existing
uint InstanceID();
uint InstanceIndex();
uint HitKind();
uint RayFlag();
// group: Ray Transforms
row_major float3x4 ObjectToWorld(); row_major float3x4 ObjectToWorld();
row_major float3x4 WorldToObject(); row_major float3x4 WorldToObject();
uint HitKind();
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
struct MyPayload { struct MyPayload {
@ -79,7 +87,7 @@ struct MyParam {
}; };
Declare_TraceRay(MyPayload); Declare_TraceRay(MyPayload);
Declare_ReportIntersection(MyAttributes); Declare_ReportHit(MyAttributes);
Declare_CallShader(MyParam); Declare_CallShader(MyParam);
// CHECK: ; S sampler NA NA S0 s1 1 // CHECK: ; S sampler NA NA S0 s1 1
@ -90,7 +98,7 @@ Declare_CallShader(MyParam);
// CHECK: @T_rangeID = external constant i32 // CHECK: @T_rangeID = external constant i32
// CHECK: @S_rangeID = external constant i32 // CHECK: @S_rangeID = external constant i32
RayTracingAccelerationStructure RTAS : register(t5); RaytracingAccelerationStructure RTAS : register(t5);
// CHECK: define void [[raygen1:@"\\01\?raygen1@[^\"]+"]]() { // CHECK: define void [[raygen1:@"\\01\?raygen1@[^\"]+"]]() {
// CHECK: [[RAWBUF_ID:[^ ]+]] = load i32, i32* @RTAS_rangeID // CHECK: [[RAWBUF_ID:[^ ]+]] = load i32, i32* @RTAS_rangeID
@ -114,7 +122,7 @@ void raygen1()
// CHECK: define void [[intersection1:@"\\01\?intersection1@[^\"]+"]]() { // CHECK: define void [[intersection1:@"\\01\?intersection1@[^\"]+"]]() {
// CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull [[pCurrentRayT:%[^)]+]]) // CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull [[pCurrentRayT:%[^)]+]])
// CHECK: [[CurrentRayT:%[^ ]+]] = load float, float* [[pCurrentRayT]], align 4 // CHECK: [[CurrentRayT:%[^ ]+]] = load float, float* [[pCurrentRayT]], align 4
// CHECK: call void {{.*}}ReportIntersection{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}}) // CHECK: call void {{.*}}ReportHit{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}})
// CHECK: ret void // CHECK: ret void
[shader("intersection")] [shader("intersection")]
@ -122,15 +130,15 @@ void intersection1()
{ {
float hitT = CurrentRayT(); float hitT = CurrentRayT();
MyAttributes attr = (MyAttributes)0; MyAttributes attr = (MyAttributes)0;
bool bReported = ReportIntersection(hitT, 0, attr); bool bReported = ReportHit(hitT, 0, attr);
} }
// CHECK: define void [[anyhit1:@"\\01\?anyhit1@[^\"]+"]](float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, i32* noalias nocapture, i32* noalias nocapture, float, float, i32) // CHECK: define void [[anyhit1:@"\\01\?anyhit1@[^\"]+"]](float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, i32* noalias nocapture, i32* noalias nocapture, float, float, i32)
// CHECK: call void {{.*}}ObjectRayOrigin{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}}) // CHECK: call void {{.*}}ObjectRayOrigin{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}})
// CHECK: call void {{.*}}ObjectRayDirection{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}}) // CHECK: call void {{.*}}ObjectRayDirection{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}})
// CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull {{.*}}) // CHECK: call void {{.*}}CurrentRayT{{.*}}(float* nonnull {{.*}})
// CHECK: call void {{.*}}TerminateRay{{.*}}() // CHECK: call void {{.*}}AcceptHitAndEndSearch{{.*}}()
// CHECK: call void {{.*}}IgnoreIntersection{{.*}}() // CHECK: call void {{.*}}IgnoreHit{{.*}}()
// CHECK: store float {{.*}}, float* %0, align 4 // CHECK: store float {{.*}}, float* %0, align 4
// CHECK: store float {{.*}}, float* %1, align 4 // CHECK: store float {{.*}}, float* %1, align 4
// CHECK: store float {{.*}}, float* %2, align 4 // CHECK: store float {{.*}}, float* %2, align 4
@ -145,9 +153,9 @@ void anyhit1( inout MyPayload payload : SV_RayPayload,
{ {
float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * CurrentRayT(); float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * CurrentRayT();
if (hitLocation.z < attr.bary.x) if (hitLocation.z < attr.bary.x)
TerminateRay(); // aborts function AcceptHitAndEndSearch(); // aborts function
if (hitLocation.z < attr.bary.y) if (hitLocation.z < attr.bary.y)
IgnoreIntersection(); // aborts function IgnoreHit(); // aborts function
payload.color += float4(0.125, 0.25, 0.5, 1.0); payload.color += float4(0.125, 0.25, 0.5, 1.0);
} }