Use Attribute to designate wave-sensitive intrinsics (#2853)

* Use Attribute to designate wave-sensitive intrinsics

This adds an intrinsic attribute to indicate wave-sensitivity that can
be indicated in gen_intrin_main.txt. This and other attributes are
passed along through function representations and lowerings. The
wave-sensitivity needs to be maintained specifically through SROA passes
since it is used by the CleanupDxbreak pass that comes after them.

Specifically this is done to allow extension intrinsics to indicate
wave-sensitivity, but the same mechanism is now used for builtin
intrinsics.

Most intrinsics get a mostly nameless function to represent them between
Codegen and dxilgen. This allows any that have the same prototype to
share the same function. For wave-sensitive intrinsics, they need a
different function or else the attrubute would be similarly shared with
intrinsics matching the prototype. So a minor change is made to their
function names to prevent this.

Adds testing for all these ops and a dummy extension one.
This commit is contained in:
Greg Roth 2020-04-28 18:19:51 -07:00 коммит произвёл GitHub
Родитель 49876c21b4
Коммит a0196bcc22
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 1416 добавлений и 869 удалений

Просмотреть файл

@ -11,15 +11,20 @@
#pragma once
#include "llvm/IR/IRBuilder.h"
#include <string>
namespace llvm {
class Module;
class Function;
class CallInst;
class Argument;
class StringRef;
template<typename T> class ArrayRef;
class AttributeSet;
class CallInst;
class Function;
class FunctionType;
class Module;
class StringRef;
class Type;
class Value;
}
namespace hlsl {
@ -123,6 +128,9 @@ unsigned GetHLOpcode(const llvm::CallInst *CI);
unsigned GetRowMajorOpcode(HLOpcodeGroup group, unsigned opcode);
void SetHLLowerStrategy(llvm::Function *F, llvm::StringRef S);
void SetHLWaveSensitive(llvm::Function *F);
bool IsHLWaveSensitive(llvm::Function *F);
// For intrinsic opcode.
bool HasUnsignedOpcode(unsigned opcode);
unsigned GetUnsignedOpcode(unsigned opcode);
@ -132,9 +140,6 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode);
llvm::StringRef GetHLOpcodeGroupName(HLOpcodeGroup op);
// Determine if this call is to an operation that is dependent on other members of its wave
bool IsCallWaveSensitive(llvm::CallInst *CI);
namespace HLOperandIndex {
// Opcode parameter.
const unsigned kOpcodeIdx = 0;
@ -389,9 +394,30 @@ llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
llvm::StringRef *fnName,
unsigned opcode);
llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
llvm::FunctionType *funcTy,
HLOpcodeGroup group, unsigned opcode,
const llvm::AttributeSet &attribs);
llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
llvm::FunctionType *funcTy,
HLOpcodeGroup group,
llvm::StringRef *groupName,
llvm::StringRef *fnName,
unsigned opcode,
const llvm::AttributeSet &attribs);
llvm::Function *GetOrCreateHLFunctionWithBody(llvm::Module &M,
llvm::FunctionType *funcTy,
HLOpcodeGroup group,
unsigned opcode,
llvm::StringRef name);
llvm::Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
llvm::Type *RetTy, llvm::ArrayRef<llvm::Value*> Args,
const llvm::AttributeSet &attribs, llvm::IRBuilder<> &Builder);
llvm::Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
llvm::Type *RetTy, llvm::ArrayRef<llvm::Value*> Args,
llvm::IRBuilder<> &Builder);
} // namespace hlsl

Просмотреть файл

@ -117,6 +117,7 @@ struct HLSL_INTRINSIC {
UINT Op; // Intrinsic Op ID
BOOL bReadOnly; // Only read memory
BOOL bReadNone; // Not read memory
BOOL bIsWave; // Is a wave-sensitive op
INT iOverloadParamIndex; // Parameter decide the overload type, -1 means ret type
UINT uNumArgs; // Count of arguments in pArgs.
const HLSL_INTRINSIC_ARGUMENT* pArgs; // Pointer to first argument.

Просмотреть файл

@ -1142,14 +1142,13 @@ public:
// For each wave operation, collect the blocks sensitive to it
SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
for (Function &IF : M->functions()) {
if (&IF == &F || !IF.getNumUses() || !IF.isDeclaration() ||
hlsl::GetHLOpcodeGroupByName(&IF) != HLOpcodeGroup::HLIntrinsic)
continue;
for (User *U : IF.users()) {
CallInst *CI = dyn_cast<CallInst>(U);
if (CI && IsCallWaveSensitive(CI))
HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
if (&IF != &F && IF.getNumUses() && IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
(opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
for (User *U : IF.users()) {
CallInst *CI = cast<CallInst>(U);
CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
}
}
}

Просмотреть файл

@ -37,20 +37,6 @@
using namespace llvm;
using namespace hlsl;
static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
SmallVector<Type*, 4> ArgTys;
ArgTys.reserve(Args.size());
for (Value *Arg : Args)
ArgTys.emplace_back(Arg->getType());
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
return Builder.CreateCall(Func, Args);
}
// Lowered UDT is the same layout, but with vectors and matrices translated to
// arrays.
// Returns nullptr for failure due to embedded HLSL object type.

Просмотреть файл

@ -159,14 +159,13 @@ private:
Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
Value *lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
Value *lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
Value *lowerHLCast(CallInst *Call, Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
Value *lowerHLInit(CallInst *Call);
Value *lowerHLSelect(CallInst *Call);
@ -784,7 +783,7 @@ Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeG
return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
case HLOpcodeGroup::HLCast:
return lowerHLCast(
return lowerHLCast(Call,
Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
@ -802,19 +801,6 @@ Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeG
}
}
static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
SmallVector<Type*, 4> ArgTys;
ArgTys.reserve(Args.size());
for (Value *Arg : Args)
ArgTys.emplace_back(Arg->getType());
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
return Builder.CreateCall(Func, Args);
}
Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
IRBuilder<> Builder(Call);
@ -850,8 +836,9 @@ Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
}
Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
LoweredRetTy, LoweredArgs, Builder);
return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
LoweredRetTy, LoweredArgs,
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
// Handles multiplcation of a scalar with a matrix
@ -1229,12 +1216,12 @@ Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode
switch (Opcode) {
case HLMatLoadStoreOpcode::RowMatLoad:
case HLMatLoadStoreOpcode::ColMatLoad:
return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
return lowerHLLoad(Call, Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
case HLMatLoadStoreOpcode::RowMatStore:
case HLMatLoadStoreOpcode::ColMatStore:
return lowerHLStore(
return lowerHLStore(Call,
Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
@ -1245,7 +1232,7 @@ Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode
}
}
Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
Value *HLMatrixLowerPass::lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
@ -1254,13 +1241,15 @@ Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<>
HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
return callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr },
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
return MatTy.emitLoweredLoad(LoweredPtr, Builder);
}
Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
Value *HLMatrixLowerPass::lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr,
bool RowMajor, bool Return, IRBuilder<> &Builder) {
DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
"Matrix store value/pointer type mismatch.");
@ -1272,7 +1261,8 @@ Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMaj
return callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
Return ? LoweredVal->getType() : Builder.getVoidTy(),
{ Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
{ Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal },
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
@ -1309,7 +1299,8 @@ static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opc
return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
}
Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
HLCastOpcode Opcode, IRBuilder<> &Builder) {
// The opcode really doesn't mean much here, the types involved are what drive most of the casting.
DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
@ -1358,7 +1349,8 @@ Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opco
DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
"Unexpected cast of matrix argument.");
LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src },
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
else {
LoweredSrc = getLoweredByValOperand(Src, Builder);
@ -1498,21 +1490,6 @@ void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, Small
addToDeadInsts(Call);
}
// Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
// Just replace the intrinsic by its equivalent with a lowered return type
IRBuilder<> Builder(Call);
SmallVector<Value*, 4> Args;
Args.reserve(Call->getNumArgOperands());
for (Value *Arg : Call->arg_operands())
Args.emplace_back(Arg);
Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
LoweredRetTy, Args, Builder);
}
Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");

Просмотреть файл

@ -26,6 +26,9 @@ const char * const HLPrefix = HLPrefixStr;
static const char HLLowerStrategyStr[] = "dx.hlls";
static const char * const HLLowerStrategy = HLLowerStrategyStr;
static const char HLWaveSensitiveStr[] = "dx.wave-sensitive";
static const char * const HLWaveSensitive = HLWaveSensitiveStr;
static StringRef HLOpcodeGroupNames[]{
"notHLDXIL", // NotHL,
"<ext>", // HLExtIntrinsic - should always refer through extension
@ -289,6 +292,17 @@ void SetHLLowerStrategy(Function *F, StringRef S) {
F->addFnAttr(HLLowerStrategy, S);
}
// Set function attribute indicating wave-sensitivity
void SetHLWaveSensitive(Function *F) {
F->addFnAttr(HLWaveSensitive, "y");
}
// Return if this Function is dependent on other wave members indicated by attribute
bool IsHLWaveSensitive(Function *F) {
AttributeSet attrSet = F->getAttributes();
return attrSet.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive);
}
std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
assert(op != HLOpcodeGroup::HLExtIntrinsic && "else table name should be used");
std::string opName = GetHLOpcodeGroupFullName(op).str() + ".";
@ -452,53 +466,29 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
}
// Determine if this Call Instruction refers to an HLOpcode that is dependent on other wave members
bool IsCallWaveSensitive(CallInst *CI) {
hlsl::IntrinsicOp opcode = static_cast<hlsl::IntrinsicOp>(hlsl::GetHLOpcode(CI));
switch(opcode) {
case IntrinsicOp::IOP_WaveActiveAllEqual:
case IntrinsicOp::IOP_WaveActiveAllTrue:
case IntrinsicOp::IOP_WaveActiveAnyTrue:
case IntrinsicOp::IOP_WaveActiveBallot:
case IntrinsicOp::IOP_WaveActiveBitAnd:
case IntrinsicOp::IOP_WaveActiveBitOr:
case IntrinsicOp::IOP_WaveActiveBitXor:
case IntrinsicOp::IOP_WaveActiveCountBits:
case IntrinsicOp::IOP_WaveActiveMax:
case IntrinsicOp::IOP_WaveActiveMin:
case IntrinsicOp::IOP_WaveActiveProduct:
case IntrinsicOp::IOP_WaveActiveSum:
case IntrinsicOp::IOP_WaveIsFirstLane:
case IntrinsicOp::IOP_WaveMatch:
case IntrinsicOp::IOP_WaveMultiPrefixBitAnd:
case IntrinsicOp::IOP_WaveMultiPrefixBitOr:
case IntrinsicOp::IOP_WaveMultiPrefixBitXor:
case IntrinsicOp::IOP_WaveMultiPrefixCountBits:
case IntrinsicOp::IOP_WaveMultiPrefixProduct:
case IntrinsicOp::IOP_WaveMultiPrefixSum:
case IntrinsicOp::IOP_WavePrefixCountBits:
case IntrinsicOp::IOP_WavePrefixProduct:
case IntrinsicOp::IOP_WavePrefixSum:
case IntrinsicOp::IOP_WaveReadLaneAt:
case IntrinsicOp::IOP_WaveReadLaneFirst:
case IntrinsicOp::IOP_QuadReadAcrossDiagonal:
case IntrinsicOp::IOP_QuadReadAcrossX:
case IntrinsicOp::IOP_QuadReadAcrossY:
case IntrinsicOp::IOP_QuadReadLaneAt:
return true;
}
return false;
}
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, unsigned opcode) {
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);
AttributeSet attribs;
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode, attribs);
}
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, llvm::StringRef *groupName,
llvm::StringRef *fnName, unsigned opcode) {
HLOpcodeGroup group, StringRef *groupName,
StringRef *fnName, unsigned opcode) {
AttributeSet attribs;
return GetOrCreateHLFunction(M, funcTy, group, groupName, fnName, opcode, attribs);
}
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, unsigned opcode,
const AttributeSet &attribs) {
return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode, attribs);
}
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, StringRef *groupName,
StringRef *fnName, unsigned opcode,
const AttributeSet &attribs) {
std::string mangledName;
raw_string_ostream mangledNameStr(mangledName);
if (group == HLOpcodeGroup::HLExtIntrinsic) {
@ -510,6 +500,9 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
}
else {
mangledNameStr << GetHLFullName(group, opcode);
// Need to add wave sensitivity to name to prevent clashes with non-wave intrinsic
if(attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
mangledNameStr << "wave";
mangledNameStr << '.';
funcTy->print(mangledNameStr);
}
@ -523,6 +516,14 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
SetHLFunctionAttribute(F, group, opcode);
// Copy attributes
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone))
F->addFnAttr(Attribute::ReadNone);
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly))
F->addFnAttr(Attribute::ReadOnly);
if (attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
F->addFnAttr(HLWaveSensitive, "y");
return F;
}
@ -546,4 +547,23 @@ Function *GetOrCreateHLFunctionWithBody(Module &M, FunctionType *funcTy,
return F;
}
Value *callHLFunction(Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
AttributeSet attribs;
return callHLFunction(Module, OpcodeGroup, Opcode, RetTy, Args, attribs, Builder);
}
Value *callHLFunction(Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
Type *RetTy, ArrayRef<Value*> Args, const AttributeSet &attribs, IRBuilder<> &Builder) {
SmallVector<Type*, 4> ArgTys;
ArgTys.reserve(Args.size());
for (Value *Arg : Args)
ArgTys.emplace_back(Arg->getType());
FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode, attribs);
return Builder.CreateCall(Func, Args);
}
} // namespace hlsl

Просмотреть файл

@ -944,7 +944,7 @@ static unsigned IsPtrUsedByLoweredFn(
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
unsigned opcode = CE->getOpcode();
if (opcode == Instruction::AddrSpaceCast || Instruction::GetElementPtr)
if (opcode == Instruction::AddrSpaceCast || opcode == Instruction::GetElementPtr)
if (IsPtrUsedByLoweredFn(user, CollectedUses))
bFound = true;
}
@ -967,7 +967,8 @@ static CallInst *RewriteIntrinsicCallForCastedArg(CallInst *CI, unsigned argIdx)
newArgs[argIdx] = newArg;
FunctionType *newFuncTy = FunctionType::get(CI->getType(), newArgTypes, false);
Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode);
Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode,
F->getAttributes().getFnAttributes());
IRBuilder<> Builder(CI);
return Builder.CreateCall(newF, newArgs);
@ -2779,7 +2780,8 @@ static CallInst *CreateFlattenedHLIntrinsicCall(
FunctionType *flatFuncTy =
FunctionType::get(CI->getType(), flatParamTys, false);
Function *flatF =
GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode);
GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode,
F->getAttributes().getFnAttributes());
return Builder.CreateCall(flatF, flatArgs);
}

Просмотреть файл

@ -895,6 +895,12 @@ def HLSLPayload : InheritableAttr {
let Documentation = [Undocumented];
}
def HLSLWaveSensitive : InheritableAttr {
let Spellings = [CXX11<"", "wavesensitive", 2015>];
let Subjects = SubjectList<[ParmVar]>;
let Documentation = [Undocumented];
}
// HLSL Change Ends
// SPIRV Change Starts

Просмотреть файл

@ -1157,6 +1157,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
if (hlsl::GetIntrinsicLowering(FD, lower))
hlsl::SetHLLowerStrategy(F, lower);
if (FD->hasAttr<HLSLWaveSensitiveAttr>())
hlsl::SetHLWaveSensitive(F);
// Don't need to add FunctionQual for intrinsic function.
return;
}

Просмотреть файл

@ -191,7 +191,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
llvm::FunctionType *funcTy, HLOpcodeGroup group,
unsigned opcode) {
Function *opFunc = nullptr;
AttributeSet attribs = F->getAttributes().getFnAttributes();
llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
if (group == HLOpcodeGroup::HLIntrinsic) {
IntrinsicOp intriOp = static_cast<IntrinsicOp>(opcode);
@ -202,7 +202,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
llvm::Type *handleTy = funcTy->getParamType(HLOperandIndex::kHandleOpIdx);
// Don't generate body for OutputStream::Append.
if (bAppend && HLModule::IsStreamOutputPtrType(handleTy)) {
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
break;
}
@ -215,7 +215,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
bAppend ? (unsigned)IntrinsicOp::MOP_IncrementCounter
: (unsigned)IntrinsicOp::MOP_DecrementCounter;
Function *incCounterFunc =
GetOrCreateHLFunction(M, IncCounterFuncTy, group, counterOpcode);
GetOrCreateHLFunction(M, IncCounterFuncTy, group, counterOpcode, attribs);
llvm::Type *idxTy = counterTy;
llvm::Type *valTy =
@ -245,7 +245,7 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
Function *subscriptFunc =
GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
(unsigned)HLSubscriptOpcode::DefaultSubscript);
(unsigned)HLSubscriptOpcode::DefaultSubscript, attribs);
BasicBlock *BB =
BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
@ -304,8 +304,8 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
llvm::FunctionType::get(valTy, {opcodeTy, valTy}, false);
unsigned sinOp = static_cast<unsigned>(IntrinsicOp::IOP_sin);
unsigned cosOp = static_cast<unsigned>(IntrinsicOp::IOP_cos);
Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp);
Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp);
Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp, attribs);
Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp, attribs);
BasicBlock *BB =
BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
@ -328,23 +328,18 @@ Function *CreateOpFunction(llvm::Module &M, Function *F,
Builder.CreateRetVoid();
} break;
default:
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
break;
}
} else if (group == HLOpcodeGroup::HLExtIntrinsic) {
llvm::StringRef fnName = F->getName();
llvm::StringRef groupName = GetHLOpcodeGroupNameByAttr(F);
opFunc =
GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode);
GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode, attribs);
} else {
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode, attribs);
}
// Add attribute
if (F->hasFnAttribute(Attribute::ReadNone))
opFunc->addFnAttr(Attribute::ReadNone);
if (F->hasFnAttribute(Attribute::ReadOnly))
opFunc->addFnAttr(Attribute::ReadOnly);
return opFunc;
}

Просмотреть файл

@ -1761,6 +1761,8 @@ static void AddHLSLIntrinsicAttr(FunctionDecl *FD, ASTContext &context,
FD->addAttr(ConstAttr::CreateImplicit(context));
if (pIntrinsic->bReadOnly)
FD->addAttr(PureAttr::CreateImplicit(context));
if (pIntrinsic->bIsWave)
FD->addAttr(HLSLWaveSensitiveAttr::CreateImplicit(context));
}
static
@ -11207,6 +11209,9 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
case AttributeList::AT_HLSLExport:
declAttr = ::new (S.Context) HLSLExportAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_HLSLWaveSensitive:
declAttr = ::new (S.Context) HLSLWaveSensitiveAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex());
break;
default:
Handled = false;
break; // SPIRV Change: was return;
@ -12644,6 +12649,7 @@ bool hlsl::IsHLSLAttr(clang::attr::Kind AttrKind) {
case clang::attr::HLSLPayload:
case clang::attr::NoInline:
case clang::attr::HLSLExport:
case clang::attr::HLSLWaveSensitive:
case clang::attr::VKBinding:
case clang::attr::VKBuiltIn:
case clang::attr::VKConstantId:

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,365 @@
// RUN: %dxc -T ps_6_5 %s | FileCheck %s
StructuredBuffer<int> buf[]: register(t2);
StructuredBuffer<uint4> g_mask;
// CHECK: @dx.break.cond = internal constant
// Cannonical example. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
int main(int a : A, int b : B) : SV_Target
{
int res = 0;
int i = WaveGetLaneCount() + WaveGetLaneIndex();
// These verify the break block keeps the conditional
// CHECK: call i1 @dx.op.waveIsFirstLane
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
bool u = WaveIsFirstLane();
if (a != u) {
res += buf[b][(int)u];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i1 @dx.op.waveAnyTrue
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
bool u = WaveActiveAnyTrue(a);
if (a != u) {
res += buf[(int)u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i1 @dx.op.waveAllTrue
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
bool u = WaveActiveAllTrue(a);
if (a != u) {
res += buf[(int)u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i1 @dx.op.waveActiveAllEqual
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
bool u = WaveActiveAllEqual(a);
if (a != u) {
res += buf[(int)u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveReadLaneFirst(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveAllOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WaveActiveCountBits(a == b);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveSum(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveProduct(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveBit
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveBitAnd((uint)a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveBit
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveBitOr((uint)a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveBit
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveBitXor((uint)a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveMin(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveActiveOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveActiveMax(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.wavePrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WavePrefixCountBits(a < b);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.wavePrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WavePrefixSum(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.wavePrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WavePrefixProduct(a);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call %dx.types.fouri32 @dx.op.waveMatch
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint4 u = WaveMatch(a);
if (a != u.y) {
res += buf[u.x][b];
break;
}
}
uint4 mask = g_mask[0];
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveMultiPrefixBitAnd((uint)a, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveMultiPrefixBitOr((uint)a, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
int u = WaveMultiPrefixBitXor((uint)a, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixBitCount
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WaveMultiPrefixCountBits(a <= b, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WaveMultiPrefixProduct(a, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
// These verify the break block keeps the conditional
// CHECK: call i32 @dx.op.waveMultiPrefixOp
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// Loop with wave-dependent conditional break block
for (;;) {
uint u = WaveMultiPrefixSum(a, mask);
if (a != u) {
res += buf[u][b];
break;
}
}
return res;
}

Просмотреть файл

@ -4,6 +4,8 @@ StructuredBuffer<int> buf[]: register(t2);
// Cannonical example. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: define i32
// CHECK-SAME: WaveInLoop
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
@ -23,6 +25,7 @@ StructuredBuffer<int> buf[]: register(t2);
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// CHECK: ret i32
export
int WaveInLoop(int a : A, int b : B)
{
@ -52,6 +55,8 @@ int WaveInLoop(int a : A, int b : B)
// Wave moved to after the break block. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: define i32
// CHECK-SAME: WaveInPostLoop
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
@ -71,6 +76,7 @@ int WaveInLoop(int a : A, int b : B)
// CHECK: br i1
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: ret i32
export
int WaveInPostLoop(int a : A, int b : B)
{
@ -101,6 +107,8 @@ int WaveInPostLoop(int a : A, int b : B)
// Wave op inside break block. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: define i32
// CHECK-SAME: WaveInBreakBlock
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
@ -111,7 +119,7 @@ int WaveInPostLoop(int a : A, int b : B)
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: br i1
// CHECK: ret i32
export
int WaveInBreakBlock(int a : A, int b : B)
{
@ -131,6 +139,8 @@ int WaveInBreakBlock(int a : A, int b : B)
}
// Wave in entry block. Expected to allow the break block to move out of loop
// CHECK: define i32
// CHECK-SAME: WaveInEntry
// CHECK: call i32 @dx.op.waveReadLaneFirst
// These verify the break block doesn't keep the conditional
@ -140,6 +150,7 @@ int WaveInBreakBlock(int a : A, int b : B)
// These verify the break block doesn't keep the conditional
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: ret i32
export
int WaveInEntry(int a : A, int b : B)
{
@ -169,6 +180,8 @@ int WaveInEntry(int a : A, int b : B)
// Wave in subloop of larger loop. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: define i32
// CHECK-SAME: WaveInSubLoop
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
@ -186,6 +199,7 @@ int WaveInEntry(int a : A, int b : B)
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// CHECK: ret i32
export
int WaveInSubLoop(int a : A, int b : B)
{
@ -218,6 +232,8 @@ int WaveInSubLoop(int a : A, int b : B)
}
// Wave in a separate loop. Expected to allow the break block to move out of loop
// CHECK: define i32
// CHECK-SAME: WaveInOtherLoop
// CHECK: load i32
// CHECK: icmp eq i32
@ -233,13 +249,14 @@ int WaveInSubLoop(int a : A, int b : B)
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHICK-NOT: br i1
// CHECK-NOT: br i1
// These verify the third break block doesn't
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK-NOT: br i1
// CHECK: ret i32
export
int WaveInOtherLoop(int a : A, int b : B)
{
@ -273,3 +290,44 @@ int WaveInOtherLoop(int a : A, int b : B)
}
return res;
}
// Test for matrices which pass through additional lowering which might lose wave-sensitive attrib
// Verify this function loads the global
// CHECK: define i32
// CHECK-SAME: WaveWithMatrix
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveReadLaneFirst
// These verify the break block keeps the conditional
// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// CHECK: ret i32
export
int WaveWithMatrix(int3x3 a : A, int b : B)
{
int res = 0;
// Loop with wave-dependent matrix conditional break block
for (;;) {
int3x3 u = WaveReadLaneFirst(a);
if (a[0][0] == u[1][1]) {
res += buf[u[2][2]][b];
break;
}
}
return res;
}

Просмотреть файл

@ -0,0 +1,49 @@
// RUN: %dxc -T ps_6_3 %s | FileCheck %s
StructuredBuffer<int> buf[]: register(t2);
// CHECK: @dx.break.cond = internal constant
// Cannonical example. Expected to keep the block in loop
// Verify this function loads the global
// CHECK: load i32
// CHECK-SAME: @dx.break.cond
// CHECK: icmp eq i32
// CHECK: call i32 @dx.op.waveActiveOp.i32
// These verify the first break block keeps the conditional
// CHECK: call %dx.types.Handle @dx.op.createHandle
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHECK: br i1
// These verify the second break block doesn't
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
// CHECK: add
// CHICK-NOT: br i1
int main(int a : A, int b : B) : SV_Target
{
int res = 0;
// Loop with wave op
for (;;) {
int u = WaveActiveSum(a);
if (a == u) {
res += buf[b][u];
break;
}
}
// Loop with non-wave op with same signature as previous wave op
// Without prototype manipulation, this will share an intermediate
// op call with the previous and get the wave attribute.
for (;;) {
int u = abs(a--);
if (a == u) {
res += buf[b][u];
break;
}
}
return res;
}

Просмотреть файл

@ -117,6 +117,12 @@ static const HLSL_INTRINSIC_ARGUMENT TestMySamplerOp[] = {
{ "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
};
// $result = wave_proc(any_vector<any_cardinality> value)
static const HLSL_INTRINSIC_ARGUMENT WaveProcArgs[] = {
{ "wave_proc", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
{ "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
};
struct Intrinsic {
LPCWSTR hlslName;
const char *dxilName;
@ -130,31 +136,32 @@ template <class T, std::size_t N>
UINT countof(T(&)[N]) { return static_cast<UINT>(N); }
Intrinsic Intrinsics[] = {
{L"test_fn", DEFAULT_NAME, "r", { 1, false, true, -1, countof(TestFnArgs), TestFnArgs }},
{L"test_proc", DEFAULT_NAME, "r", { 2, false, false,-1, countof(TestProcArgs), TestProcArgs }},
{L"test_poly", "test_poly.$o", "r", { 3, false, true, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
{L"test_int", "test_int", "r", { 4, false, true, -1, countof(TestFnIntArgs), TestFnIntArgs}},
{L"test_nolower", "test_nolower.$o", "n", { 5, false, true, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
{L"test_pack_0", "test_pack_0.$o", "p", { 6, false, false,-1, countof(TestFnPack0), TestFnPack0}},
{L"test_pack_1", "test_pack_1.$o", "p", { 7, false, true, -1, countof(TestFnPack1), TestFnPack1}},
{L"test_pack_2", "test_pack_2.$o", "p", { 8, false, true, -1, countof(TestFnPack2), TestFnPack2}},
{L"test_pack_3", "test_pack_3.$o", "p", { 9, false, true, -1, countof(TestFnPack3), TestFnPack3}},
{L"test_pack_4", "test_pack_4.$o", "p", { 10, false, false,-1, countof(TestFnPack4), TestFnPack4}},
{L"test_rand", "test_rand", "r", { 11, false, false,-1, countof(TestRand), TestRand}},
{L"test_isinf", "test_isinf", "d", { 13, true, true, -1, countof(TestIsInf), TestIsInf}},
{L"test_ibfe", "test_ibfe", "d", { 14, true, true, -1, countof(TestIBFE), TestIBFE}},
{L"test_fn", DEFAULT_NAME, "r", { 1, false, true, false, -1, countof(TestFnArgs), TestFnArgs }},
{L"test_proc", DEFAULT_NAME, "r", { 2, false, false, false,-1, countof(TestProcArgs), TestProcArgs }},
{L"test_poly", "test_poly.$o", "r", { 3, false, true, false, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
{L"test_int", "test_int", "r", { 4, false, true, false, -1, countof(TestFnIntArgs), TestFnIntArgs}},
{L"test_nolower", "test_nolower.$o", "n", { 5, false, true, false, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
{L"test_pack_0", "test_pack_0.$o", "p", { 6, false, false, false,-1, countof(TestFnPack0), TestFnPack0}},
{L"test_pack_1", "test_pack_1.$o", "p", { 7, false, true, false, -1, countof(TestFnPack1), TestFnPack1}},
{L"test_pack_2", "test_pack_2.$o", "p", { 8, false, true, false, -1, countof(TestFnPack2), TestFnPack2}},
{L"test_pack_3", "test_pack_3.$o", "p", { 9, false, true, false, -1, countof(TestFnPack3), TestFnPack3}},
{L"test_pack_4", "test_pack_4.$o", "p", { 10, false, false, false,-1, countof(TestFnPack4), TestFnPack4}},
{L"test_rand", "test_rand", "r", { 11, false, false, false,-1, countof(TestRand), TestRand}},
{L"test_isinf", "test_isinf", "d", { 13, true, true, false, -1, countof(TestIsInf), TestIsInf}},
{L"test_ibfe", "test_ibfe", "d", { 14, true, true, false, -1, countof(TestIBFE), TestIBFE}},
// Make this intrinsic have the same opcode as an hlsl intrinsic with an unsigned
// counterpart for testing purposes.
{L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, -1, countof(TestUnsigned), TestUnsigned}},
{L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, false, -1, countof(TestUnsigned), TestUnsigned}},
{L"wave_proc", DEFAULT_NAME, "r", { 16, false, true, true, -1, countof(WaveProcArgs), WaveProcArgs }},
};
Intrinsic BufferIntrinsics[] = {
{L"MyBufferOp", "MyBufferOp", "m", { 12, false, true, -1, countof(TestMyBufferOp), TestMyBufferOp}},
{L"MyBufferOp", "MyBufferOp", "m", { 12, false, true, false, -1, countof(TestMyBufferOp), TestMyBufferOp}},
};
// Test adding a method to an object that normally has no methods (SamplerState will do).
Intrinsic SamplerIntrinsics[] = {
{L"MySamplerOp", "MySamplerOp", "m", { 15, false, true, -1, countof(TestMySamplerOp), TestMySamplerOp}},
{L"MySamplerOp", "MySamplerOp", "m", { 15, false, true, false, -1, countof(TestMySamplerOp), TestMySamplerOp}},
};
class IntrinsicTable {
@ -454,6 +461,7 @@ public:
TEST_METHOD(DxilLoweringVector2)
TEST_METHOD(DxilLoweringScalar)
TEST_METHOD(SamplerExtensionIntrinsic)
TEST_METHOD(WaveIntrinsic)
};
TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@ -880,3 +888,43 @@ TEST_F(ExtensionTest, SamplerExtensionIntrinsic) {
CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
}
TEST_F(ExtensionTest, WaveIntrinsic) {
// Test wave-sensitive intrinsic in breaked loop
Compiler c(m_dllSupport);
c.RegisterIntrinsicTable(new TestIntrinsicTable());
c.Compile(
"StructuredBuffer<int> buf[]: register(t2);"
"float2 main(float2 a : A, int b : B) : SV_Target {"
" int res = 0;"
" float2 u = {0,0};"
" for (;;) {"
" u += wave_proc(a);"
" if (a.x == u.x) {"
" res += buf[b][(int)u.y];"
" break;"
" }"
" }"
" return res;"
"}",
{ L"/Vd" }, {}
);
std::string disassembly = c.Disassemble();
// Check that the wave op causes the break block to be retained
VERIFY_IS_TRUE(
disassembly.npos !=
disassembly.find("@dx.break.cond = internal constant [1 x i32] zeroinitializer"));
VERIFY_IS_TRUE(
disassembly.npos !=
disassembly.find("%1 = load i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @dx.break.cond"));
VERIFY_IS_TRUE(
disassembly.npos !=
disassembly.find("%2 = icmp eq i32 %1, 0"));
VERIFY_IS_TRUE(
disassembly.npos !=
disassembly.find("call float @\"test.\\01?wave_proc@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32 16, float"));
VERIFY_IS_TRUE(
disassembly.npos !=
disassembly.find("br i1 %2"));
}

Просмотреть файл

@ -245,34 +245,34 @@ bool [[rn]] CheckAccessFullyMapped(in uint_only status) : check_access_fully_map
uint<c> [[rn]] AddUint64(in $match<1, 0> uint<c> a, in $match<2, 0> uint<c> b) : adduint64;
$type1 [[rn]] NonUniformResourceIndex(in any<> index) : nonuniform_resource_index;
// Wave intrinsics.
bool [[]] WaveIsFirstLane();
// Wave intrinsics. Only those that depend on the exec mask are marked as wave-sensitive
bool [[wv]] WaveIsFirstLane();
uint [[rn]] WaveGetLaneIndex();
uint [[rn]] WaveGetLaneCount();
bool [[]] WaveActiveAnyTrue(in bool cond);
bool [[]] WaveActiveAllTrue(in bool cond);
$match<1, 0> bool<> [[]] WaveActiveAllEqual(in any<> value);
uint<4> [[]] WaveActiveBallot(in bool cond);
bool [[wv]] WaveActiveAnyTrue(in bool cond);
bool [[wv]] WaveActiveAllTrue(in bool cond);
$match<1, 0> bool<> [[wv]] WaveActiveAllEqual(in any<> value);
uint<4> [[wv]] WaveActiveBallot(in bool cond);
$type1 [[]] WaveReadLaneAt(in any<> value, in uint lane);
$type1 [[]] WaveReadLaneFirst(in any<> value);
uint [[]] WaveActiveCountBits(in bool value);
$type1 [[unsigned_op=WaveActiveUSum]] WaveActiveSum(in numeric<> value);
$type1 [[unsigned_op=WaveActiveUProduct]] WaveActiveProduct(in numeric<> value);
$type1 [[]] WaveActiveBitAnd(in uint_only<> value);
$type1 [[]] WaveActiveBitOr(in uint_only<> value);
$type1 [[]] WaveActiveBitXor(in uint_only<> value);
$type1 [[unsigned_op=WaveActiveUMin]] WaveActiveMin(in numeric<> value);
$type1 [[unsigned_op=WaveActiveUMax]] WaveActiveMax(in numeric<> value);
uint [[]] WavePrefixCountBits(in bool value);
$type1 [[unsigned_op=WavePrefixUSum]] WavePrefixSum(in numeric<> value);
$type1 [[unsigned_op=WavePrefixUProduct]] WavePrefixProduct(in numeric<> value);
uint<4> [[]] WaveMatch(in numeric<> value);
$type1 [[]] WaveMultiPrefixBitAnd(in any_int<> value, in uint<4> mask);
$type1 [[]] WaveMultiPrefixBitOr(in any_int<> value, in uint<4> mask);
$type1 [[]] WaveMultiPrefixBitXor(in any_int<> value, in uint<4> mask);
uint [[]] WaveMultiPrefixCountBits(in bool value, in uint<4> mask);
$type1 [[unsigned_op=WaveMultiPrefixUProduct]] WaveMultiPrefixProduct(in numeric<> value, in uint<4> mask);
$type1 [[unsigned_op=WaveMultiPrefixUSum]] WaveMultiPrefixSum(in numeric<> value, in uint<4> mask);
$type1 [[wv]] WaveReadLaneFirst(in any<> value);
uint [[wv]] WaveActiveCountBits(in bool value);
$type1 [[unsigned_op=WaveActiveUSum,wv]] WaveActiveSum(in numeric<> value);
$type1 [[unsigned_op=WaveActiveUProduct,wv]] WaveActiveProduct(in numeric<> value);
$type1 [[wv]] WaveActiveBitAnd(in uint_only<> value);
$type1 [[wv]] WaveActiveBitOr(in uint_only<> value);
$type1 [[wv]] WaveActiveBitXor(in uint_only<> value);
$type1 [[unsigned_op=WaveActiveUMin,wv]] WaveActiveMin(in numeric<> value);
$type1 [[unsigned_op=WaveActiveUMax,wv]] WaveActiveMax(in numeric<> value);
uint [[wv]] WavePrefixCountBits(in bool value);
$type1 [[unsigned_op=WavePrefixUSum,wv]] WavePrefixSum(in numeric<> value);
$type1 [[unsigned_op=WavePrefixUProduct,wv]] WavePrefixProduct(in numeric<> value);
uint<4> [[wv]] WaveMatch(in numeric<> value);
$type1 [[wv]] WaveMultiPrefixBitAnd(in any_int<> value, in uint<4> mask);
$type1 [[wv]] WaveMultiPrefixBitOr(in any_int<> value, in uint<4> mask);
$type1 [[wv]] WaveMultiPrefixBitXor(in any_int<> value, in uint<4> mask);
uint [[wv]] WaveMultiPrefixCountBits(in bool value, in uint<4> mask);
$type1 [[unsigned_op=WaveMultiPrefixUProduct,wv]] WaveMultiPrefixProduct(in numeric<> value, in uint<4> mask);
$type1 [[unsigned_op=WaveMultiPrefixUSum,wv]] WaveMultiPrefixSum(in numeric<> value, in uint<4> mask);
$type1 [[]] QuadReadLaneAt(in numeric<> value, in uint quadLane);
$type1 [[]] QuadReadAcrossX(in numeric<> value);
$type1 [[]] QuadReadAcrossY(in numeric<> value);

Просмотреть файл

@ -2621,7 +2621,7 @@ class db_hlsl_attribute(object):
class db_hlsl_intrinsic(object):
"An HLSL intrinsic declaration"
def __init__(self, name, idx, opname, params, ns, ns_idx, doc, ro, rn, unsigned_op, overload_idx, hidden):
def __init__(self, name, idx, opname, params, ns, ns_idx, doc, ro, rn, wv, unsigned_op, overload_idx, hidden):
self.name = name # Function name
self.idx = idx # Unique number within namespace
self.opname = opname # D3D-style name
@ -2633,6 +2633,7 @@ class db_hlsl_intrinsic(object):
self.enum_name = "%s_%s" % (id_prefix, name) # enum name
self.readonly = ro # Only read memory
self.readnone = rn # Not read memory
self.wave = wv # Is wave-sensitive
self.unsigned_op = unsigned_op # Unsigned opcode if exist
if unsigned_op != "":
self.unsigned_op = "%s_%s" % (id_prefix, unsigned_op)
@ -2873,6 +2874,7 @@ class db_hlsl(object):
attrs = attr.split(',')
readonly = False # Only read memory
readnone = False # Not read memory
is_wave = False; # Is wave-sensitive
unsigned_op = "" # Unsigned opcode if exist
overload_param_index = -1 # Parameter determines the overload type, -1 means ret type.
hidden = False
@ -2885,6 +2887,9 @@ class db_hlsl(object):
if (a == "rn"):
readnone = True
continue
if (a == "wv"):
is_wave = True
continue
if (a == "hidden"):
hidden = True
continue
@ -2904,7 +2909,7 @@ class db_hlsl(object):
continue
assert False, "invalid attr %s" % (a)
return readonly, readnone, unsigned_op, overload_param_index, hidden
return readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden
current_namespace = None
for line in intrinsic_defs:
@ -2937,7 +2942,7 @@ class db_hlsl(object):
op = operand_match.group(1)
if not op:
op = name
readonly, readnone, unsigned_op, overload_param_index, hidden = process_attr(attr)
readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden = process_attr(attr)
# Add an entry for this intrinsic.
if bracket_cleanup_re.search(opts):
opts = bracket_cleanup_re.sub(r"<\1@\2>", opts)
@ -2961,7 +2966,7 @@ class db_hlsl(object):
# TODO: verify a single level of indirection
self.intrinsics.append(db_hlsl_intrinsic(
name, num_entries, op, args, current_namespace, ns_idx, "pending doc for " + name,
readonly, readnone, unsigned_op, overload_param_index, hidden))
readonly, readnone, is_wave, unsigned_op, overload_param_index, hidden))
num_entries += 1
continue
assert False, "cannot parse line %s" % (line)

Просмотреть файл

@ -363,7 +363,7 @@ class db_oload_gen:
f = lambda i,c : "true" if i.oload_types.find(c) >= 0 else "false"
lower_exceptions = { "CBufferLoad" : "cbufferLoad", "CBufferLoadLegacy" : "cbufferLoadLegacy", "GSInstanceID" : "gsInstanceID" }
lower_fn = lambda t: lower_exceptions[t] if t in lower_exceptions else t[:1].lower() + t[1:]
attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn" }
attr_dict = { "": "None", "ro": "ReadOnly", "rn": "ReadNone", "nd": "NoDuplicate", "nr": "NoReturn", "wv" : "None" }
attr_fn = lambda i : "Attribute::" + attr_dict[i.fn_attr] + ","
for i in self.instrs:
if last_category != i.category:
@ -691,7 +691,7 @@ def get_hlsl_intrinsics():
result += "#ifdef ENABLE_SPIRV_CODEGEN\n\n"
# SPIRV Change Ends
arg_idx = 0
ns_table += " {(UINT)%s::%s_%s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
ns_table += " {(UINT)%s::%s_%s, %s, %s, %s, %d, %d, g_%s_Args%s},\n" % (opcode_namespace, id_prefix, i.name, str(i.readonly).lower(), str(i.readnone).lower(), str(i.wave).lower(), i.overload_param_index,len(i.params), last_ns, arg_idx)
result += "static const HLSL_INTRINSIC_ARGUMENT g_%s_Args%s[] =\n{\n" % (last_ns, arg_idx)
for p in i.params:
name = p.name