Ensure the cached Function->OpCodeClass map is updated (#186)
* Ensure the cached Function->OpCodeClass map is updated The original goal of this change was to use opcode class for deciding when we can perform constant folding on a function. We maintain a mapping from Function* to OpCodeClass inside the OP class. We wanted to use this map in constant folding to decide if we can constant fold a function to avoid string comparison on the function names. However, it turns out that the DxilModule is not always available during constant folding of dxil calls so we cannot use the map inside of OP. The change contains a few bug fixes and improvements that came out of trying to get opcode class working inside constant folding. 1. Use opcode class in dxil constant folding where possible. 2. Make sure the opcode cache is refreshed properly. 3. Remove 64-bit test for bfi. 4. Add equality comparison for the ShaderModel class. When switching to use the opcode class for constant folding, we discovered that our test for 64-bit bfi is invalid. There is no 64-bit overload for bfi in dxil, so the test we had written was not legal dxil. This change removes the 64-bit test for bfi constant prop.
This commit is contained in:
Родитель
0677af11e4
Коммит
deec58372a
|
@ -54,6 +54,11 @@ public:
|
|||
llvm::Type *GetResRetType(llvm::Type *pOverloadType);
|
||||
llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
|
||||
|
||||
// Try to get the opcode class for a function.
|
||||
// Return true and set `opClass` if the given function is a dxil function.
|
||||
// Return false if the given function is not a dxil function.
|
||||
bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
|
||||
|
||||
// LLVM helpers. Perhaps, move to a separate utility class.
|
||||
llvm::Constant *GetI1Const(bool v);
|
||||
llvm::Constant *GetI8Const(char v);
|
||||
|
@ -102,8 +107,9 @@ private:
|
|||
llvm::Function *pOverloads[kNumTypeOverloads];
|
||||
};
|
||||
OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
|
||||
std::unordered_map<llvm::Function *, OpCodeClass> m_FunctionToOpClass;
|
||||
std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
|
||||
void RefreshCache(llvm::Module *pModule);
|
||||
void UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F);
|
||||
private:
|
||||
// Static properties.
|
||||
struct OpCodeProperty {
|
||||
|
|
|
@ -60,6 +60,9 @@ public:
|
|||
static const ShaderModel *Get(Kind Kind, unsigned Major, unsigned Minor);
|
||||
static const ShaderModel *GetByName(const char *pszName);
|
||||
|
||||
bool operator==(const ShaderModel &other) const;
|
||||
bool operator!=(const ShaderModel &other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
Kind m_Kind;
|
||||
unsigned m_Major;
|
||||
|
|
|
@ -697,7 +697,7 @@ public:
|
|||
void ResetHLModule();
|
||||
bool HasDxilModule() const { return TheDxilModule != nullptr; }
|
||||
void SetDxilModule(hlsl::DxilModule *pValue) { TheDxilModule = pValue; }
|
||||
hlsl::DxilModule &GetDxilModule() { return *TheDxilModule; }
|
||||
hlsl::DxilModule &GetDxilModule() const { return *TheDxilModule; }
|
||||
hlsl::DxilModule &GetOrCreateDxilModule(bool skipInit = false);
|
||||
void ResetDxilModule();
|
||||
// HLSL Change end
|
||||
|
|
|
@ -421,19 +421,17 @@ static Constant *ConstantFoldQuaternaryIntInstrinsic(OP::OpCode opcode, Type *Ty
|
|||
return ConstantInt::get(Ty, result);
|
||||
}
|
||||
|
||||
// Return true if opcode is for a dot operation.
|
||||
static bool IsDotOpcode(OP::OpCode opcode) {
|
||||
return opcode == OP::OpCode::Dot2
|
||||
|| opcode == OP::OpCode::Dot3
|
||||
|| opcode == OP::OpCode::Dot4;
|
||||
}
|
||||
|
||||
// Top level function to constant fold floating point intrinsics.
|
||||
static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
|
||||
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
|
||||
return nullptr;
|
||||
|
||||
if (IntrinsicOperands.Size() == 1) {
|
||||
OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
|
||||
|
||||
switch (opClass) {
|
||||
default: break;
|
||||
case OP::OpCodeClass::Unary: {
|
||||
assert(IntrinsicOperands.Size() == 1);
|
||||
ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
|
||||
|
||||
if (!IsValidOp(Op))
|
||||
|
@ -441,7 +439,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|||
|
||||
return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
|
||||
}
|
||||
else if (IntrinsicOperands.Size() == 2) {
|
||||
case OP::OpCodeClass::Binary: {
|
||||
assert(IntrinsicOperands.Size() == 2);
|
||||
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
||||
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
||||
|
||||
|
@ -450,7 +449,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|||
|
||||
return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
|
||||
}
|
||||
else if (IntrinsicOperands.Size() == 3) {
|
||||
case OP::OpCodeClass::Tertiary: {
|
||||
assert(IntrinsicOperands.Size() == 3);
|
||||
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
||||
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
||||
ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
|
||||
|
@ -460,7 +460,9 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|||
|
||||
return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
||||
}
|
||||
else if (IsDotOpcode(opcode)) {
|
||||
case OP::OpCodeClass::Dot2:
|
||||
case OP::OpCodeClass::Dot3:
|
||||
case OP::OpCodeClass::Dot4:
|
||||
return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
|
||||
}
|
||||
|
||||
|
@ -472,14 +474,21 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|||
if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
|
||||
return nullptr;
|
||||
|
||||
if (IntrinsicOperands.Size() == 1) {
|
||||
OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
|
||||
|
||||
switch (opClass) {
|
||||
default: break;
|
||||
case OP::OpCodeClass::Unary:
|
||||
case OP::OpCodeClass::UnaryBits: {
|
||||
assert(IntrinsicOperands.Size() == 1);
|
||||
ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
|
||||
if (!Op)
|
||||
return nullptr;
|
||||
|
||||
return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
|
||||
}
|
||||
else if (IntrinsicOperands.Size() == 2) {
|
||||
case OP::OpCodeClass::Binary: {
|
||||
assert(IntrinsicOperands.Size() == 2);
|
||||
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
||||
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
||||
if (!Op1 || !Op2)
|
||||
|
@ -487,7 +496,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|||
|
||||
return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
|
||||
}
|
||||
else if (IntrinsicOperands.Size() == 3) {
|
||||
case OP::OpCodeClass::Tertiary: {
|
||||
assert(IntrinsicOperands.Size() == 3);
|
||||
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
||||
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
||||
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
||||
|
@ -496,7 +506,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|||
|
||||
return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
||||
}
|
||||
else if (IntrinsicOperands.Size() == 4) {
|
||||
case OP::OpCodeClass::Quaternary: {
|
||||
assert(IntrinsicOperands.Size() == 4);
|
||||
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
||||
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
||||
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
||||
|
@ -506,6 +517,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|||
|
||||
return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -535,6 +548,8 @@ bool hlsl::CanConstantFoldCallTo(const Function *F) {
|
|||
return false;
|
||||
|
||||
// Check match using startswith to get all overloads.
|
||||
// We cannot use the opcode class here because constant folding
|
||||
// may run without a DxilModule available.
|
||||
StringRef Name = F->getName();
|
||||
if (Name.startswith("dx.op.unary"))
|
||||
return true;
|
||||
|
|
|
@ -526,7 +526,6 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
|
|||
return E_INVALIDARG;
|
||||
}
|
||||
|
||||
|
||||
legacy::PassManager ModulePasses;
|
||||
legacy::FunctionPassManager FunctionPasses(M.get());
|
||||
legacy::PassManagerBase *pPassManager = &ModulePasses;
|
||||
|
|
|
@ -127,7 +127,7 @@ Module *DxilModule::GetModule() const { return m_pModule; }
|
|||
OP *DxilModule::GetOP() const { return m_pOP.get(); }
|
||||
|
||||
void DxilModule::SetShaderModel(const ShaderModel *pSM) {
|
||||
DXASSERT(m_pSM == nullptr, "shader model must not change for the module");
|
||||
DXASSERT(m_pSM == nullptr || (pSM != nullptr && *m_pSM == *pSM), "shader model must not change for the module");
|
||||
m_pSM = pSM;
|
||||
m_pMDHelper->SetShaderModel(m_pSM);
|
||||
DXIL::ShaderKind shaderKind = pSM->GetKind();
|
||||
|
|
|
@ -427,10 +427,16 @@ void OP::RefreshCache(llvm::Module *pModule) {
|
|||
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
|
||||
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
|
||||
DXASSERT_NOMSG(OpFunc == &F);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OP::UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F) {
|
||||
m_OpCodeClassCache[(unsigned)opClass].pOverloads[typeSlot] = F;
|
||||
m_FunctionToOpClass[F] = opClass;
|
||||
}
|
||||
|
||||
Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
|
||||
DXASSERT(0 <= (unsigned)OpCode && OpCode < OpCode::NumOpCodes, "otherwise caller passed OOB OpCode");
|
||||
_Analysis_assume_(0 <= (unsigned)OpCode && OpCode < OpCode::NumOpCodes);
|
||||
|
@ -438,8 +444,10 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
|
|||
unsigned TypeSlot = GetTypeSlot(pOverloadType);
|
||||
OpCodeClass opClass = m_OpCodeProps[(unsigned)OpCode].OpCodeClass;
|
||||
Function *&F = m_OpCodeClassCache[(unsigned)opClass].pOverloads[TypeSlot];
|
||||
if (F != nullptr)
|
||||
if (F != nullptr) {
|
||||
UpdateCache(opClass, TypeSlot, F);
|
||||
return F;
|
||||
}
|
||||
|
||||
vector<Type*> ArgTypes; // RetType is ArgTypes[0]
|
||||
Type *pETy = pOverloadType;
|
||||
|
@ -470,6 +478,7 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
|
|||
// Try to find exist function with the same name in the module.
|
||||
if (Function *existF = m_pModule->getFunction(funcName)) {
|
||||
F = existF;
|
||||
UpdateCache(opClass, TypeSlot, F);
|
||||
return F;
|
||||
}
|
||||
|
||||
|
@ -690,7 +699,7 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
|
|||
|
||||
F = cast<Function>(m_pModule->getOrInsertFunction(funcName, pFT));
|
||||
|
||||
m_FunctionToOpClass[F] = opClass;
|
||||
UpdateCache(opClass, TypeSlot, F);
|
||||
F->setCallingConv(CallingConv::C);
|
||||
F->addFnAttr(Attribute::NoUnwind);
|
||||
if (m_OpCodeProps[(unsigned)OpCode].FuncAttr != Attribute::None)
|
||||
|
@ -718,6 +727,16 @@ void OP::RemoveFunction(Function *F) {
|
|||
}
|
||||
}
|
||||
|
||||
bool OP::GetOpCodeClass(const Function *F, OP::OpCodeClass &opClass) {
|
||||
auto iter = m_FunctionToOpClass.find(F);
|
||||
if (iter == m_FunctionToOpClass.end()) {
|
||||
DXASSERT(!IsDxilOpFunc(F), "dxil function without an opcode class mapping?");
|
||||
return false;
|
||||
}
|
||||
opClass = iter->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::Type *OP::GetOverloadType(OpCode OpCode, llvm::Function *F) {
|
||||
DXASSERT(F, "not work on nullptr");
|
||||
Type *Ty = F->getReturnType();
|
||||
|
|
|
@ -28,6 +28,17 @@ ShaderModel::ShaderModel(Kind Kind, unsigned Major, unsigned Minor, const char *
|
|||
, m_NumUAVRegs(NumUAVRegs) {
|
||||
}
|
||||
|
||||
bool ShaderModel::operator==(const ShaderModel &other) const {
|
||||
return m_Kind == other.m_Kind
|
||||
&& m_Major == other.m_Major
|
||||
&& m_Minor == other.m_Minor
|
||||
&& strcmp(m_pszName, other.m_pszName) == 0
|
||||
&& m_NumInputRegs == other.m_NumInputRegs
|
||||
&& m_NumOutputRegs == other.m_NumOutputRegs
|
||||
&& m_bTypedUavs == other.m_bTypedUavs
|
||||
&& m_NumUAVRegs == other.m_NumUAVRegs;
|
||||
}
|
||||
|
||||
bool ShaderModel::IsValid() const {
|
||||
DXASSERT(IsPS() || IsVS() || IsGS() || IsHS() || IsDS() || IsCS() || m_Kind == Kind::Invalid, "invalid shader model");
|
||||
return m_Kind != Kind::Invalid;
|
||||
|
|
|
@ -30,16 +30,18 @@ entry:
|
|||
%4 = call i32 @dx.op.quaternary.i32(i32 53, i32 0, i32 8, i32 0, i32 15)
|
||||
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 16, i32 undef, i32 %4, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
|
||||
|
||||
; CHECK: @dx.op.bufferStore{{.*}}, i32 2560,
|
||||
%5 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 8, i64 4010, i64 0)
|
||||
%6 = trunc i64 %5 to i32
|
||||
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 20, i32 undef, i32 %6, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
|
||||
; No i64 overloads for bfi in dxil yet.
|
||||
; xHECK: @dx.op.bufferStore{{.*}}, i32 2560,
|
||||
;%5 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 8, i64 4010, i64 0)
|
||||
;%6 = trunc i64 %5 to i32
|
||||
;call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 20, i32 undef, i32 %6, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
|
||||
|
||||
; CHECK: @dx.op.bufferStore{{.*}}, i32 10,
|
||||
%7 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 32, i64 4010, i64 0)
|
||||
%8 = lshr i64 %7, 32
|
||||
%9 = trunc i64 %8 to i32
|
||||
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 24, i32 undef, i32 %9, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
|
||||
; No i64 overloads for bfi in dxil yet.
|
||||
; xHECK: @dx.op.bufferStore{{.*}}, i32 10,
|
||||
;%7 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 32, i64 4010, i64 0)
|
||||
;%8 = lshr i64 %7, 32
|
||||
;%9 = trunc i64 %8 to i32
|
||||
;call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 24, i32 undef, i32 %9, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
|
||||
|
||||
call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 0, i32 0) ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
|
||||
ret void
|
||||
|
|
Загрузка…
Ссылка в новой задаче