Use Attribute to designate wave-sensitive intrinsics (#2853)
* Use Attribute to designate wave-sensitive intrinsics This adds an intrinsic attribute to indicate wave-sensitivity that can be indicated in gen_intrin_main.txt. This and other attributes are passed along through function representations and lowerings. The wave-sensitivity needs to be maintained specifically through SROA passes since it is used by the CleanupDxbreak pass that comes after them. Specifically this is done to allow extension intrinsics to indicate wave-sensitivity, but the same mechanism is now used for builtin intrinsics. Most intrinsics get a mostly nameless function to represent them between Codegen and dxilgen. This allows any that have the same prototype to share the same function. For wave-sensitive intrinsics, they need a different function or else the attrubute would be similarly shared with intrinsics matching the prototype. So a minor change is made to their function names to prevent this. Adds testing for all these ops and a dummy extension one.
This commit is contained in:
Родитель
49876c21b4
Коммит
a0196bcc22
|
@ -11,15 +11,20 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <string>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class Function;
|
||||
class CallInst;
|
||||
class Argument;
|
||||
class StringRef;
|
||||
template<typename T> class ArrayRef;
|
||||
class AttributeSet;
|
||||
class CallInst;
|
||||
class Function;
|
||||
class FunctionType;
|
||||
class Module;
|
||||
class StringRef;
|
||||
class Type;
|
||||
class Value;
|
||||
}
|
||||
|
||||
namespace hlsl {
|
||||
|
@ -123,6 +128,9 @@ unsigned GetHLOpcode(const llvm::CallInst *CI);
|
|||
unsigned GetRowMajorOpcode(HLOpcodeGroup group, unsigned opcode);
|
||||
void SetHLLowerStrategy(llvm::Function *F, llvm::StringRef S);
|
||||
|
||||
void SetHLWaveSensitive(llvm::Function *F);
|
||||
bool IsHLWaveSensitive(llvm::Function *F);
|
||||
|
||||
// For intrinsic opcode.
|
||||
bool HasUnsignedOpcode(unsigned opcode);
|
||||
unsigned GetUnsignedOpcode(unsigned opcode);
|
||||
|
@ -132,9 +140,6 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode);
|
|||
|
||||
llvm::StringRef GetHLOpcodeGroupName(HLOpcodeGroup op);
|
||||
|
||||
// Determine if this call is to an operation that is dependent on other members of its wave
|
||||
bool IsCallWaveSensitive(llvm::CallInst *CI);
|
||||
|
||||
namespace HLOperandIndex {
|
||||
// Opcode parameter.
|
||||
const unsigned kOpcodeIdx = 0;
|
||||
|
@ -389,9 +394,30 @@ llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
|
|||
llvm::StringRef *fnName,
|
||||
unsigned opcode);
|
||||
|
||||
llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
|
||||
llvm::FunctionType *funcTy,
|
||||
HLOpcodeGroup group, unsigned opcode,
|
||||
const llvm::AttributeSet &attribs);
|
||||
llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
|
||||
llvm::FunctionType *funcTy,
|
||||
HLOpcodeGroup group,
|
||||
llvm::StringRef *groupName,
|
||||
llvm::StringRef *fnName,
|
||||
unsigned opcode,
|
||||
const llvm::AttributeSet &attribs);
|
||||
|
||||
llvm::Function *GetOrCreateHLFunctionWithBody(llvm::Module &M,
|
||||
llvm::FunctionType *funcTy,
|
||||
HLOpcodeGroup group,
|
||||
unsigned opcode,
|
||||
llvm::StringRef name);
|
||||
|
||||
llvm::Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
llvm::Type *RetTy, llvm::ArrayRef<llvm::Value*> Args,
|
||||
const llvm::AttributeSet &attribs, llvm::IRBuilder<> &Builder);
|
||||
|
||||
llvm::Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
llvm::Type *RetTy, llvm::ArrayRef<llvm::Value*> Args,
|
||||
llvm::IRBuilder<> &Builder);
|
||||
|
||||
} // namespace hlsl
|
||||
|
|
|
@ -117,6 +117,7 @@ struct HLSL_INTRINSIC {
|
|||
UINT Op; // Intrinsic Op ID
|
||||
BOOL bReadOnly; // Only read memory
|
||||
BOOL bReadNone; // Not read memory
|
||||
BOOL bIsWave; // Is a wave-sensitive op
|
||||
INT iOverloadParamIndex; // Parameter decide the overload type, -1 means ret type
|
||||
UINT uNumArgs; // Count of arguments in pArgs.
|
||||
const HLSL_INTRINSIC_ARGUMENT* pArgs; // Pointer to first argument.
|
||||
|
|
|
@ -1142,14 +1142,13 @@ public:
|
|||
// For each wave operation, collect the blocks sensitive to it
|
||||
SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
|
||||
for (Function &IF : M->functions()) {
|
||||
if (&IF == &F || !IF.getNumUses() || !IF.isDeclaration() ||
|
||||
hlsl::GetHLOpcodeGroupByName(&IF) != HLOpcodeGroup::HLIntrinsic)
|
||||
continue;
|
||||
|
||||
for (User *U : IF.users()) {
|
||||
CallInst *CI = dyn_cast<CallInst>(U);
|
||||
if (CI && IsCallWaveSensitive(CI))
|
||||
HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
|
||||
if (&IF != &F && IF.getNumUses() && IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
|
||||
(opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
|
||||
for (User *U : IF.users()) {
|
||||
CallInst *CI = cast<CallInst>(U);
|
||||
CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -37,20 +37,6 @@
|
|||
using namespace llvm;
|
||||
using namespace hlsl;
|
||||
|
||||
static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
|
||||
SmallVector<Type*, 4> ArgTys;
|
||||
ArgTys.reserve(Args.size());
|
||||
for (Value *Arg : Args)
|
||||
ArgTys.emplace_back(Arg->getType());
|
||||
|
||||
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
|
||||
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
|
||||
|
||||
return Builder.CreateCall(Func, Args);
|
||||
}
|
||||
|
||||
|
||||
// Lowered UDT is the same layout, but with vectors and matrices translated to
|
||||
// arrays.
|
||||
// Returns nullptr for failure due to embedded HLSL object type.
|
||||
|
|
|
@ -159,14 +159,13 @@ private:
|
|||
Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
|
||||
Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
|
||||
Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
|
||||
Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
|
||||
Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
|
||||
Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
|
||||
Value *lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
|
||||
Value *lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
|
||||
Value *lowerHLCast(CallInst *Call, Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
|
||||
Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
|
||||
Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
|
||||
Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
|
||||
void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
|
||||
Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
|
||||
Value *lowerHLInit(CallInst *Call);
|
||||
Value *lowerHLSelect(CallInst *Call);
|
||||
|
||||
|
@ -784,7 +783,7 @@ Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeG
|
|||
return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
|
||||
|
||||
case HLOpcodeGroup::HLCast:
|
||||
return lowerHLCast(
|
||||
return lowerHLCast(Call,
|
||||
Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
|
||||
static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
|
||||
|
||||
|
@ -802,19 +801,6 @@ Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeG
|
|||
}
|
||||
}
|
||||
|
||||
static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
|
||||
SmallVector<Type*, 4> ArgTys;
|
||||
ArgTys.reserve(Args.size());
|
||||
for (Value *Arg : Args)
|
||||
ArgTys.emplace_back(Arg->getType());
|
||||
|
||||
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
|
||||
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
|
||||
|
||||
return Builder.CreateCall(Func, Args);
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
|
||||
IRBuilder<> Builder(Call);
|
||||
|
||||
|
@ -850,8 +836,9 @@ Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
|
|||
}
|
||||
|
||||
Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
|
||||
return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
|
||||
LoweredRetTy, LoweredArgs, Builder);
|
||||
return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
|
||||
LoweredRetTy, LoweredArgs,
|
||||
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
|
||||
}
|
||||
|
||||
// Handles multiplcation of a scalar with a matrix
|
||||
|
@ -1229,12 +1216,12 @@ Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode
|
|||
switch (Opcode) {
|
||||
case HLMatLoadStoreOpcode::RowMatLoad:
|
||||
case HLMatLoadStoreOpcode::ColMatLoad:
|
||||
return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
|
||||
return lowerHLLoad(Call, Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
|
||||
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
|
||||
|
||||
case HLMatLoadStoreOpcode::RowMatStore:
|
||||
case HLMatLoadStoreOpcode::ColMatStore:
|
||||
return lowerHLStore(
|
||||
return lowerHLStore(Call,
|
||||
Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
|
||||
Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
|
||||
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
|
||||
|
@ -1245,7 +1232,7 @@ Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode
|
|||
}
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
|
||||
Value *HLMatrixLowerPass::lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
|
||||
HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
|
||||
|
||||
Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
|
||||
|
@ -1254,13 +1241,15 @@ Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<>
|
|||
HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
|
||||
return callHLFunction(
|
||||
*m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
|
||||
MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
|
||||
MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr },
|
||||
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
|
||||
}
|
||||
|
||||
return MatTy.emitLoweredLoad(LoweredPtr, Builder);
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
|
||||
Value *HLMatrixLowerPass::lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr,
|
||||
bool RowMajor, bool Return, IRBuilder<> &Builder) {
|
||||
DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
|
||||
"Matrix store value/pointer type mismatch.");
|
||||
|
||||
|
@ -1272,7 +1261,8 @@ Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMaj
|
|||
return callHLFunction(
|
||||
*m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
|
||||
Return ? LoweredVal->getType() : Builder.getVoidTy(),
|
||||
{ Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
|
||||
{ Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal },
|
||||
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
|
||||
}
|
||||
|
||||
HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
|
||||
|
@ -1309,7 +1299,8 @@ static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opc
|
|||
return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
|
||||
Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
|
||||
HLCastOpcode Opcode, IRBuilder<> &Builder) {
|
||||
// The opcode really doesn't mean much here, the types involved are what drive most of the casting.
|
||||
DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
|
||||
|
||||
|
@ -1358,7 +1349,8 @@ Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opco
|
|||
DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
|
||||
"Unexpected cast of matrix argument.");
|
||||
LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
|
||||
LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
|
||||
LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src },
|
||||
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
|
||||
}
|
||||
else {
|
||||
LoweredSrc = getLoweredByValOperand(Src, Builder);
|
||||
|
@ -1498,21 +1490,6 @@ void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, Small
|
|||
addToDeadInsts(Call);
|
||||
}
|
||||
|
||||
// Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
|
||||
Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
|
||||
// Just replace the intrinsic by its equivalent with a lowered return type
|
||||
IRBuilder<> Builder(Call);
|
||||
|
||||
SmallVector<Value*, 4> Args;
|
||||
Args.reserve(Call->getNumArgOperands());
|
||||
for (Value *Arg : Call->arg_operands())
|
||||
Args.emplace_back(Arg);
|
||||
|
||||
Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
|
||||
return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
|
||||
LoweredRetTy, Args, Builder);
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
|
||||
DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ const char * const HLPrefix = HLPrefixStr;
|
|||
static const char HLLowerStrategyStr[] = "dx.hlls";
|
||||
static const char * const HLLowerStrategy = HLLowerStrategyStr;
|
||||
|
||||
static const char HLWaveSensitiveStr[] = "dx.wave-sensitive";
|
||||
static const char * const HLWaveSensitive = HLWaveSensitiveStr;
|
||||
|
||||
static StringRef HLOpcodeGroupNames[]{
|
||||
"notHLDXIL", // NotHL,
|
||||
"<ext>", // HLExtIntrinsic - should always refer through extension
|
||||
|
@ -289,6 +292,17 @@ void SetHLLowerStrategy(Function *F, StringRef S) {
|
|||
F->addFnAttr(HLLowerStrategy, S);
|
||||
}
|
||||
|
||||
// Set function attribute indicating wave-sensitivity
|
||||
void SetHLWaveSensitive(Function *F) {
|
||||
F->addFnAttr(HLWaveSensitive, "y");
|
||||
}
|
||||
|
||||
// Return if this Function is dependent on other wave members indicated by attribute
|
||||
bool IsHLWaveSensitive(Function *F) {
|
||||
AttributeSet attrSet = F->getAttributes();
|
||||
return attrSet.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive);
|
||||
}
|
||||
|
||||
std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
|
||||
assert(op != HLOpcodeGroup::HLExtIntrinsic && "else table name should be used");
|
||||
std::string opName = GetHLOpcodeGroupFullName(op).str() + ".";
|
||||
|
@ -452,53 +466,29 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
|
|||
}
|
||||
|
||||
|
||||
// Determine if this Call Instruction refers to an HLOpcode that is dependent on other wave members
|
||||
bool IsCallWaveSensitive(CallInst *CI) {
|
||||
hlsl::IntrinsicOp opcode = static_cast<hlsl::IntrinsicOp>(hlsl::GetHLOpcode(CI));
|
||||
switch(opcode) {
|
||||
case IntrinsicOp::IOP_WaveActiveAllEqual:
|
||||
case IntrinsicOp::IOP_WaveActiveAllTrue:
|
||||
case IntrinsicOp::IOP_WaveActiveAnyTrue:
|
||||
case IntrinsicOp::IOP_WaveActiveBallot:
|
||||
case IntrinsicOp::IOP_WaveActiveBitAnd:
|
||||
case IntrinsicOp::IOP_WaveActiveBitOr:
|
||||
case IntrinsicOp::IOP_WaveActiveBitXor:
|
||||
case IntrinsicOp::IOP_WaveActiveCountBits:
|
||||
case IntrinsicOp::IOP_WaveActiveMax:
|
||||
case IntrinsicOp::IOP_WaveActiveMin:
|
||||
case IntrinsicOp::IOP_WaveActiveProduct:
|
||||
case IntrinsicOp::IOP_WaveActiveSum:
|
||||
case IntrinsicOp::IOP_WaveIsFirstLane:
|
||||
case IntrinsicOp::IOP_WaveMatch:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixBitAnd:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixBitOr:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixBitXor:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixCountBits:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixProduct:
|
||||
case IntrinsicOp::IOP_WaveMultiPrefixSum:
|
||||
case IntrinsicOp::IOP_WavePrefixCountBits:
|
||||
case IntrinsicOp::IOP_WavePrefixProduct:
|
||||
case IntrinsicOp::IOP_WavePrefixSum:
|
||||
case IntrinsicOp::IOP_WaveReadLaneAt:
|
||||
case IntrinsicOp::IOP_WaveReadLaneFirst:
|
||||
case IntrinsicOp::IOP_QuadReadAcrossDiagonal:
|
||||
case IntrinsicOp::IOP_QuadReadAcrossX:
|
||||
case IntrinsicOp::IOP_QuadReadAcrossY:
|
||||
case IntrinsicOp::IOP_QuadReadLaneAt:
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
||||
HLOpcodeGroup group, unsigned opcode) {
|
||||
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);
|
||||
AttributeSet attribs;
|
||||
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode, attribs);
|
||||
}
|
||||
|
||||
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
||||
HLOpcodeGroup group, llvm::StringRef *groupName,
|
||||
llvm::StringRef *fnName, unsigned opcode) {
|
||||
HLOpcodeGroup group, StringRef *groupName,
|
||||
StringRef *fnName, unsigned opcode) {
|
||||
AttributeSet attribs;
|
||||
return GetOrCreateHLFunction(M, funcTy, group, groupName, fnName, opcode, attribs);
|
||||
}
|
||||
|
||||
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
||||
HLOpcodeGroup group, unsigned opcode,
|
||||
const AttributeSet &attribs) {
|
||||
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode, attribs);
|
||||
}
|
||||
|
||||
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
||||
HLOpcodeGroup group, StringRef *groupName,
|
||||
StringRef *fnName, unsigned opcode,
|
||||
const AttributeSet &attribs) {
|
||||
std::string mangledName;
|
||||
raw_string_ostream mangledNameStr(mangledName);
|
||||
if (group == HLOpcodeGroup::HLExtIntrinsic) {
|
||||
|
@ -510,6 +500,9 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
|||
}
|
||||
else {
|
||||
mangledNameStr << GetHLFullName(group, opcode);
|
||||
// Need to add wave sensitivity to name to prevent clashes with non-wave intrinsic
|
||||
if(attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
|
||||
mangledNameStr << "wave";
|
||||
mangledNameStr << '.';
|
||||
funcTy->print(mangledNameStr);
|
||||
}
|
||||
|
@ -523,6 +516,14 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
|
|||
|
||||
SetHLFunctionAttribute(F, group, opcode);
|
||||
|
||||
// Copy attributes
|
||||
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone))
|
||||
F->addFnAttr(Attribute::ReadNone);
|
||||
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly))
|
||||
F->addFnAttr(Attribute::ReadOnly);
|
||||
if (attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
|
||||
F->addFnAttr(HLWaveSensitive, "y");
|
||||
|
||||
return F;
|
||||
}
|
||||
|
||||
|
@ -546,4 +547,23 @@ Function *GetOrCreateHLFunctionWithBody(Module &M, FunctionType *funcTy,
|
|||
return F;
|
||||
}
|
||||
|
||||
Value *callHLFunction(Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
|
||||
AttributeSet attribs;
|
||||
return callHLFunction(Module, OpcodeGroup, Opcode, RetTy, Args, attribs, Builder);
|
||||
}
|
||||
|
||||
Value *callHLFunction(Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
||||
Type *RetTy, ArrayRef<Value*> Args, const AttributeSet &attribs, IRBuilder<> &Builder) {
|
||||
SmallVector<Type*, 4> ArgTys;
|
||||
ArgTys.reserve(Args.size());
|
||||
for (Value *Arg : Args)
|
||||
ArgTys.emplace_back(Arg->getType());
|
||||
|
||||
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
|
||||
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode, attribs);
|
||||
|
||||
return Builder.CreateCall(Func, Args);
|
||||
}
|
||||
|
||||
} // namespace hlsl
|
||||
|
|
|
@ -944,7 +944,7 @@ static unsigned IsPtrUsedByLoweredFn(
|
|||
|
||||
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
|
||||
unsigned opcode = CE->getOpcode();
|
||||
if (opcode == Instruction::AddrSpaceCast || Instruction::GetElementPtr)
|
||||
if (opcode == Instruction::AddrSpaceCast || opcode == Instruction::GetElementPtr)
|
||||
if (IsPtrUsedByLoweredFn(user, CollectedUses))
|
||||
bFound = true;
|
||||
}
|
||||
|
@ -967,7 +967,8 @@ static CallInst *RewriteIntrinsicCallForCastedArg(CallInst *CI, unsigned argIdx)
|
|||
newArgs[argIdx] = newArg;
|
||||
|
||||
FunctionType *newFuncTy = FunctionType::get(CI->getType(), newArgTypes, false);
|
||||
Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode);
|
||||
Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode,
|
||||
F->getAttributes().getFnAttributes());
|
||||
|
||||
IRBuilder<> Builder(CI);
|
||||
return Builder.CreateCall(newF, newArgs);
|
||||
|
@ -2779,7 +2780,8 @@ static CallInst *CreateFlattenedHLIntrinsicCall(
|
|||
FunctionType *flatFuncTy =
|
||||
FunctionType::get(CI->getType(), flatParamTys, false);
|
||||
Function *flatF =
|
||||
GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode);
|
||||
GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode,
|
||||
F->getAttributes().getFnAttributes());
|
||||
|
||||
return Builder.CreateCall(flatF, flatArgs);
|
||||
}
|
||||
|
|
|
@ -895,6 +895,12 @@ def HLSLPayload : InheritableAttr {
|
|||
let Documentation = [Undocumented];
|
||||
}
|
||||
|
||||
def HLSLWaveSensitive : InheritableAttr {
|
||||
let Spellings = [CXX11<"", "wavesensitive", 2015>];
|
||||
let Subjects = SubjectList<[ParmVar]>;
|
||||
let Documentation = [Undocumented];
|
||||
}
|
||||
|
||||
// HLSL Change Ends
|
||||
|
||||
// SPIRV Change Starts
|
||||
|
|
|
@ -1157,6 +1157,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
|
|||
if (hlsl::GetIntrinsicLowering(FD, lower))
|
||||
hlsl::SetHLLowerStrategy(F, lower);
|
||||
|
||||
|
||||
if (FD->hasAttr<HLSLWaveSensitiveAttr>())
|
||||
hlsl::SetHLWaveSensitive(F);
|
||||
|
||||
// Don't need to add FunctionQual for intrinsic function.
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
llvm::FunctionType *funcTy, HLOpcodeGroup group,
|
||||
unsigned opcode) {
|
||||
Function *opFunc = nullptr;
|
||||
|
||||
AttributeSet attribs = F->getAttributes().getFnAttributes();
|
||||
llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
|
||||
if (group == HLOpcodeGroup::HLIntrinsic) {
|
||||
IntrinsicOp intriOp = static_cast<IntrinsicOp>(opcode);
|
||||
|
@ -202,7 +202,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
llvm::Type *handleTy = funcTy->getParamType(HLOperandIndex::kHandleOpIdx);
|
||||
// Don't generate body for OutputStream::Append.
|
||||
if (bAppend && HLModule::IsStreamOutputPtrType(handleTy)) {
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -215,7 +215,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
bAppend ? (unsigned)IntrinsicOp::MOP_IncrementCounter
|
||||
: (unsigned)IntrinsicOp::MOP_DecrementCounter;
|
||||
Function *incCounterFunc =
|
||||
GetOrCreateHLFunction(M, IncCounterFuncTy, group, counterOpcode);
|
||||
GetOrCreateHLFunction(M, IncCounterFuncTy, group, counterOpcode, attribs);
|
||||
|
||||
llvm::Type *idxTy = counterTy;
|
||||
llvm::Type *valTy =
|
||||
|
@ -245,7 +245,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
|
||||
Function *subscriptFunc =
|
||||
GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
|
||||
(unsigned)HLSubscriptOpcode::DefaultSubscript);
|
||||
(unsigned)HLSubscriptOpcode::DefaultSubscript, attribs);
|
||||
|
||||
BasicBlock *BB =
|
||||
BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
|
||||
|
@ -304,8 +304,8 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
llvm::FunctionType::get(valTy, {opcodeTy, valTy}, false);
|
||||
unsigned sinOp = static_cast<unsigned>(IntrinsicOp::IOP_sin);
|
||||
unsigned cosOp = static_cast<unsigned>(IntrinsicOp::IOP_cos);
|
||||
Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp);
|
||||
Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp);
|
||||
Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp, attribs);
|
||||
Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp, attribs);
|
||||
|
||||
BasicBlock *BB =
|
||||
BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
|
||||
|
@ -328,23 +328,18 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
|
|||
Builder.CreateRetVoid();
|
||||
} break;
|
||||
default:
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
|
||||
break;
|
||||
}
|
||||
} else if (group == HLOpcodeGroup::HLExtIntrinsic) {
|
||||
llvm::StringRef fnName = F->getName();
|
||||
llvm::StringRef groupName = GetHLOpcodeGroupNameByAttr(F);
|
||||
opFunc =
|
||||
GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode);
|
||||
GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode, attribs);
|
||||
} else {
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
|
||||
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
|
||||
}
|
||||
|
||||
// Add attribute
|
||||
if (F->hasFnAttribute(Attribute::ReadNone))
|
||||
opFunc->addFnAttr(Attribute::ReadNone);
|
||||
if (F->hasFnAttribute(Attribute::ReadOnly))
|
||||
opFunc->addFnAttr(Attribute::ReadOnly);
|
||||
return opFunc;
|
||||
}
|
||||
|
||||
|
|
|
@ -1761,6 +1761,8 @@ static void AddHLSLIntrinsicAttr(FunctionDecl *FD, ASTContext &context,
|
|||
FD->addAttr(ConstAttr::CreateImplicit(context));
|
||||
if (pIntrinsic->bReadOnly)
|
||||
FD->addAttr(PureAttr::CreateImplicit(context));
|
||||
if (pIntrinsic->bIsWave)
|
||||
FD->addAttr(HLSLWaveSensitiveAttr::CreateImplicit(context));
|
||||
}
|
||||
|
||||
static
|
||||
|
@ -11207,6 +11209,9 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
|
|||
case AttributeList::AT_HLSLExport:
|
||||
declAttr = ::new (S.Context) HLSLExportAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex());
|
||||
break;
|
||||
case AttributeList::AT_HLSLWaveSensitive:
|
||||
declAttr = ::new (S.Context) HLSLWaveSensitiveAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex());
|
||||
break;
|
||||
default:
|
||||
Handled = false;
|
||||
break; // SPIRV Change: was return;
|
||||
|
@ -12644,6 +12649,7 @@ bool hlsl::IsHLSLAttr(clang::attr::Kind AttrKind) {
|
|||
case clang::attr::HLSLPayload:
|
||||
case clang::attr::NoInline:
|
||||
case clang::attr::HLSLExport:
|
||||
case clang::attr::HLSLWaveSensitive:
|
||||
case clang::attr::VKBinding:
|
||||
case clang::attr::VKBuiltIn:
|
||||
case clang::attr::VKConstantId:
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,365 @@
|
|||
// RUN: %dxc -T ps_6_5 %s | FileCheck %s
|
||||
StructuredBuffer<int> buf[]: register(t2);
|
||||
StructuredBuffer<uint4> g_mask;
|
||||
|
||||
// CHECK: @dx.break.cond = internal constant
|
||||
|
||||
// Cannonical example. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
||||
int main(int a : A, int b : B) : SV_Target
|
||||
{
|
||||
int res = 0;
|
||||
int i = WaveGetLaneCount() + WaveGetLaneIndex();
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i1 @dx.op.waveIsFirstLane
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
bool u = WaveIsFirstLane();
|
||||
if (a != u) {
|
||||
res += buf[b][(int)u];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i1 @dx.op.waveAnyTrue
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
bool u = WaveActiveAnyTrue(a);
|
||||
if (a != u) {
|
||||
res += buf[(int)u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i1 @dx.op.waveAllTrue
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
bool u = WaveActiveAllTrue(a);
|
||||
if (a != u) {
|
||||
res += buf[(int)u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i1 @dx.op.waveActiveAllEqual
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
bool u = WaveActiveAllEqual(a);
|
||||
if (a != u) {
|
||||
res += buf[(int)u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveReadLaneFirst(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveAllOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WaveActiveCountBits(a == b);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveSum(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveProduct(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveBit
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveBitAnd((uint)a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveBit
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveBitOr((uint)a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveBit
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveBitXor((uint)a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveMin(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveActiveOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveActiveMax(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.wavePrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WavePrefixCountBits(a < b);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.wavePrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WavePrefixSum(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.wavePrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WavePrefixProduct(a);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call %dx.types.fouri32 @dx.op.waveMatch
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint4 u = WaveMatch(a);
|
||||
if (a != u.y) {
|
||||
res += buf[u.x][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
uint4 mask = g_mask[0];
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveMultiPrefixBitAnd((uint)a, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveMultiPrefixBitOr((uint)a, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
int u = WaveMultiPrefixBitXor((uint)a, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixBitCount
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WaveMultiPrefixCountBits(a <= b, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WaveMultiPrefixProduct(a, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call i32 @dx.op.waveMultiPrefixOp
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// Loop with wave-dependent conditional break block
|
||||
for (;;) {
|
||||
uint u = WaveMultiPrefixSum(a, mask);
|
||||
if (a != u) {
|
||||
res += buf[u][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
|
@ -4,6 +4,8 @@ StructuredBuffer<int> buf[]: register(t2);
|
|||
|
||||
// Cannonical example. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInLoop
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
@ -23,6 +25,7 @@ StructuredBuffer<int> buf[]: register(t2);
|
|||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInLoop(int a : A, int b : B)
|
||||
{
|
||||
|
@ -52,6 +55,8 @@ int WaveInLoop(int a : A, int b : B)
|
|||
|
||||
// Wave moved to after the break block. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInPostLoop
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
@ -71,6 +76,7 @@ int WaveInLoop(int a : A, int b : B)
|
|||
// CHECK: br i1
|
||||
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInPostLoop(int a : A, int b : B)
|
||||
{
|
||||
|
@ -101,6 +107,8 @@ int WaveInPostLoop(int a : A, int b : B)
|
|||
|
||||
// Wave op inside break block. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInBreakBlock
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
@ -111,7 +119,7 @@ int WaveInPostLoop(int a : A, int b : B)
|
|||
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: br i1
|
||||
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInBreakBlock(int a : A, int b : B)
|
||||
{
|
||||
|
@ -131,6 +139,8 @@ int WaveInBreakBlock(int a : A, int b : B)
|
|||
}
|
||||
|
||||
// Wave in entry block. Expected to allow the break block to move out of loop
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInEntry
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
|
||||
// These verify the break block doesn't keep the conditional
|
||||
|
@ -140,6 +150,7 @@ int WaveInBreakBlock(int a : A, int b : B)
|
|||
// These verify the break block doesn't keep the conditional
|
||||
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInEntry(int a : A, int b : B)
|
||||
{
|
||||
|
@ -169,6 +180,8 @@ int WaveInEntry(int a : A, int b : B)
|
|||
|
||||
// Wave in subloop of larger loop. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInSubLoop
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
@ -186,6 +199,7 @@ int WaveInEntry(int a : A, int b : B)
|
|||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInSubLoop(int a : A, int b : B)
|
||||
{
|
||||
|
@ -218,6 +232,8 @@ int WaveInSubLoop(int a : A, int b : B)
|
|||
}
|
||||
|
||||
// Wave in a separate loop. Expected to allow the break block to move out of loop
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveInOtherLoop
|
||||
// CHECK: load i32
|
||||
// CHECK: icmp eq i32
|
||||
|
||||
|
@ -233,13 +249,14 @@ int WaveInSubLoop(int a : A, int b : B)
|
|||
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHICK-NOT: br i1
|
||||
// CHECK-NOT: br i1
|
||||
|
||||
// These verify the third break block doesn't
|
||||
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK-NOT: br i1
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveInOtherLoop(int a : A, int b : B)
|
||||
{
|
||||
|
@ -273,3 +290,44 @@ int WaveInOtherLoop(int a : A, int b : B)
|
|||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Test for matrices which pass through additional lowering which might lose wave-sensitive attrib
|
||||
// Verify this function loads the global
|
||||
// CHECK: define i32
|
||||
// CHECK-SAME: WaveWithMatrix
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
||||
|
||||
// These verify the break block keeps the conditional
|
||||
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
// CHECK: ret i32
|
||||
export
|
||||
int WaveWithMatrix(int3x3 a : A, int b : B)
|
||||
{
|
||||
int res = 0;
|
||||
|
||||
// Loop with wave-dependent matrix conditional break block
|
||||
for (;;) {
|
||||
int3x3 u = WaveReadLaneFirst(a);
|
||||
if (a[0][0] == u[1][1]) {
|
||||
res += buf[u[2][2]][b];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// RUN: %dxc -T ps_6_3 %s | FileCheck %s
|
||||
StructuredBuffer<int> buf[]: register(t2);
|
||||
// CHECK: @dx.break.cond = internal constant
|
||||
|
||||
// Cannonical example. Expected to keep the block in loop
|
||||
// Verify this function loads the global
|
||||
// CHECK: load i32
|
||||
// CHECK-SAME: @dx.break.cond
|
||||
// CHECK: icmp eq i32
|
||||
|
||||
// CHECK: call i32 @dx.op.waveActiveOp.i32
|
||||
|
||||
// These verify the first break block keeps the conditional
|
||||
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHECK: br i1
|
||||
|
||||
// These verify the second break block doesn't
|
||||
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
||||
// CHECK: add
|
||||
// CHICK-NOT: br i1
|
||||
|
||||
int main(int a : A, int b : B) : SV_Target
|
||||
{
|
||||
int res = 0;
|
||||
|
||||
// Loop with wave op
|
||||
for (;;) {
|
||||
int u = WaveActiveSum(a);
|
||||
if (a == u) {
|
||||
res += buf[b][u];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Loop with non-wave op with same signature as previous wave op
|
||||
// Without prototype manipulation, this will share an intermediate
|
||||
// op call with the previous and get the wave attribute.
|
||||
for (;;) {
|
||||
int u = abs(a--);
|
||||
if (a == u) {
|
||||
res += buf[b][u];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
|
@ -117,6 +117,12 @@ static const HLSL_INTRINSIC_ARGUMENT TestMySamplerOp[] = {
|
|||
{ "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
|
||||
};
|
||||
|
||||
// $result = wave_proc(any_vector<any_cardinality> value)
|
||||
static const HLSL_INTRINSIC_ARGUMENT WaveProcArgs[] = {
|
||||
{ "wave_proc", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
|
||||
{ "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
|
||||
};
|
||||
|
||||
struct Intrinsic {
|
||||
LPCWSTR hlslName;
|
||||
const char *dxilName;
|
||||
|
@ -130,31 +136,32 @@ template <class T, std::size_t N>
|
|||
UINT countof(T(&)[N]) { return static_cast<UINT>(N); }
|
||||
|
||||
Intrinsic Intrinsics[] = {
|
||||
{L"test_fn", DEFAULT_NAME, "r", { 1, false, true, -1, countof(TestFnArgs), TestFnArgs }},
|
||||
{L"test_proc", DEFAULT_NAME, "r", { 2, false, false,-1, countof(TestProcArgs), TestProcArgs }},
|
||||
{L"test_poly", "test_poly.$o", "r", { 3, false, true, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
|
||||
{L"test_int", "test_int", "r", { 4, false, true, -1, countof(TestFnIntArgs), TestFnIntArgs}},
|
||||
{L"test_nolower", "test_nolower.$o", "n", { 5, false, true, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
|
||||
{L"test_pack_0", "test_pack_0.$o", "p", { 6, false, false,-1, countof(TestFnPack0), TestFnPack0}},
|
||||
{L"test_pack_1", "test_pack_1.$o", "p", { 7, false, true, -1, countof(TestFnPack1), TestFnPack1}},
|
||||
{L"test_pack_2", "test_pack_2.$o", "p", { 8, false, true, -1, countof(TestFnPack2), TestFnPack2}},
|
||||
{L"test_pack_3", "test_pack_3.$o", "p", { 9, false, true, -1, countof(TestFnPack3), TestFnPack3}},
|
||||
{L"test_pack_4", "test_pack_4.$o", "p", { 10, false, false,-1, countof(TestFnPack4), TestFnPack4}},
|
||||
{L"test_rand", "test_rand", "r", { 11, false, false,-1, countof(TestRand), TestRand}},
|
||||
{L"test_isinf", "test_isinf", "d", { 13, true, true, -1, countof(TestIsInf), TestIsInf}},
|
||||
{L"test_ibfe", "test_ibfe", "d", { 14, true, true, -1, countof(TestIBFE), TestIBFE}},
|
||||
{L"test_fn", DEFAULT_NAME, "r", { 1, false, true, false, -1, countof(TestFnArgs), TestFnArgs }},
|
||||
{L"test_proc", DEFAULT_NAME, "r", { 2, false, false, false,-1, countof(TestProcArgs), TestProcArgs }},
|
||||
{L"test_poly", "test_poly.$o", "r", { 3, false, true, false, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
|
||||
{L"test_int", "test_int", "r", { 4, false, true, false, -1, countof(TestFnIntArgs), TestFnIntArgs}},
|
||||
{L"test_nolower", "test_nolower.$o", "n", { 5, false, true, false, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
|
||||
{L"test_pack_0", "test_pack_0.$o", "p", { 6, false, false, false,-1, countof(TestFnPack0), TestFnPack0}},
|
||||
{L"test_pack_1", "test_pack_1.$o", "p", { 7, false, true, false, -1, countof(TestFnPack1), TestFnPack1}},
|
||||
{L"test_pack_2", "test_pack_2.$o", "p", { 8, false, true, false, -1, countof(TestFnPack2), TestFnPack2}},
|
||||
{L"test_pack_3", "test_pack_3.$o", "p", { 9, false, true, false, -1, countof(TestFnPack3), TestFnPack3}},
|
||||
{L"test_pack_4", "test_pack_4.$o", "p", { 10, false, false, false,-1, countof(TestFnPack4), TestFnPack4}},
|
||||
{L"test_rand", "test_rand", "r", { 11, false, false, false,-1, countof(TestRand), TestRand}},
|
||||
{L"test_isinf", "test_isinf", "d", { 13, true, true, false, -1, countof(TestIsInf), TestIsInf}},
|
||||
{L"test_ibfe", "test_ibfe", "d", { 14, true, true, false, -1, countof(TestIBFE), TestIBFE}},
|
||||
// Make this intrinsic have the same opcode as an hlsl intrinsic with an unsigned
|
||||
// counterpart for testing purposes.
|
||||
{L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, -1, countof(TestUnsigned), TestUnsigned}},
|
||||
{L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, false, -1, countof(TestUnsigned), TestUnsigned}},
|
||||
{L"wave_proc", DEFAULT_NAME, "r", { 16, false, true, true, -1, countof(WaveProcArgs), WaveProcArgs }},
|
||||
};
|
||||
|
||||
Intrinsic BufferIntrinsics[] = {
|
||||
{L"MyBufferOp", "MyBufferOp", "m", { 12, false, true, -1, countof(TestMyBufferOp), TestMyBufferOp}},
|
||||
{L"MyBufferOp", "MyBufferOp", "m", { 12, false, true, false, -1, countof(TestMyBufferOp), TestMyBufferOp}},
|
||||
};
|
||||
|
||||
// Test adding a method to an object that normally has no methods (SamplerState will do).
|
||||
Intrinsic SamplerIntrinsics[] = {
|
||||
{L"MySamplerOp", "MySamplerOp", "m", { 15, false, true, -1, countof(TestMySamplerOp), TestMySamplerOp}},
|
||||
{L"MySamplerOp", "MySamplerOp", "m", { 15, false, true, false, -1, countof(TestMySamplerOp), TestMySamplerOp}},
|
||||
};
|
||||
|
||||
class IntrinsicTable {
|
||||
|
@ -454,6 +461,7 @@ public:
|
|||
TEST_METHOD(DxilLoweringVector2)
|
||||
TEST_METHOD(DxilLoweringScalar)
|
||||
TEST_METHOD(SamplerExtensionIntrinsic)
|
||||
TEST_METHOD(WaveIntrinsic)
|
||||
};
|
||||
|
||||
TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
|
||||
|
@ -880,3 +888,43 @@ TEST_F(ExtensionTest, SamplerExtensionIntrinsic) {
|
|||
CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
|
||||
}
|
||||
|
||||
TEST_F(ExtensionTest, WaveIntrinsic) {
|
||||
// Test wave-sensitive intrinsic in breaked loop
|
||||
Compiler c(m_dllSupport);
|
||||
c.RegisterIntrinsicTable(new TestIntrinsicTable());
|
||||
c.Compile(
|
||||
"StructuredBuffer<int> buf[]: register(t2);"
|
||||
"float2 main(float2 a : A, int b : B) : SV_Target {"
|
||||
" int res = 0;"
|
||||
" float2 u = {0,0};"
|
||||
" for (;;) {"
|
||||
" u += wave_proc(a);"
|
||||
" if (a.x == u.x) {"
|
||||
" res += buf[b][(int)u.y];"
|
||||
" break;"
|
||||
" }"
|
||||
" }"
|
||||
" return res;"
|
||||
"}",
|
||||
{ L"/Vd" }, {}
|
||||
);
|
||||
std::string disassembly = c.Disassemble();
|
||||
|
||||
// Check that the wave op causes the break block to be retained
|
||||
VERIFY_IS_TRUE(
|
||||
disassembly.npos !=
|
||||
disassembly.find("@dx.break.cond = internal constant [1 x i32] zeroinitializer"));
|
||||
VERIFY_IS_TRUE(
|
||||
disassembly.npos !=
|
||||
disassembly.find("%1 = load i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @dx.break.cond"));
|
||||
VERIFY_IS_TRUE(
|
||||
disassembly.npos !=
|
||||
disassembly.find("%2 = icmp eq i32 %1, 0"));
|
||||
VERIFY_IS_TRUE(
|
||||
disassembly.npos !=
|
||||
disassembly.find("call float @\"test.\\01?wave_proc@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32 16, float"));
|
||||
VERIFY_IS_TRUE(
|
||||
disassembly.npos !=
|
||||
disassembly.find("br i1 %2"));
|
||||
}
|
||||
|
||||
|
|
|
@ -245,34 +245,34 @@ bool [[rn]] CheckAccessFullyMapped(in uint_only status) : check_access_fully_map
|
|||
uint<c> [[rn]] AddUint64(in $match<1, 0> uint<c> a, in $match<2, 0> uint<c> b) : adduint64;
|
||||
$type1 [[rn]] NonUniformResourceIndex(in any<> index) : nonuniform_resource_index;
|
||||
|
||||
// Wave intrinsics.
|
||||
bool [[]] WaveIsFirstLane();
|
||||
// Wave intrinsics. Only those that depend on the exec mask are marked as wave-sensitive
|
||||
bool [[wv]] WaveIsFirstLane();
|
||||
uint [[rn]] WaveGetLaneIndex();
|
||||
uint [[rn]] WaveGetLaneCount();
|
||||
bool [[]] WaveActiveAnyTrue(in bool cond);
|
||||
bool [[]] WaveActiveAllTrue(in bool cond);
|
||||
$match<1, 0> bool<> [[]] WaveActiveAllEqual(in any<> value);
|
||||
uint<4> [[]] WaveActiveBallot(in bool cond);
|
||||
bool [[wv]] WaveActiveAnyTrue(in bool cond);
|
||||
bool [[wv]] WaveActiveAllTrue(in bool cond);
|
||||
$match<1, 0> bool<> [[wv]] WaveActiveAllEqual(in any<> value);
|
||||
uint<4> [[wv]] WaveActiveBallot(in bool cond);
|
||||
$type1 [[]] WaveReadLaneAt(in any<> value, in uint lane);
|
||||
$type1 [[]] WaveReadLaneFirst(in any<> value);
|
||||
uint [[]] WaveActiveCountBits(in bool value);
|
||||
$type1 [[unsigned_op=WaveActiveUSum]] WaveActiveSum(in numeric<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUProduct]] WaveActiveProduct(in numeric<> value);
|
||||
$type1 [[]] WaveActiveBitAnd(in uint_only<> value);
|
||||
$type1 [[]] WaveActiveBitOr(in uint_only<> value);
|
||||
$type1 [[]] WaveActiveBitXor(in uint_only<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUMin]] WaveActiveMin(in numeric<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUMax]] WaveActiveMax(in numeric<> value);
|
||||
uint [[]] WavePrefixCountBits(in bool value);
|
||||
$type1 [[unsigned_op=WavePrefixUSum]] WavePrefixSum(in numeric<> value);
|
||||
$type1 [[unsigned_op=WavePrefixUProduct]] WavePrefixProduct(in numeric<> value);
|
||||
uint<4> [[]] WaveMatch(in numeric<> value);
|
||||
$type1 [[]] WaveMultiPrefixBitAnd(in any_int<> value, in uint<4> mask);
|
||||
$type1 [[]] WaveMultiPrefixBitOr(in any_int<> value, in uint<4> mask);
|
||||
$type1 [[]] WaveMultiPrefixBitXor(in any_int<> value, in uint<4> mask);
|
||||
uint [[]] WaveMultiPrefixCountBits(in bool value, in uint<4> mask);
|
||||
$type1 [[unsigned_op=WaveMultiPrefixUProduct]] WaveMultiPrefixProduct(in numeric<> value, in uint<4> mask);
|
||||
$type1 [[unsigned_op=WaveMultiPrefixUSum]] WaveMultiPrefixSum(in numeric<> value, in uint<4> mask);
|
||||
$type1 [[wv]] WaveReadLaneFirst(in any<> value);
|
||||
uint [[wv]] WaveActiveCountBits(in bool value);
|
||||
$type1 [[unsigned_op=WaveActiveUSum,wv]] WaveActiveSum(in numeric<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUProduct,wv]] WaveActiveProduct(in numeric<> value);
|
||||
$type1 [[wv]] WaveActiveBitAnd(in uint_only<> value);
|
||||
$type1 [[wv]] WaveActiveBitOr(in uint_only<> value);
|
||||
$type1 [[wv]] WaveActiveBitXor(in uint_only<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUMin,wv]] WaveActiveMin(in numeric<> value);
|
||||
$type1 [[unsigned_op=WaveActiveUMax,wv]] WaveActiveMax(in numeric<> value);
|
||||
uint [[wv]] WavePrefixCountBits(in bool value);
|
||||
$type1 [[unsigned_op=WavePrefixUSum,wv]] WavePrefixSum(in numeric<> value);
|
||||
$type1 [[unsigned_op=WavePrefixUProduct,wv]] WavePrefixProduct(in numeric<> value);
|
||||
uint<4> [[wv]] WaveMatch(in numeric<> value);
|
||||
$type1 [[wv]] WaveMultiPrefixBitAnd(in any_int<> value, in uint<4> mask);
|
||||
$type1 [[wv]] WaveMultiPrefixBitOr(in any_int<> value, in uint<4> mask);
|
||||
$type1 [[wv]] WaveMultiPrefixBitXor(in any_int<> value, in uint<4> mask);
|
||||
uint [[wv]] WaveMultiPrefixCountBits(in bool value, in uint<4> mask);
|
||||
$type1 [[unsigned_op=WaveMultiPrefixUProduct,wv]] WaveMultiPrefixProduct(in numeric<> value, in uint<4> mask);
|
||||
$type1 [[unsigned_op=WaveMultiPrefixUSum,wv]] WaveMultiPrefixSum(in numeric<> value, in uint<4> mask);
|
||||
$type1 [[]] QuadReadLaneAt(in numeric<> value, in uint quadLane);
|
||||
$type1 [[]] QuadReadAcrossX(in numeric<> value);
|
||||
$type1 [[]] QuadReadAcrossY(in numeric<> value);
|
||||
|
|
|
@ -2621,7 +2621,7 @@ class db_hlsl_attribute(object):
|
|||
|
||||
class db_hlsl_intrinsic(object):
|
||||
"An HLSL intrinsic declaration"
|
||||
def __init__(self, name, idx, opname, params, ns, ns_idx, doc, ro, rn, unsigned_op, overload_idx, hidden):
|
||||
def __init__(self, name, idx, opname, params, ns, ns_idx, doc, ro, rn, wv, unsigned_op, overload_idx, hidden):
|
||||
self.name = name # Function name
|
||||
self.idx = idx # Unique number within namespace
|
||||
self.opname = opname # D3D-style name
|
||||
|
@ -2633,6 +2633,7 @@ class db_hlsl_intrinsic(object):
|
|||
self.enum_name = "%s_%s" % (id_prefix, name) # enum name
|
||||
self.readonly = ro # Only read memory
|
||||
self.readnone = rn # Not read memory
|
||||
self.wave = wv # Is wave-sensitive
|
||||
self.unsigned_op = unsigned_op # Unsigned opcode if exist
|
||||
if unsigned_op != "":
|
||||
self.unsigned_op = "%s_%s" % (id_prefix, unsigned_op)
|
||||
|
@ -2873,6 +2874,7 @@ class db_hlsl(object):
|
|||
attrs = attr.split(',')
|
||||
readonly = False # Only read memory
|
||||
readnone = False # Not read memory
|
||||
is_wave = False; # Is wave-sensitive
|
||||
unsigned_op = "" # Unsigned opcode if exist
|
||||
overload_param_index = -1 # Parameter determines the overload type, -1 means ret type.
|
||||
hidden = False
|
||||
|
@ -2885,6 +2887,9 @@ class db_hlsl(object):
|
|||
if (a == "rn"):
|
||||
readnone = True
|
||||
continue
|
||||
if (a == "wv"):
|
||||
is_wave = True
|
||||
continue
|
||||
if (a == "hidden"):
|
||||
hidden = True
|
||||
continue
|
||||
|
@ -2904,7 +2909,7 @@ class db_hlsl(object):
|
|||
continue
|
||||
assert False, "invalid attr %s" % (a)
|
||||
|
||||
return readonly, readnone, unsigned_op, overload_param_index, hidden
|
||||
return readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden
|
||||
|
||||
current_namespace = None
|
||||
for line in intrinsic_defs:
|
||||
|
@ -2937,7 +2942,7 @@ class db_hlsl(object):
|
|||
op = operand_match.group(1)
|
||||
if not op:
|
||||
op = name
|
||||
readonly, readnone, unsigned_op, overload_param_index, hidden = process_attr(attr)
|
||||
readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden = process_attr(attr)
|
||||
# Add an entry for this intrinsic.
|
||||
if bracket_cleanup_re.search(opts):
|
||||
opts = bracket_cleanup_re.sub(r"<\1@\2>", opts)
|
||||
|
@ -2961,7 +2966,7 @@ class db_hlsl(object):
|
|||
# TODO: verify a single level of indirection
|
||||
self.intrinsics.append(db_hlsl_intrinsic(
|
||||
name, num_entries, op, args, current_namespace, ns_idx, "pending doc for " + name,
|
||||
readonly, readnone, unsigned_op, overload_param_index, hidden))
|
||||
readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden))
|
||||
num_entries += 1
|
||||
continue
|
||||
assert False, "cannot parse line %s" % (line)
|
||||
|
|
|
@ -363,7 +363,7 @@ class db_oload_gen:
|
|||
f = lambda i,c : "true" if i.oload_types.find(c) >= 0 else "false"
|
||||
lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
|
||||
lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
|
||||
attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn" }
|
||||
attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn", "wv" : "None" }
|
||||
attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
|
||||
for i in self.instrs:
|
||||
if last_category != i.category:
|
||||
|
@ -691,7 +691,7 @@ def get_hlsl_intrinsics():
|
|||
result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
|
||||
# SPIRV Change Ends
|
||||
arg_idx = 0
|
||||
ns_table += " {(UINT)%s::%s_%s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
|
||||
ns_table += " {(UINT)%s::%s_%s, %s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), str(i.wave).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
|
||||
result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
|
||||
for p in i.params:
|
||||
name = p.name
|
||||
|
|
Загрузка…
Ссылка в новой задаче