Ensure the cached Function->OpCodeClass map is updated (#186)

* Ensure the cached Function->OpCodeClass map is updated

The original goal of this change was to use opcode class for deciding when we
can perform constant folding on a function.

We maintain a mapping from Function* to OpCodeClass inside the OP class.
We wanted to use this map in constant folding to decide if we can constant
fold a function to avoid string comparison on the function names.

However, it turns out that the DxilModule is not always available during
constant folding of dxil calls so we cannot use the map inside of OP. The
change contains a few bug fixes and improvements that came out of trying
to get opcode class working inside constant folding.

  1. Use opcode class in dxil constant folding where possible.
  2. Make sure the opcode cache is refreshed properly.
  3. Remove 64-bit test for bfi.
  4. Add equality comparison for the ShaderModel class.

When switching to use the opcode class for constant folding, we discovered
that our test for 64-bit bfi is invalid. There is no 64-bit overload for
bfi in dxil, so the test we had written was not legal dxil. This change
removes the 64-bit test for bfi constant prop.
This commit is contained in:
David Peixotto 2017-04-06 13:48:27 -07:00 коммит произвёл GitHub
Родитель 0677af11e4
Коммит deec58372a
9 изменённых файлов: 85 добавлений и 30 удалений

Просмотреть файл

@ -54,6 +54,11 @@ public:
llvm::Type *GetResRetType(llvm::Type *pOverloadType);
llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
// Try to get the opcode class for a function.
// Return true and set `opClass` if the given function is a dxil function.
// Return false if the given function is not a dxil function.
bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
// LLVM helpers. Perhaps, move to a separate utility class.
llvm::Constant *GetI1Const(bool v);
llvm::Constant *GetI8Const(char v);
@ -102,8 +107,9 @@ private:
llvm::Function *pOverloads[kNumTypeOverloads];
};
OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
std::unordered_map<llvm::Function *, OpCodeClass> m_FunctionToOpClass;
std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
void RefreshCache(llvm::Module *pModule);
void UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F);
private:
// Static properties.
struct OpCodeProperty {

Просмотреть файл

@ -60,6 +60,9 @@ public:
static const ShaderModel *Get(Kind Kind, unsigned Major, unsigned Minor);
static const ShaderModel *GetByName(const char *pszName);
bool operator==(const ShaderModel &other) const;
bool operator!=(const ShaderModel &other) const { return !(*this == other); }
private:
Kind m_Kind;
unsigned m_Major;

Просмотреть файл

@ -697,7 +697,7 @@ public:
void ResetHLModule();
bool HasDxilModule() const { return TheDxilModule != nullptr; }
void SetDxilModule(hlsl::DxilModule *pValue) { TheDxilModule = pValue; }
hlsl::DxilModule &GetDxilModule() { return *TheDxilModule; }
hlsl::DxilModule &GetDxilModule() const { return *TheDxilModule; }
hlsl::DxilModule &GetOrCreateDxilModule(bool skipInit = false);
void ResetDxilModule();
// HLSL Change end

Просмотреть файл

@ -421,19 +421,17 @@ static Constant *ConstantFoldQuaternaryIntInstrinsic(OP::OpCode opcode, Type *Ty
return ConstantInt::get(Ty, result);
}
// Return true if opcode is for a dot operation.
static bool IsDotOpcode(OP::OpCode opcode) {
return opcode == OP::OpCode::Dot2
|| opcode == OP::OpCode::Dot3
|| opcode == OP::OpCode::Dot4;
}
// Top level function to constant fold floating point intrinsics.
static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
return nullptr;
if (IntrinsicOperands.Size() == 1) {
OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
switch (opClass) {
default: break;
case OP::OpCodeClass::Unary: {
assert(IntrinsicOperands.Size() == 1);
ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
if (!IsValidOp(Op))
@ -441,7 +439,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
}
else if (IntrinsicOperands.Size() == 2) {
case OP::OpCodeClass::Binary: {
assert(IntrinsicOperands.Size() == 2);
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
@ -450,7 +449,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
}
else if (IntrinsicOperands.Size() == 3) {
case OP::OpCodeClass::Tertiary: {
assert(IntrinsicOperands.Size() == 3);
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
@ -460,7 +460,9 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
}
else if (IsDotOpcode(opcode)) {
case OP::OpCodeClass::Dot2:
case OP::OpCodeClass::Dot3:
case OP::OpCodeClass::Dot4:
return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
}
@ -472,14 +474,21 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
return nullptr;
if (IntrinsicOperands.Size() == 1) {
OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
switch (opClass) {
default: break;
case OP::OpCodeClass::Unary:
case OP::OpCodeClass::UnaryBits: {
assert(IntrinsicOperands.Size() == 1);
ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
if (!Op)
return nullptr;
return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
}
else if (IntrinsicOperands.Size() == 2) {
case OP::OpCodeClass::Binary: {
assert(IntrinsicOperands.Size() == 2);
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
if (!Op1 || !Op2)
@ -487,7 +496,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
}
else if (IntrinsicOperands.Size() == 3) {
case OP::OpCodeClass::Tertiary: {
assert(IntrinsicOperands.Size() == 3);
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
@ -496,7 +506,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
}
else if (IntrinsicOperands.Size() == 4) {
case OP::OpCodeClass::Quaternary: {
assert(IntrinsicOperands.Size() == 4);
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
@ -506,6 +517,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
}
}
return nullptr;
}
@ -535,6 +548,8 @@ bool hlsl::CanConstantFoldCallTo(const Function *F) {
return false;
// Check match using startswith to get all overloads.
// We cannot use the opcode class here because constant folding
// may run without a DxilModule available.
StringRef Name = F->getName();
if (Name.startswith("dx.op.unary"))
return true;

Просмотреть файл

@ -526,7 +526,6 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer(
return E_INVALIDARG;
}
legacy::PassManager ModulePasses;
legacy::FunctionPassManager FunctionPasses(M.get());
legacy::PassManagerBase *pPassManager = &ModulePasses;

Просмотреть файл

@ -127,7 +127,7 @@ Module *DxilModule::GetModule() const { return m_pModule; }
OP *DxilModule::GetOP() const { return m_pOP.get(); }
void DxilModule::SetShaderModel(const ShaderModel *pSM) {
DXASSERT(m_pSM == nullptr, "shader model must not change for the module");
DXASSERT(m_pSM == nullptr || (pSM != nullptr && *m_pSM == *pSM), "shader model must not change for the module");
m_pSM = pSM;
m_pMDHelper->SetShaderModel(m_pSM);
DXIL::ShaderKind shaderKind = pSM->GetKind();

Просмотреть файл

@ -427,10 +427,16 @@ void OP::RefreshCache(llvm::Module *pModule) {
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
DXASSERT_NOMSG(OpFunc == &F);
}
}
}
void OP::UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F) {
m_OpCodeClassCache[(unsigned)opClass].pOverloads[typeSlot] = F;
m_FunctionToOpClass[F] = opClass;
}
Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
DXASSERT(0 <= (unsigned)OpCode && OpCode < OpCode::NumOpCodes, "otherwise caller passed OOB OpCode");
_Analysis_assume_(0 <= (unsigned)OpCode && OpCode < OpCode::NumOpCodes);
@ -438,8 +444,10 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
unsigned TypeSlot = GetTypeSlot(pOverloadType);
OpCodeClass opClass = m_OpCodeProps[(unsigned)OpCode].OpCodeClass;
Function *&F = m_OpCodeClassCache[(unsigned)opClass].pOverloads[TypeSlot];
if (F != nullptr)
if (F != nullptr) {
UpdateCache(opClass, TypeSlot, F);
return F;
}
vector<Type*> ArgTypes; // RetType is ArgTypes[0]
Type *pETy = pOverloadType;
@ -470,6 +478,7 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
// Try to find exist function with the same name in the module.
if (Function *existF = m_pModule->getFunction(funcName)) {
F = existF;
UpdateCache(opClass, TypeSlot, F);
return F;
}
@ -690,7 +699,7 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
F = cast<Function>(m_pModule->getOrInsertFunction(funcName, pFT));
m_FunctionToOpClass[F] = opClass;
UpdateCache(opClass, TypeSlot, F);
F->setCallingConv(CallingConv::C);
F->addFnAttr(Attribute::NoUnwind);
if (m_OpCodeProps[(unsigned)OpCode].FuncAttr != Attribute::None)
@ -718,6 +727,16 @@ void OP::RemoveFunction(Function *F) {
}
}
bool OP::GetOpCodeClass(const Function *F, OP::OpCodeClass &opClass) {
auto iter = m_FunctionToOpClass.find(F);
if (iter == m_FunctionToOpClass.end()) {
DXASSERT(!IsDxilOpFunc(F), "dxil function without an opcode class mapping?");
return false;
}
opClass = iter->second;
return true;
}
llvm::Type *OP::GetOverloadType(OpCode OpCode, llvm::Function *F) {
DXASSERT(F, "not work on nullptr");
Type *Ty = F->getReturnType();

Просмотреть файл

@ -28,6 +28,17 @@ ShaderModel::ShaderModel(Kind Kind, unsigned Major, unsigned Minor, const char *
, m_NumUAVRegs(NumUAVRegs) {
}
bool ShaderModel::operator==(const ShaderModel &other) const {
return m_Kind == other.m_Kind
&& m_Major == other.m_Major
&& m_Minor == other.m_Minor
&& strcmp(m_pszName, other.m_pszName) == 0
&& m_NumInputRegs == other.m_NumInputRegs
&& m_NumOutputRegs == other.m_NumOutputRegs
&& m_bTypedUavs == other.m_bTypedUavs
&& m_NumUAVRegs == other.m_NumUAVRegs;
}
bool ShaderModel::IsValid() const {
DXASSERT(IsPS() || IsVS() || IsGS() || IsHS() || IsDS() || IsCS() || m_Kind == Kind::Invalid, "invalid shader model");
return m_Kind != Kind::Invalid;

Просмотреть файл

@ -30,16 +30,18 @@ entry:
%4 = call i32 @dx.op.quaternary.i32(i32 53, i32 0, i32 8, i32 0, i32 15)
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 16, i32 undef, i32 %4, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
; CHECK: @dx.op.bufferStore{{.*}}, i32 2560,
%5 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 8, i64 4010, i64 0)
%6 = trunc i64 %5 to i32
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 20, i32 undef, i32 %6, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
; No i64 overloads for bfi in dxil yet.
; xHECK: @dx.op.bufferStore{{.*}}, i32 2560,
;%5 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 8, i64 4010, i64 0)
;%6 = trunc i64 %5 to i32
;call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 20, i32 undef, i32 %6, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
; CHECK: @dx.op.bufferStore{{.*}}, i32 10,
%7 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 32, i64 4010, i64 0)
%8 = lshr i64 %7, 32
%9 = trunc i64 %8 to i32
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 24, i32 undef, i32 %9, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
; No i64 overloads for bfi in dxil yet.
; xHECK: @dx.op.bufferStore{{.*}}, i32 10,
;%7 = call i64 @dx.op.quaternary.i64(i32 53, i64 4, i64 32, i64 4010, i64 0)
;%8 = lshr i64 %7, 32
;%9 = trunc i64 %8 to i32
;call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_rawbuf, i32 24, i32 undef, i32 %9, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
call void @dx.op.storeOutput.i32(i32 5, i32 0, i32 0, i8 0, i32 0) ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
ret void